-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #30 from f-dangel/development
Prepare initial release
- Loading branch information
Showing
69 changed files
with
5,005 additions
and
46 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,6 +5,4 @@ channels: | |
dependencies: | ||
- python=3.8.10 | ||
- pip: | ||
- -e . | ||
- -e .[lint] | ||
- -e .[test] | ||
- -e .[lint,test,doc] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,3 +2,5 @@ | |
*.egg-info | ||
.eggs | ||
/.coverage | ||
**.DS_Store | ||
site |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
# Read the Docs configuration file for MkDocs projects | ||
# See https://docs.readthedocs.io/en/stable/config-file/v2.html for details | ||
|
||
# Required | ||
version: 2 | ||
|
||
# Set the version of Python and other tools you might need | ||
build: | ||
os: ubuntu-22.04 | ||
tools: | ||
python: "3.8" | ||
|
||
mkdocs: | ||
configuration: mkdocs.yml | ||
|
||
# Optionally declare the Python requirements required to build your docs | ||
python: | ||
install: | ||
- method: pip | ||
path: . | ||
extra_requirements: | ||
- doc |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
MIT License | ||
|
||
Copyright (c) 2023 Felix Dangel | ||
|
||
Permission is hereby granted, free of charge, to any person obtaining a copy | ||
of this software and associated documentation files (the "Software"), to deal | ||
in the Software without restriction, including without limitation the rights | ||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | ||
copies of the Software, and to permit persons to whom the Software is | ||
furnished to do so, subject to the following conditions: | ||
|
||
The above copyright notice and this permission notice shall be included in all | ||
copies or substantial portions of the Software. | ||
|
||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | ||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | ||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | ||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | ||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, | ||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE | ||
SOFTWARE. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,117 @@ | ||
# <img alt="Einconv:" src="./docs/logo.png" height="90"> Convolutions Through the Lens of Tensor Networks | ||
|
||
This package offers `einsum`-based implementations of convolutions and related | ||
operations in PyTorch. | ||
|
||
Its name is inspired by [this](https://github.com/pfnet-research/einconv) Github | ||
repository which represented the starting point for our work. | ||
|
||
## Installation | ||
Install from PyPI via `pip` | ||
|
||
```sh | ||
pip install einconv | ||
``` | ||
|
||
## Examples | ||
|
||
- [Basic | ||
example](https://einconv.readthedocs.io/en/latest/tutorials/basic_conv2d/) | ||
|
||
- For more tutorials, check out the | ||
[docs](https://einconv.readthedocs.io/en/latest/) | ||
|
||
## Features & Usage | ||
|
||
In general, `einconv`'s goals are: | ||
|
||
- Full hyper-parameter support (stride, padding, dilation, groups, etc.) | ||
- Support for any dimension (e.g. 5d-convolution) | ||
- Optimizations via symbolic simplification | ||
|
||
### Modules | ||
|
||
`einconv` provides `einsum`-based implementations of the following PyTorch modules: | ||
|
||
| `torch` module | `einconv` module | | ||
|-------------------|--------------------| | ||
| `nn.Conv{1,2,3}d` | `modules.ConvNd` | | ||
| `nn.Unfold` | `modules.UnfoldNd` | | ||
|
||
They work in exactly the same way as their PyTorch equivalents. | ||
|
||
### Functionals | ||
|
||
`einconv` provides `einsum`-based implementations of the following PyTorch functionals: | ||
|
||
| `torch` functional | `einconv` functional | | ||
|------------------------------|------------------------| | ||
| `nn.functional.conv{1,2,3}d` | `functionals.convNd` | | ||
| `nn.functional.unfold` | `functionals.unfoldNd` | | ||
|
||
They work in exactly the same way as their PyTorch equivalents. | ||
|
||
### Einsum Expressions | ||
`einconv` can generate `einsum` expressions (equation, operands, and output | ||
shape) for the following operations: | ||
|
||
- Forward pass of `N`-dimensional convolution | ||
- Backward pass (input and weight VJPs) of `N`-dimensional convolution | ||
- Input unfolding (`im2col/unfold`) for inputs of `N`-dimensional convolution | ||
- Input-based Kronecker factors of Fisher approximations for convolutions (KFC | ||
and KFAC-reduce) | ||
|
||
These can then be evaluated with `einsum`. For instance, the `einsum` expression | ||
for the forward pass of an `N`-dimensional convolution is | ||
|
||
```py | ||
from torch import einsum | ||
from einconv.expressions import convNd_forward | ||
|
||
equation, operands, shape = convNd_forward.einsum_expression(...) | ||
result = einsum(equation, *operands).reshape(shape) | ||
``` | ||
|
||
All expressions follow this pattern. | ||
|
||
### Symbolic Simplification | ||
|
||
Some operations (e.g. dense convolutions) can be optimized via symbolic | ||
simplifications. This is turned on by default as it generally improves | ||
performance. You can also generate a non-optimized expression and simplify it: | ||
|
||
```py | ||
from einconv import simplify | ||
|
||
equation, operands, shape = convNd_forward.einsum_expression(..., simplify=False) | ||
equation, operands = simplify(equation, operands) | ||
result = einsum(equation, *operands).reshape(shape) | ||
``` | ||
|
||
Sometimes it might be better to inspect the non-simplified expression to see how | ||
indices relate to operands. | ||
|
||
## Citation | ||
|
||
If you find the `einconv` package useful for your research, consider mentioning | ||
the accompanying article | ||
|
||
```bib | ||
@article{dangel2023convolutions, | ||
title = {Convolutions Through the Lens of Tensor Networks}, | ||
author = {Dangel, Felix}, | ||
year = 2023, | ||
} | ||
``` | ||
## Limitations | ||
|
||
- Currently, none of the underlying operations (computation of index pattern | ||
tensors, generation of einsum equations and shapes, simplification) is cached. | ||
This consumes additional time, although it should usually take much less time | ||
than evaluating an expression via `einsum`. | ||
|
||
- At the moment, the code to perform expression simplifications is coupled with | ||
PyTorch. I am planning to address this in the future by switching the | ||
implementation to a symbolic approach which will also allow efficient caching. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
# Changelog | ||
All notable changes to this project will be documented in this file. | ||
|
||
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), | ||
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). | ||
|
||
## [Unreleased] | ||
|
||
## [0.1.0] - 2023-07-03 | ||
|
||
Initial release | ||
|
||
[Unreleased]: https://github.com/f-dangel/einconv/compare/0.1.0...HEAD | ||
[0.1.0]: https://github.com/f-dangel/einconv/releases/tag/0.1.0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
:::einconv.expressions.convNd_forward | ||
:::einconv.expressions.convNd_input_vjp | ||
:::einconv.expressions.convNd_weight_vjp | ||
:::einconv.expressions.convNd_unfold | ||
:::einconv.expressions.convNd_kfc | ||
:::einconv.expressions.convNd_kfac_reduce |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
::: einconv.functionals.convNd | ||
::: einconv.functionals.unfoldNd |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
::: einconv.index_pattern |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
::: einconv.modules.ConvNd | ||
options: | ||
members: | ||
- forward | ||
- from_nn_Conv | ||
|
||
:::einconv.modules.UnfoldNd |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
:::einconv.simplify |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
Under preparation |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
../README.md |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
```py | ||
--8<-- "docs/tutorials/basic_conv2d.py" | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
"""Compare forward pass of a 2d convolution layer.""" | ||
|
||
from torch import allclose, manual_seed, rand | ||
from torch.nn import Conv2d | ||
|
||
from einconv.modules import ConvNd | ||
|
||
manual_seed(0) # make deterministic | ||
|
||
x = rand(10, 4, 28, 28) # random input | ||
conv_params = { | ||
"in_channels": 4, | ||
"out_channels": 8, | ||
"kernel_size": 4, # can also use tuple | ||
"padding": 1, # can also use tuple, or string | ||
"stride": 3, # can also use tuple | ||
"dilation": 2, # can also use tuple | ||
"groups": 2, | ||
"bias": True, | ||
} | ||
N = 2 # convolution dimension | ||
|
||
torch_layer = Conv2d(**conv_params) | ||
ein_layer = ConvNd(N, **conv_params) | ||
ein_layer.weight.data = torch_layer.weight.data | ||
ein_layer.bias.data = torch_layer.bias.data | ||
|
||
assert allclose(torch_layer(x), ein_layer(x), rtol=1e-4, atol=1e-6) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,10 +1,9 @@ | ||
"""einconv library.""" | ||
"""Einsum implementations of convolutions and related operations.""" | ||
|
||
from einconv.conv_index_pattern import index_pattern | ||
from einconv.simplifications import simplify | ||
|
||
def hello(name): | ||
"""Say hello to a name. | ||
Args: | ||
name (str): Name to say hello to. | ||
""" | ||
print(f"Hello, {name}") | ||
__all__ = [ | ||
"index_pattern", | ||
"simplify", | ||
] |
Oops, something went wrong.