From 16344289d2038d605747b095a690d50cfa64b3e9 Mon Sep 17 00:00:00 2001 From: kengz Date: Sat, 26 Jun 2021 11:11:53 -0700 Subject: [PATCH 1/2] feat(dict): replace TensorTuple with tensor dict - replace all namedtuple TensorTuple with safer dict of tensors to make module pickle-safe - use native Python functions instead of pydash to improve performance by reducing executions in Python (all in torch) BREAKING CHANGE: TensorTuple is replaced with dict of tensors --- README.md | 10 +++---- setup.py | 2 +- test/module/test_dag.py | 18 +++++-------- test/module/test_fork.py | 9 +++---- test/module/test_merge.py | 4 +-- test/test_module_builder.py | 15 +++++------ test/test_net_util.py | 5 ++-- torcharc/__init__.py | 1 - torcharc/module/dag.py | 8 +++--- torcharc/module/fork.py | 17 +++++------- torcharc/module/merge.py | 14 +++++----- torcharc/module_builder.py | 53 ++++++++++++++++++------------------- torcharc/net_util.py | 20 +++++--------- 13 files changed, 76 insertions(+), 100 deletions(-) diff --git a/README.md b/README.md index c7492d8..9d58982 100644 --- a/README.md +++ b/README.md @@ -359,10 +359,8 @@ model = torcharc.build(arc) batch_size = 16 dag_in_shape = arc['dag_in_shape'] -data = {'image': torch.rand([batch_size, *dag_in_shape['image']]), 'vector': torch.rand([batch_size, *dag_in_shape['vector']])} -# convert from a dict of Tensors into a TensorTuple - a namedtuple -xs = torcharc.to_namedtuple(data) -# returns TensorTuple if output is multi-model, Tensor otherwise +xs = {'image': torch.rand([batch_size, *dag_in_shape['image']]), 'vector': torch.rand([batch_size, *dag_in_shape['vector']])} +# returns dict if output is multi-model, Tensor otherwise ys = model(xs) ``` @@ -406,9 +404,9 @@ DAGNet(

-DAG module accepts a `TensorTuple` (example below) as input, and the module selects its input by matching its own name in the arc and the `in_name`, then carry forward the output together with any unconsumed inputs. +DAG module accepts a `dict` (example below) as input, and the module selects its input by matching its own name in the arc and the `in_name`, then carry forward the output together with any unconsumed inputs. -For example, the input `xs` with keys `image, vector` passes through the first `image` module, and the output becomes `TensorTuple(image=image_module(xs.image), vector=xs.vector)`. This is then passed through the remainder of the modules in the arc as declared. +For example, the input `xs` with keys `image, vector` passes through the first `image` module, and the output becomes `{'image': image_module(xs.image), 'vector': xs.vector}`. This is then passed through the remainder of the modules in the arc as declared. ## Development diff --git a/setup.py b/setup.py index 9258551..6c6b74b 100644 --- a/setup.py +++ b/setup.py @@ -35,7 +35,7 @@ def run_tests(self): setup( name='torcharc', - version='0.0.6', + version='1.0.0', description='Build PyTorch networks by specifying architectures.', long_description='https://github.com/kengz/torcharc', keywords='torcharc', diff --git a/test/module/test_dag.py b/test/module/test_dag.py index f71576d..596729e 100644 --- a/test/module/test_dag.py +++ b/test/module/test_dag.py @@ -1,6 +1,5 @@ from torcharc import arc_ref, net_util from torcharc.module import dag -import pydash as ps import torch @@ -37,7 +36,7 @@ def test_dag_reusefork(): xs = net_util.get_rand_tensor(in_shapes) model = dag.DAGNet(arc) ys = model(xs) - assert ps.is_tuple(ys) + assert isinstance(ys, dict) def test_dag_splitfork(): @@ -46,7 +45,7 @@ def test_dag_splitfork(): xs = net_util.get_rand_tensor(in_shapes) model = dag.DAGNet(arc) ys = model(xs) - assert ps.is_tuple(ys) + assert isinstance(ys, dict) def test_dag_merge_fork(): @@ -54,9 +53,8 @@ def test_dag_merge_fork(): in_shapes = arc['dag_in_shape'] xs = net_util.get_rand_tensor(in_shapes) model = dag.DAGNet(arc) - ys = model(xs._asdict()) # test dict input for tracing ys = model(xs) - assert ps.is_tuple(ys) + assert isinstance(ys, dict) def test_dag_fork_merge(): @@ -74,7 +72,7 @@ def test_dag_reuse_fork_forward(): xs = net_util.get_rand_tensor(in_shapes) model = dag.DAGNet(arc) ys = model(xs) - assert ps.is_tuple(ys) + assert isinstance(ys, dict) def test_dag_split_fork_forward(): @@ -83,7 +81,7 @@ def test_dag_split_fork_forward(): xs = net_util.get_rand_tensor(in_shapes) model = dag.DAGNet(arc) ys = model(xs) - assert ps.is_tuple(ys) + assert isinstance(ys, dict) def test_dag_merge_forward_split(): @@ -91,9 +89,8 @@ def test_dag_merge_forward_split(): in_shapes = arc['dag_in_shape'] xs = net_util.get_rand_tensor(in_shapes) model = dag.DAGNet(arc) - ys = model(xs._asdict()) # test dict input for tracing ys = model(xs) - assert ps.is_tuple(ys) + assert isinstance(ys, dict) def test_dag_hydra(): @@ -101,6 +98,5 @@ def test_dag_hydra(): in_shapes = arc['dag_in_shape'] xs = net_util.get_rand_tensor(in_shapes) model = dag.DAGNet(arc) - ys = model(xs._asdict()) # test dict input for tracing ys = model(xs) - assert ps.is_tuple(ys) + assert isinstance(ys, dict) diff --git a/test/module/test_fork.py b/test/module/test_fork.py index 3a63317..f92c2df 100644 --- a/test/module/test_fork.py +++ b/test/module/test_fork.py @@ -1,5 +1,4 @@ from torcharc.module.fork import Fork, ReuseFork, SplitFork -import pydash as ps import pytest import torch @@ -14,8 +13,8 @@ def test_reuse_fork(names, x): fork = ReuseFork(names) assert isinstance(fork, Fork) ys = fork(x) - assert ps.is_tuple(ys) - assert ys._fields == tuple(names) + assert isinstance(ys, dict) + assert list(ys) == names @pytest.mark.parametrize('shapes,x', [ @@ -28,5 +27,5 @@ def test_split_fork(shapes, x): fork = SplitFork(shapes) assert isinstance(fork, Fork) ys = fork(x) - assert ps.is_tuple(ys) - assert ys._fields == tuple(shapes.keys()) + assert isinstance(ys, dict) + assert ys.keys() == shapes.keys() diff --git a/test/module/test_merge.py b/test/module/test_merge.py index a8b99d1..767543b 100644 --- a/test/module/test_merge.py +++ b/test/module/test_merge.py @@ -16,7 +16,7 @@ def test_concat_merge(xs, out_shape): merge = ConcatMerge() assert isinstance(merge, Merge) - y = merge(net_util.to_namedtuple(xs)) + y = merge(xs) assert y.shape == torch.Size(out_shape) @@ -56,5 +56,5 @@ def test_film_affine_transform(feature): def test_film_merge(names, shapes, xs): merge = FiLMMerge(names, shapes) assert isinstance(merge, Merge) - y = merge(net_util.to_namedtuple(xs)) + y = merge(xs) assert y.shape == xs[names['feature']].shape diff --git a/test/test_module_builder.py b/test/test_module_builder.py index c92d7fd..3ad10aa 100644 --- a/test/test_module_builder.py +++ b/test/test_module_builder.py @@ -1,7 +1,6 @@ from fixture.net import CONV1D_ARC, CONV2D_ARC, CONV3D_ARC, LINEAR_ARC from torcharc import module_builder, net_util from torch import nn -import pydash as ps import pytest import torch @@ -27,7 +26,7 @@ ]) def test_get_init_fn(init, activation): init_fn = module_builder.get_init_fn(init, activation) - assert ps.is_function(init_fn) + assert callable(init_fn) @pytest.mark.parametrize('arc,nn_class', [ @@ -248,11 +247,11 @@ def test_carry_forward_tensor(arc, xs): net_util.get_rand_tensor({'vector': [LINEAR_ARC['in_features']], 'image': CONV2D_ARC['in_shape']}), ) ]) -def test_carry_forward_tensor_tuple_default(arc, xs): +def test_carry_forward_dict_default(arc, xs): module = module_builder.build_module(arc) - assert ps.is_tuple(xs) + assert isinstance(xs, dict) ys = module_builder.carry_forward(module, xs) - assert ps.is_tuple(ys) + assert isinstance(ys, dict) @pytest.mark.parametrize('xs', [ @@ -281,8 +280,8 @@ def test_carry_forward_tensor_tuple_default(arc, xs): ['image', 'vector'], ) ]) -def test_carry_forward_tensor_tuple(arc, xs, in_names): +def test_carry_forward_dict(arc, xs, in_names): module = module_builder.build_module(arc) - assert ps.is_tuple(xs) + assert isinstance(xs, dict) ys = module_builder.carry_forward(module, xs, in_names) - assert isinstance(ys, (torch.Tensor, tuple)) + assert isinstance(ys, (torch.Tensor, dict)) diff --git a/test/test_net_util.py b/test/test_net_util.py index d316253..c1dcea9 100644 --- a/test/test_net_util.py +++ b/test/test_net_util.py @@ -1,7 +1,6 @@ from fixture.net import CONV1D_ARC, CONV2D_ARC, CONV3D_ARC, LINEAR_ARC from torcharc import module_builder, net_util from torch import nn -import pydash as ps import pytest import torch @@ -39,7 +38,7 @@ def test_get_rand_tensor(shape, batch_size, tensor_shape): ]) def test_get_rand_tensor_dict(shapes, batch_size, tensor_shapes): xs = net_util.get_rand_tensor(shapes, batch_size) - assert ps.is_tuple(xs) + assert isinstance(xs, dict) for name, tensor_shape in tensor_shapes.items(): - x = getattr(xs, name) + x = xs[name] assert list(x.shape) == tensor_shape diff --git a/torcharc/__init__.py b/torcharc/__init__.py index 055cc96..6fd4fd9 100644 --- a/torcharc/__init__.py +++ b/torcharc/__init__.py @@ -1,5 +1,4 @@ from torcharc import module_builder -from torcharc.net_util import to_namedtuple from torcharc.module import dag from torch import nn diff --git a/torcharc/module/dag.py b/torcharc/module/dag.py index 225aa29..28791c6 100644 --- a/torcharc/module/dag.py +++ b/torcharc/module/dag.py @@ -1,7 +1,7 @@ # build DAG of nn modules from torcharc import module_builder, net_util from torch import nn -from typing import NamedTuple, Union +from typing import Union import pydash as ps import torch @@ -30,10 +30,8 @@ def __init__(self, arc: dict) -> None: xs = module_builder.carry_forward(module, xs, m_arc.get('in_names')) self.module_dict.update({name: module}) - def forward(self, xs: Union[torch.Tensor, NamedTuple]) -> Union[torch.Tensor, NamedTuple]: - # jit.trace will spread args on encountering a namedtuple, thus xs needs to be passed as dict then converted back into namedtuple - if ps.is_dict(xs): # guard to convert dict xs into namedtuple - xs = net_util.to_namedtuple(xs) + def forward(self, xs: Union[torch.Tensor, dict]) -> Union[torch.Tensor, dict]: + # safe for jit.trace for name, module in self.module_dict.items(): m_arc = self.arc[name] xs = module_builder.carry_forward(module, xs, m_arc.get('in_names')) diff --git a/torcharc/module/fork.py b/torcharc/module/fork.py index 5c972db..284e942 100644 --- a/torcharc/module/fork.py +++ b/torcharc/module/fork.py @@ -1,7 +1,6 @@ from abc import ABC, abstractmethod -from collections import namedtuple from torch import nn -from typing import Dict, List, NamedTuple +from typing import Dict, List import pydash as ps import torch @@ -10,21 +9,20 @@ class Fork(ABC, nn.Module): '''A Fork module forks one tensor into a dict of multiple tensors.''' @abstractmethod - def forward(self, x: torch.Tensor) -> NamedTuple: # pragma: no cover + def forward(self, x: torch.Tensor) -> dict: # pragma: no cover raise NotImplementedError class ReuseFork(Fork): - '''Fork layer to reuse a tensor multiple times via ref in TensorTuple''' + '''Fork layer to reuse a tensor multiple times via ref in dict''' def __init__(self, names: List[str]) -> None: super().__init__() self.names = names self.num_reuse = len(names) - self.TensorTuple = namedtuple('TensorTuple', names) - def forward(self, x: torch.Tensor) -> NamedTuple: - return self.TensorTuple(*[x] * self.num_reuse) + def forward(self, x: torch.Tensor) -> dict: + return dict(zip(self.names, [x] * self.num_reuse)) class SplitFork(Fork): @@ -34,7 +32,6 @@ def __init__(self, shapes: Dict[str, List[int]]) -> None: super().__init__() self.shapes = shapes self.split_size = ps.flatten(self.shapes.values()) - self.TensorTuple = namedtuple('TensorTuple', shapes.keys()) - def forward(self, x: torch.Tensor) -> NamedTuple: - return self.TensorTuple(*x.split(self.split_size, dim=1)) + def forward(self, x: torch.Tensor) -> dict: + return dict(zip(self.shapes, x.split(self.split_size, dim=1))) diff --git a/torcharc/module/merge.py b/torcharc/module/merge.py index b639f63..106647c 100644 --- a/torcharc/module/merge.py +++ b/torcharc/module/merge.py @@ -1,6 +1,6 @@ from abc import ABC, abstractmethod from torch import nn -from typing import Dict, List, NamedTuple +from typing import Dict, List import torch @@ -8,15 +8,15 @@ class Merge(ABC, nn.Module): '''A Merge module merges a dict of tensors into one tensor''' @abstractmethod - def forward(self, xs: NamedTuple) -> torch.Tensor: # pragma: no cover + def forward(self, xs: dict) -> torch.Tensor: # pragma: no cover raise NotImplementedError class ConcatMerge(Merge): '''Merge layer to merge a dict of tensors by concatenating along dim=1. Reverse of Split''' - def forward(self, xs: NamedTuple) -> torch.Tensor: - return torch.cat(xs, dim=1) + def forward(self, xs: dict) -> torch.Tensor: + return torch.cat(list(xs.values()), dim=1) class FiLMMerge(Merge): @@ -43,10 +43,10 @@ def affine_transform(cls, feature: torch.Tensor, conditioner_scale: torch.Tensor view_shape = list(conditioner_scale.shape) + [1] * (feature.dim() - conditioner_scale.dim()) return conditioner_scale.view(*view_shape) * feature + conditioner_shift.view(*view_shape) - def forward(self, xs: NamedTuple) -> torch.Tensor: + def forward(self, xs: dict) -> torch.Tensor: '''Apply FiLM affine transform on feature using conditioner''' - feature = getattr(xs, self.feature_name) - conditioner = getattr(xs, self.conditioner_name) + feature = xs[self.feature_name] + conditioner = xs[self.conditioner_name] conditioner_scale = self.conditioner_scale(conditioner) conditioner_shift = self.conditioner_shift(conditioner) return self.affine_transform(feature, conditioner_scale, conditioner_shift) diff --git a/torcharc/module_builder.py b/torcharc/module_builder.py index dee25c6..84887b2 100644 --- a/torcharc/module_builder.py +++ b/torcharc/module_builder.py @@ -1,10 +1,9 @@ # build neural networks modularly from torch import nn -from torcharc import net_util from torcharc import optim from torcharc.module import fork, merge, sequential from torcharc.module.transformer import pytorch_tst, tst -from typing import Callable, List, Optional, NamedTuple, Union +from typing import Callable, List, Optional, Union import inspect import pydash as ps import torch @@ -29,11 +28,11 @@ def get_init_fn(init: Union[str, dict], activation: Optional[str] = None) -> Cal def init_fn(module: nn.Module) -> None: if init is None: return - elif ps.is_string(init): + elif isinstance(init, str): init_type = init init_kwargs = {} else: - assert ps.is_dict(init) + assert isinstance(init, dict) init_type = init['type'] init_kwargs = ps.omit(init, 'type') fn = getattr(nn.init, init_type) @@ -70,60 +69,60 @@ def build_module(arc: dict) -> nn.Module: return module -def infer_in_shape(arc: dict, xs: Union[torch.Tensor, NamedTuple]) -> None: +def infer_in_shape(arc: dict, xs: Union[torch.Tensor, dict]) -> None: '''Infer the input shape(s) for arc depending on its type and the input tensor. This updates the arc with the appropriate key.''' nn_type = arc['type'] if nn_type == 'Linear': - if ps.is_tuple(xs): - in_names = arc.get('in_names', xs._fields[:1]) - xs = getattr(xs, in_names[0]) + if isinstance(xs, dict): + in_names = arc.get('in_names', list(xs)[:1]) + xs = xs[in_names[0]] assert isinstance(xs, torch.Tensor) assert len(xs.shape) == 2, f'xs shape {xs.shape} is not meant for {nn_type} layer' in_features = xs.shape[1] arc.update(in_features=in_features) elif nn_type.startswith('Conv') or nn_type == 'transformer': - if ps.is_tuple(xs): - in_names = arc.get('in_names', xs._fields[:1]) - xs = getattr(xs, in_names[0]) + if isinstance(xs, dict): + in_names = arc.get('in_names', list(xs)[:1]) + xs = xs[in_names[0]] assert isinstance(xs, torch.Tensor) assert len(xs.shape) >= 2, f'xs shape {xs.shape} is not meant for {nn_type} layer' in_shape = list(xs.shape)[1:] arc.update(in_shape=in_shape) elif nn_type == 'FiLMMerge': - assert ps.is_tuple(xs) + assert isinstance(xs, dict) assert len(arc['in_names']) == 2, 'FiLMMerge in_names should only specify 2 keys for feature and conditioner' - shapes = {name: list(x.shape)[1:] for name, x in xs._asdict().items() if name in arc['in_names']} + shapes = {name: list(x.shape)[1:] for name, x in xs.items() if name in arc['in_names']} arc.update(shapes=shapes) else: pass -def carry_forward(module: nn.Module, xs: Union[torch.Tensor, NamedTuple], in_names: Optional[List[str]] = None) -> Union[torch.Tensor, NamedTuple]: +def carry_forward(module: nn.Module, xs: Union[torch.Tensor, dict], in_names: Optional[List[str]] = None) -> Union[torch.Tensor, dict]: ''' - Main method to call module.forward and handle tensor and namedtuple input/output + Main method to call module.forward and handle tensor and dict input/output If xs and ys are tensors, forward as usual - If xs or ys is namedtuple, then arc.in_names must specify the inputs names to be used in forward, and any unused names will be carried with the output, which will be namedtuple. + If xs or ys is dict, then arc.in_names must specify the inputs names to be used in forward, and any unused names will be carried with the output, which will be dict. ''' - if ps.is_tuple(xs): + if isinstance(xs, dict): if in_names is None: # use the first by default - in_names = xs._fields[:1] + in_names = list(xs)[:1] if len(in_names) == 1: # single input is tensor - m_xs = getattr(xs, in_names[0]) - else: # multi input is namedtuple of tensors - m_xs = net_util.to_namedtuple({name: getattr(xs, name) for name in in_names}) + m_xs = xs[in_names[0]] + else: # multi input is dict of tensors + m_xs = {name: xs[name] for name in in_names} ys = module(m_xs) - # any unused_xs must be carried with the output as namedtuple - d_xs = xs._asdict() + # any unused_xs must be carried with the output as dict + d_xs = xs unused_d_xs = ps.omit(d_xs, in_names) if unused_d_xs: - if ps.is_tuple(ys): - d_ys = {**ys._asdict(), **unused_d_xs} - else: # when formed as namedtuple, single output will use the first of in_names + if isinstance(ys, dict): + d_ys = {**ys, **unused_d_xs} + else: # when formed as dict, single output will use the first of in_names d_ys = {**{in_names[0]: ys}, **unused_d_xs} - ys = net_util.to_namedtuple(d_ys) + ys = d_ys else: ys = module(xs) return ys diff --git a/torcharc/net_util.py b/torcharc/net_util.py index f4adb77..e312fa1 100644 --- a/torcharc/net_util.py +++ b/torcharc/net_util.py @@ -1,7 +1,5 @@ -from collections import namedtuple from torch import nn -from typing import Dict, List, NamedTuple, Union -import pydash as ps +from typing import Dict, List, Union import torch @@ -18,20 +16,14 @@ def get_layer_names(nn_layers: List[nn.Module]) -> List[str]: return [nn_layer._get_name() for nn_layer in nn_layers] -def _get_rand_tensor(shape: Union[list, tuple], batch_size: int = 4) -> torch.Tensor: +def _get_rand_tensor(shape: Union[list, dict], batch_size: int = 4) -> torch.Tensor: '''Get a random tensor given a shape and a batch size''' return torch.rand([batch_size] + list(shape)) -def get_rand_tensor(shapes: Union[List[int], Dict[str, list]], batch_size: int = 4) -> Union[torch.Tensor, NamedTuple]: - '''Get a random tensor tuple with default batch size for a dict of shapes''' - if ps.is_dict(shapes): - TensorTuple = namedtuple('TensorTuple', shapes.keys()) - return TensorTuple(*[_get_rand_tensor(shape, batch_size) for shape in shapes.values()]) +def get_rand_tensor(shapes: Union[List[int], Dict[str, list]], batch_size: int = 4) -> Union[torch.Tensor, dict]: + '''Get a random tensor dict with default batch size for a dict of shapes''' + if isinstance(shapes, dict): + return {k: _get_rand_tensor(shape, batch_size) for k, shape in shapes.items()} else: return _get_rand_tensor(shapes, batch_size) - - -def to_namedtuple(data: dict, name='NamedTensor') -> NamedTuple: - '''Convert a dictionary to namedtuple.''' - return namedtuple(name, data)(**data) From 6fccf2a2f3895e823b27dcdd18f163bf3bdf149d Mon Sep 17 00:00:00 2001 From: kengz Date: Sat, 26 Jun 2021 11:19:24 -0700 Subject: [PATCH 2/2] chore(ci): update github actions CI and tag-release --- .github/workflows/ci.yml | 54 ++++++++++++++++++++++++------- .github/workflows/tag-release.yml | 30 +++++++++++++++++ setup.py | 3 -- 3 files changed, 73 insertions(+), 14 deletions(-) create mode 100644 .github/workflows/tag-release.yml diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 5eb1080..645e497 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -4,10 +4,39 @@ on: push: branches: [main] pull_request: - branches: [main] + branches: ["**"] jobs: + lint: + runs-on: ubuntu-latest + + steps: + - name: Check out Git repository + uses: actions/checkout@v2 + + - name: Set up Python + uses: actions/setup-python@v2 + with: + python-version: 3.8 + + - uses: liskin/gh-problem-matcher-wrap@v1 + with: + action: add + linters: flake8 + + - name: Lint with flake8 + run: | + pip install flake8 + # exit-zero treats all errors as warnings. + flake8 . --ignore=E501 --count --exit-zero --statistics + + - uses: liskin/gh-problem-matcher-wrap@v1 + with: + action: remove + linters: flake8 + build: + needs: lint runs-on: ubuntu-latest steps: @@ -35,16 +64,19 @@ jobs: conda info conda list - - name: Setup flake8 annotations - uses: rbialon/flake8-annotations@v1 - - name: Lint with flake8 - shell: bash -l {0} - run: | - pip install flake8 - # exit-zero treats all errors as warnings. - flake8 . --ignore=E501 --count --exit-zero --statistics + - uses: liskin/gh-problem-matcher-wrap@v1 + with: + action: add + linters: pytest - name: Run tests shell: bash -l {0} - run: | - python setup.py test + run: python setup.py test | tee pytest-coverage.txt + + - name: Post coverage to PR comment + uses: coroo/pytest-coverage-commentator@v1.0.2 + + - uses: liskin/gh-problem-matcher-wrap@v1 + with: + action: add + linters: pytest diff --git a/.github/workflows/tag-release.yml b/.github/workflows/tag-release.yml new file mode 100644 index 0000000..b187658 --- /dev/null +++ b/.github/workflows/tag-release.yml @@ -0,0 +1,30 @@ +# tag ref: https://github.com/marketplace/actions/github-tag#bumping +# commit msg format for tag: https://github.com/angular/angular.js/blob/master/DEVELOPERS.md#-git-commit-guidelines +name: Tag and release version + +on: + push: + branches: [main] + +jobs: + tag_and_release: + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v2 + + - name: Bump version and push tag + id: tag_version + uses: mathieudutour/github-tag-action@v5.1 + with: + github_token: ${{ secrets.GITHUB_TOKEN }} + tag_prefix: '' + + - name: Create a GitHub release + uses: actions/create-release@v1 + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + with: + tag_name: ${{ steps.tag_version.outputs.new_tag }} + release_name: Release ${{ steps.tag_version.outputs.new_tag }} + body: ${{ steps.tag_version.outputs.changelog }} diff --git a/setup.py b/setup.py index 6c6b74b..5a242b8 100644 --- a/setup.py +++ b/setup.py @@ -11,9 +11,6 @@ '--log-file-level=INFO', '--no-flaky-report', '--timeout=300', - '--cov-report=html', - '--cov-report=term', - '--cov-report=xml', '--cov=torcharc', 'test', ]