Skip to content

Commit

Permalink
Merge pull request #6 from kengz/1.0.0
Browse files Browse the repository at this point in the history
1.0.0 replace TensorTuple with tensor dict
  • Loading branch information
kengz authored Jun 26, 2021
2 parents 0be2905 + 6fccf2a commit 1853726
Show file tree
Hide file tree
Showing 15 changed files with 149 additions and 114 deletions.
54 changes: 43 additions & 11 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,39 @@ on:
push:
branches: [main]
pull_request:
branches: [main]
branches: ["**"]

jobs:
lint:
runs-on: ubuntu-latest

steps:
- name: Check out Git repository
uses: actions/checkout@v2

- name: Set up Python
uses: actions/setup-python@v2
with:
python-version: 3.8

- uses: liskin/gh-problem-matcher-wrap@v1
with:
action: add
linters: flake8

- name: Lint with flake8
run: |
pip install flake8
# exit-zero treats all errors as warnings.
flake8 . --ignore=E501 --count --exit-zero --statistics
- uses: liskin/gh-problem-matcher-wrap@v1
with:
action: remove
linters: flake8

build:
needs: lint
runs-on: ubuntu-latest

steps:
Expand Down Expand Up @@ -35,16 +64,19 @@ jobs:
conda info
conda list
- name: Setup flake8 annotations
uses: rbialon/flake8-annotations@v1
- name: Lint with flake8
shell: bash -l {0}
run: |
pip install flake8
# exit-zero treats all errors as warnings.
flake8 . --ignore=E501 --count --exit-zero --statistics
- uses: liskin/gh-problem-matcher-wrap@v1
with:
action: add
linters: pytest

- name: Run tests
shell: bash -l {0}
run: |
python setup.py test
run: python setup.py test | tee pytest-coverage.txt

- name: Post coverage to PR comment
uses: coroo/[email protected]

- uses: liskin/gh-problem-matcher-wrap@v1
with:
action: add
linters: pytest
30 changes: 30 additions & 0 deletions .github/workflows/tag-release.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# tag ref: https://github.com/marketplace/actions/github-tag#bumping
# commit msg format for tag: https://github.com/angular/angular.js/blob/master/DEVELOPERS.md#-git-commit-guidelines
name: Tag and release version

on:
push:
branches: [main]

jobs:
tag_and_release:
runs-on: ubuntu-latest

steps:
- uses: actions/checkout@v2

- name: Bump version and push tag
id: tag_version
uses: mathieudutour/[email protected]
with:
github_token: ${{ secrets.GITHUB_TOKEN }}
tag_prefix: ''

- name: Create a GitHub release
uses: actions/create-release@v1
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
with:
tag_name: ${{ steps.tag_version.outputs.new_tag }}
release_name: Release ${{ steps.tag_version.outputs.new_tag }}
body: ${{ steps.tag_version.outputs.changelog }}
10 changes: 4 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -359,10 +359,8 @@ model = torcharc.build(arc)

batch_size = 16
dag_in_shape = arc['dag_in_shape']
data = {'image': torch.rand([batch_size, *dag_in_shape['image']]), 'vector': torch.rand([batch_size, *dag_in_shape['vector']])}
# convert from a dict of Tensors into a TensorTuple - a namedtuple
xs = torcharc.to_namedtuple(data)
# returns TensorTuple if output is multi-model, Tensor otherwise
xs = {'image': torch.rand([batch_size, *dag_in_shape['image']]), 'vector': torch.rand([batch_size, *dag_in_shape['vector']])}
# returns dict if output is multi-model, Tensor otherwise
ys = model(xs)
```

Expand Down Expand Up @@ -406,9 +404,9 @@ DAGNet(
</p>
</details>

DAG module accepts a `TensorTuple` (example below) as input, and the module selects its input by matching its own name in the arc and the `in_name`, then carry forward the output together with any unconsumed inputs.
DAG module accepts a `dict` (example below) as input, and the module selects its input by matching its own name in the arc and the `in_name`, then carry forward the output together with any unconsumed inputs.

For example, the input `xs` with keys `image, vector` passes through the first `image` module, and the output becomes `TensorTuple(image=image_module(xs.image), vector=xs.vector)`. This is then passed through the remainder of the modules in the arc as declared.
For example, the input `xs` with keys `image, vector` passes through the first `image` module, and the output becomes `{'image': image_module(xs.image), 'vector': xs.vector}`. This is then passed through the remainder of the modules in the arc as declared.

## Development

Expand Down
5 changes: 1 addition & 4 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,6 @@
'--log-file-level=INFO',
'--no-flaky-report',
'--timeout=300',
'--cov-report=html',
'--cov-report=term',
'--cov-report=xml',
'--cov=torcharc',
'test',
]
Expand All @@ -35,7 +32,7 @@ def run_tests(self):

setup(
name='torcharc',
version='0.0.6',
version='1.0.0',
description='Build PyTorch networks by specifying architectures.',
long_description='https://github.com/kengz/torcharc',
keywords='torcharc',
Expand Down
18 changes: 7 additions & 11 deletions test/module/test_dag.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from torcharc import arc_ref, net_util
from torcharc.module import dag
import pydash as ps
import torch


Expand Down Expand Up @@ -37,7 +36,7 @@ def test_dag_reusefork():
xs = net_util.get_rand_tensor(in_shapes)
model = dag.DAGNet(arc)
ys = model(xs)
assert ps.is_tuple(ys)
assert isinstance(ys, dict)


def test_dag_splitfork():
Expand All @@ -46,17 +45,16 @@ def test_dag_splitfork():
xs = net_util.get_rand_tensor(in_shapes)
model = dag.DAGNet(arc)
ys = model(xs)
assert ps.is_tuple(ys)
assert isinstance(ys, dict)


def test_dag_merge_fork():
arc = arc_ref.REF_ARCS['merge_fork']
in_shapes = arc['dag_in_shape']
xs = net_util.get_rand_tensor(in_shapes)
model = dag.DAGNet(arc)
ys = model(xs._asdict()) # test dict input for tracing
ys = model(xs)
assert ps.is_tuple(ys)
assert isinstance(ys, dict)


def test_dag_fork_merge():
Expand All @@ -74,7 +72,7 @@ def test_dag_reuse_fork_forward():
xs = net_util.get_rand_tensor(in_shapes)
model = dag.DAGNet(arc)
ys = model(xs)
assert ps.is_tuple(ys)
assert isinstance(ys, dict)


def test_dag_split_fork_forward():
Expand All @@ -83,24 +81,22 @@ def test_dag_split_fork_forward():
xs = net_util.get_rand_tensor(in_shapes)
model = dag.DAGNet(arc)
ys = model(xs)
assert ps.is_tuple(ys)
assert isinstance(ys, dict)


def test_dag_merge_forward_split():
arc = arc_ref.REF_ARCS['merge_forward_split']
in_shapes = arc['dag_in_shape']
xs = net_util.get_rand_tensor(in_shapes)
model = dag.DAGNet(arc)
ys = model(xs._asdict()) # test dict input for tracing
ys = model(xs)
assert ps.is_tuple(ys)
assert isinstance(ys, dict)


def test_dag_hydra():
arc = arc_ref.REF_ARCS['hydra']
in_shapes = arc['dag_in_shape']
xs = net_util.get_rand_tensor(in_shapes)
model = dag.DAGNet(arc)
ys = model(xs._asdict()) # test dict input for tracing
ys = model(xs)
assert ps.is_tuple(ys)
assert isinstance(ys, dict)
9 changes: 4 additions & 5 deletions test/module/test_fork.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from torcharc.module.fork import Fork, ReuseFork, SplitFork
import pydash as ps
import pytest
import torch

Expand All @@ -14,8 +13,8 @@ def test_reuse_fork(names, x):
fork = ReuseFork(names)
assert isinstance(fork, Fork)
ys = fork(x)
assert ps.is_tuple(ys)
assert ys._fields == tuple(names)
assert isinstance(ys, dict)
assert list(ys) == names


@pytest.mark.parametrize('shapes,x', [
Expand All @@ -28,5 +27,5 @@ def test_split_fork(shapes, x):
fork = SplitFork(shapes)
assert isinstance(fork, Fork)
ys = fork(x)
assert ps.is_tuple(ys)
assert ys._fields == tuple(shapes.keys())
assert isinstance(ys, dict)
assert ys.keys() == shapes.keys()
4 changes: 2 additions & 2 deletions test/module/test_merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
def test_concat_merge(xs, out_shape):
merge = ConcatMerge()
assert isinstance(merge, Merge)
y = merge(net_util.to_namedtuple(xs))
y = merge(xs)
assert y.shape == torch.Size(out_shape)


Expand Down Expand Up @@ -56,5 +56,5 @@ def test_film_affine_transform(feature):
def test_film_merge(names, shapes, xs):
merge = FiLMMerge(names, shapes)
assert isinstance(merge, Merge)
y = merge(net_util.to_namedtuple(xs))
y = merge(xs)
assert y.shape == xs[names['feature']].shape
15 changes: 7 additions & 8 deletions test/test_module_builder.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from fixture.net import CONV1D_ARC, CONV2D_ARC, CONV3D_ARC, LINEAR_ARC
from torcharc import module_builder, net_util
from torch import nn
import pydash as ps
import pytest
import torch

Expand All @@ -27,7 +26,7 @@
])
def test_get_init_fn(init, activation):
init_fn = module_builder.get_init_fn(init, activation)
assert ps.is_function(init_fn)
assert callable(init_fn)


@pytest.mark.parametrize('arc,nn_class', [
Expand Down Expand Up @@ -248,11 +247,11 @@ def test_carry_forward_tensor(arc, xs):
net_util.get_rand_tensor({'vector': [LINEAR_ARC['in_features']], 'image': CONV2D_ARC['in_shape']}),
)
])
def test_carry_forward_tensor_tuple_default(arc, xs):
def test_carry_forward_dict_default(arc, xs):
module = module_builder.build_module(arc)
assert ps.is_tuple(xs)
assert isinstance(xs, dict)
ys = module_builder.carry_forward(module, xs)
assert ps.is_tuple(ys)
assert isinstance(ys, dict)


@pytest.mark.parametrize('xs', [
Expand Down Expand Up @@ -281,8 +280,8 @@ def test_carry_forward_tensor_tuple_default(arc, xs):
['image', 'vector'],
)
])
def test_carry_forward_tensor_tuple(arc, xs, in_names):
def test_carry_forward_dict(arc, xs, in_names):
module = module_builder.build_module(arc)
assert ps.is_tuple(xs)
assert isinstance(xs, dict)
ys = module_builder.carry_forward(module, xs, in_names)
assert isinstance(ys, (torch.Tensor, tuple))
assert isinstance(ys, (torch.Tensor, dict))
5 changes: 2 additions & 3 deletions test/test_net_util.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from fixture.net import CONV1D_ARC, CONV2D_ARC, CONV3D_ARC, LINEAR_ARC
from torcharc import module_builder, net_util
from torch import nn
import pydash as ps
import pytest
import torch

Expand Down Expand Up @@ -39,7 +38,7 @@ def test_get_rand_tensor(shape, batch_size, tensor_shape):
])
def test_get_rand_tensor_dict(shapes, batch_size, tensor_shapes):
xs = net_util.get_rand_tensor(shapes, batch_size)
assert ps.is_tuple(xs)
assert isinstance(xs, dict)
for name, tensor_shape in tensor_shapes.items():
x = getattr(xs, name)
x = xs[name]
assert list(x.shape) == tensor_shape
1 change: 0 additions & 1 deletion torcharc/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from torcharc import module_builder
from torcharc.net_util import to_namedtuple
from torcharc.module import dag
from torch import nn

Expand Down
8 changes: 3 additions & 5 deletions torcharc/module/dag.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# build DAG of nn modules
from torcharc import module_builder, net_util
from torch import nn
from typing import NamedTuple, Union
from typing import Union
import pydash as ps
import torch

Expand Down Expand Up @@ -30,10 +30,8 @@ def __init__(self, arc: dict) -> None:
xs = module_builder.carry_forward(module, xs, m_arc.get('in_names'))
self.module_dict.update({name: module})

def forward(self, xs: Union[torch.Tensor, NamedTuple]) -> Union[torch.Tensor, NamedTuple]:
# jit.trace will spread args on encountering a namedtuple, thus xs needs to be passed as dict then converted back into namedtuple
if ps.is_dict(xs): # guard to convert dict xs into namedtuple
xs = net_util.to_namedtuple(xs)
def forward(self, xs: Union[torch.Tensor, dict]) -> Union[torch.Tensor, dict]:
# safe for jit.trace
for name, module in self.module_dict.items():
m_arc = self.arc[name]
xs = module_builder.carry_forward(module, xs, m_arc.get('in_names'))
Expand Down
17 changes: 7 additions & 10 deletions torcharc/module/fork.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from abc import ABC, abstractmethod
from collections import namedtuple
from torch import nn
from typing import Dict, List, NamedTuple
from typing import Dict, List
import pydash as ps
import torch

Expand All @@ -10,21 +9,20 @@ class Fork(ABC, nn.Module):
'''A Fork module forks one tensor into a dict of multiple tensors.'''

@abstractmethod
def forward(self, x: torch.Tensor) -> NamedTuple: # pragma: no cover
def forward(self, x: torch.Tensor) -> dict: # pragma: no cover
raise NotImplementedError


class ReuseFork(Fork):
'''Fork layer to reuse a tensor multiple times via ref in TensorTuple'''
'''Fork layer to reuse a tensor multiple times via ref in dict'''

def __init__(self, names: List[str]) -> None:
super().__init__()
self.names = names
self.num_reuse = len(names)
self.TensorTuple = namedtuple('TensorTuple', names)

def forward(self, x: torch.Tensor) -> NamedTuple:
return self.TensorTuple(*[x] * self.num_reuse)
def forward(self, x: torch.Tensor) -> dict:
return dict(zip(self.names, [x] * self.num_reuse))


class SplitFork(Fork):
Expand All @@ -34,7 +32,6 @@ def __init__(self, shapes: Dict[str, List[int]]) -> None:
super().__init__()
self.shapes = shapes
self.split_size = ps.flatten(self.shapes.values())
self.TensorTuple = namedtuple('TensorTuple', shapes.keys())

def forward(self, x: torch.Tensor) -> NamedTuple:
return self.TensorTuple(*x.split(self.split_size, dim=1))
def forward(self, x: torch.Tensor) -> dict:
return dict(zip(self.shapes, x.split(self.split_size, dim=1)))
Loading

0 comments on commit 1853726

Please sign in to comment.