Skip to content

Commit

Permalink
[REF] Extract instantiation of index pattern tensors
Browse files Browse the repository at this point in the history
  • Loading branch information
f-dangel committed Jul 3, 2023
1 parent 9d3055d commit 273e999
Show file tree
Hide file tree
Showing 9 changed files with 136 additions and 125 deletions.
4 changes: 1 addition & 3 deletions einconv/conv_index_pattern.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,7 @@
from torch import Tensor, arange, eye, logical_and, nonzero, ones_like, zeros
from torch.nn.functional import conv1d

from einconv.utils import get_conv_output_size, get_conv_paddings

cpu = torch.device("cpu")
from einconv.utils import cpu, get_conv_output_size, get_conv_paddings


def index_pattern(
Expand Down
33 changes: 12 additions & 21 deletions einconv/expressions/convNd_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from torch.nn import Parameter

from einconv import index_pattern
from einconv.expressions.utils import create_conv_index_patterns
from einconv.utils import _tuple, get_letters


Expand Down Expand Up @@ -78,29 +79,19 @@ def _operands_and_shape(
un-grouped weight.
Output shape.
"""
input_size = tuple(x.shape[2:])
kernel_size = tuple(weight.shape[2:])

# convert into tuple format
N = x.dim() - 2
t_stride: Tuple[int, ...] = _tuple(stride, N)
t_padding: Union[Tuple[int, ...], str] = (
padding if isinstance(padding, str) else _tuple(padding, N)
input_size = x.shape[2:]
kernel_size = weight.shape[2:]
patterns = create_conv_index_patterns(
N,
input_size,
kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
device=weight.device,
dtype=weight.dtype,
)
t_dilation: Tuple[int, ...] = _tuple(dilation, N)

patterns = [
index_pattern(
input_size[n],
kernel_size[n],
stride=t_stride[n],
padding=padding if isinstance(padding, str) else t_padding[n],
dilation=t_dilation[n],
device=x.device,
dtype=x.dtype,
)
for n in range(N)
]
x_ungrouped = rearrange(x, "n (g c_in) ... -> n g c_in ...", g=groups)
weight_ungrouped = rearrange(weight, "(g c_out) ... -> g c_out ...", g=groups)
operands = [x_ungrouped, *patterns, weight_ungrouped]
Expand Down
31 changes: 11 additions & 20 deletions einconv/expressions/convNd_input_vjp.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from torch import Tensor
from torch.nn import Parameter

from einconv import index_pattern
from einconv.expressions.utils import create_conv_index_patterns
from einconv.utils import _tuple, get_letters


Expand Down Expand Up @@ -92,34 +92,25 @@ def _operands_and_shape(
un-grouped weight.
Output shape
"""
# convert into tuple format
N = weight.dim() - 2
kernel_size = weight.shape[2:]
t_input_size: Tuple[int, ...] = _tuple(input_size, N)
t_dilation: Tuple[int, ...] = _tuple(dilation, N)
t_padding: Union[Tuple[int, ...], str] = (
padding if isinstance(padding, str) else _tuple(padding, N)
patterns = create_conv_index_patterns(
N,
input_size,
kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
device=weight.device,
dtype=weight.dtype,
)
t_stride: Tuple[int, ...] = _tuple(stride, N)

patterns: List[Tensor] = [
index_pattern(
t_input_size[n],
kernel_size[n],
stride=t_stride[n],
padding=t_padding if isinstance(t_padding, str) else t_padding[n],
dilation=t_dilation[n],
device=weight.device,
dtype=weight.dtype,
)
for n in range(N)
]
v_ungrouped = rearrange(v, "n (g c_out) ... -> n g c_out ...", g=groups)
weight_ungrouped = rearrange(weight, "(g c_out) ... -> g c_out ...", g=groups)
operands = [v_ungrouped, *patterns, weight_ungrouped]

batch_size = v.shape[0]
group_in_channels = weight.shape[1]
t_input_size = _tuple(input_size, N)
shape = (batch_size, groups * group_in_channels, *t_input_size)

return operands, shape
Expand Down
32 changes: 12 additions & 20 deletions einconv/expressions/convNd_kfac_reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from torch import Tensor

from einconv import index_pattern
from einconv.expressions.utils import create_conv_index_patterns
from einconv.utils import _tuple, get_letters


Expand Down Expand Up @@ -83,34 +84,25 @@ def _operands_and_shape(
un-grouped input, patterns, normalization scaling.
Output shape
"""
# convert into tuple format
N = x.dim() - 2
input_sizes = x.shape[2:]
t_kernel_size: Tuple[int, ...] = _tuple(kernel_size, N)
t_dilation: Tuple[int, ...] = _tuple(dilation, N)
t_padding: Union[Tuple[int, ...], str] = (
padding if isinstance(padding, str) else _tuple(padding, N)
input_size = x.shape[2:]
patterns = create_conv_index_patterns(
N,
input_size,
kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
device=x.device,
dtype=x.dtype,
)
t_stride: Tuple[int, ...] = _tuple(stride, N)

patterns: List[Tensor] = [
index_pattern(
input_sizes[n],
t_kernel_size[n],
stride=t_stride[n],
padding=t_padding if isinstance(t_padding, str) else t_padding[n],
dilation=t_dilation[n],
device=x.device,
dtype=x.dtype,
)
for n in range(N)
]
x_ungrouped = rearrange(x, "n (g c_in) ... -> n g c_in ...", g=groups)
output_tot_size = Tensor([p.shape[1] for p in patterns]).int().prod()
batch_size = x.shape[0]
scale = Tensor([1.0 / (batch_size * output_tot_size**2)]).to(x.device).to(x.dtype)
operands = [x_ungrouped, *patterns, *patterns, x_ungrouped, scale]

t_kernel_size = _tuple(kernel_size, N)
kernel_tot_size = int(Tensor(t_kernel_size).int().prod())
in_channels = x.shape[1]
shape = (
Expand Down
32 changes: 12 additions & 20 deletions einconv/expressions/convNd_kfc.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from torch import Tensor

from einconv import index_pattern
from einconv.expressions.utils import create_conv_index_patterns
from einconv.utils import _tuple, get_letters


Expand Down Expand Up @@ -83,33 +84,24 @@ def _operands_and_shape(
un-grouped input, patterns, normalization scaling.
Output shape
"""
# convert into tuple format
N = x.dim() - 2
input_sizes = x.shape[2:]
t_kernel_size: Tuple[int, ...] = _tuple(kernel_size, N)
t_dilation: Tuple[int, ...] = _tuple(dilation, N)
t_padding: Union[Tuple[int, ...], str] = (
padding if isinstance(padding, str) else _tuple(padding, N)
input_size = x.shape[2:]
patterns = create_conv_index_patterns(
N,
input_size,
kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
device=x.device,
dtype=x.dtype,
)
t_stride: Tuple[int, ...] = _tuple(stride, N)

patterns: List[Tensor] = [
index_pattern(
input_sizes[n],
t_kernel_size[n],
stride=t_stride[n],
padding=t_padding if isinstance(t_padding, str) else t_padding[n],
dilation=t_dilation[n],
device=x.device,
dtype=x.dtype,
)
for n in range(N)
]
x_ungrouped = rearrange(x, "n (g c_in) ... -> n g c_in ...", g=groups)
batch_size = x.shape[0]
scale = Tensor([1.0 / batch_size]).to(x.device).to(x.dtype)
operands = [x_ungrouped, *patterns, *patterns, x_ungrouped, scale]

t_kernel_size = _tuple(kernel_size, N)
kernel_tot_sizes = int(Tensor(t_kernel_size).int().prod())
in_channels = x.shape[1]
shape = (
Expand Down
33 changes: 12 additions & 21 deletions einconv/expressions/convNd_unfold.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from torch import Tensor

from einconv import index_pattern
from einconv.expressions.utils import create_conv_index_patterns
from einconv.utils import _tuple, get_letters


Expand Down Expand Up @@ -67,31 +67,22 @@ def _operands_and_shape(
Tensor list containing the operands in order input, patterns.
Output shape.
"""
# convert into tuple format
N = x.dim() - 2
input_sizes = x.shape[2:]
t_kernel_size: Tuple[int, ...] = _tuple(kernel_size, N)
t_dilation: Tuple[int, ...] = _tuple(dilation, N)
t_padding: Union[Tuple[int, ...], str] = (
padding if isinstance(padding, str) else _tuple(padding, N)
input_size = x.shape[2:]
patterns = create_conv_index_patterns(
N,
input_size,
kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
device=x.device,
dtype=x.dtype,
)
t_stride: Tuple[int, ...] = _tuple(stride, N)

patterns: List[Tensor] = [
index_pattern(
input_sizes[n],
t_kernel_size[n],
stride=t_stride[n],
padding=t_padding if isinstance(t_padding, str) else t_padding[n],
dilation=t_dilation[n],
device=x.device,
dtype=x.dtype,
)
for n in range(N)
]
operands = [x, *patterns]

output_tot_size = int(Tensor([p.shape[1] for p in patterns]).int().prod())
t_kernel_size = _tuple(kernel_size, N)
kernel_tot_size = int(Tensor(t_kernel_size).int().prod())
batch_size, in_channels = x.shape[:2]
shape = (batch_size, in_channels * kernel_tot_size, output_tot_size)
Expand Down
32 changes: 12 additions & 20 deletions einconv/expressions/convNd_weight_vjp.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from torch import Tensor

from einconv import index_pattern
from einconv.expressions.utils import create_conv_index_patterns
from einconv.utils import _tuple, get_letters


Expand Down Expand Up @@ -88,34 +89,25 @@ def _operands_and_shape(
Einsum operands in order un-grouped input, patterns, un-grouped vector
Output shape
"""
# convert into tuple format
N = x.dim() - 2
input_sizes = x.shape[2:]
t_kernel_size: Tuple[int, ...] = _tuple(kernel_size, N)
t_dilation: Tuple[int, ...] = _tuple(dilation, N)
t_padding: Union[Tuple[int, ...], str] = (
padding if isinstance(padding, str) else _tuple(padding, N)
input_size = x.shape[2:]
patterns = create_conv_index_patterns(
N,
input_size,
kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
device=x.device,
dtype=x.dtype,
)
t_stride: Tuple[int, ...] = _tuple(stride, N)

patterns: List[Tensor] = [
index_pattern(
input_sizes[n],
t_kernel_size[n],
stride=t_stride[n],
padding=t_padding if isinstance(t_padding, str) else t_padding[n],
dilation=t_dilation[n],
device=x.device,
dtype=x.dtype,
)
for n in range(N)
]
x_ungrouped = rearrange(x, "n (g c_in) ... -> n g c_in ...", g=groups)
v_ungrouped = rearrange(v, "n (g c_out) ... -> n g c_out ...", g=groups)
operands = [x_ungrouped, *patterns, v_ungrouped]

in_channels = x.shape[1]
out_channels = v.shape[1]
t_kernel_size = _tuple(kernel_size, N)
shape = (out_channels, in_channels // groups, *t_kernel_size)

return operands, shape
Expand Down
61 changes: 61 additions & 0 deletions einconv/expressions/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
"""Utility functions for creating einsum expressions."""

from typing import List, Tuple, Union

import torch
from torch import Tensor

from einconv import index_pattern
from einconv.utils import _tuple, cpu


def create_conv_index_patterns(
N: int,
input_size: Union[int, Tuple[int, ...]],
kernel_size: Union[int, Tuple[int, ...]],
stride: Union[int, Tuple[int, ...]] = 1,
padding: Union[int, str, Tuple[int, ...]] = 0,
dilation: Union[int, Tuple[int, ...]] = 1,
device: torch.device = cpu,
dtype: torch.dtype = torch.bool,
) -> List[Tensor]:
"""Create the index pattern tensors for all dimensions of a convolution.
Args:
N: Convolution dimension.
input_size: Spatial dimensions of the convolution. Can be a single integer
(shared along all spatial dimensions), or an ``N``-tuple of integers.
kernel_size: Kernel dimensions. Can be a single integer (shared along all
spatial dimensions), or an ``N``-tuple of integers.
stride: Stride of the convolution. Can be a single integer (shared along all
spatial dimensions), or an ``N``-tuple of integers. Default: ``1``.
padding: Padding of the convolution. Can be a single integer (shared along
all spatial dimensions), an ``N``-tuple of integers, or a string.
Default: ``0``. Allowed strings are ``'same'`` and ``'valid'``.
dilation: Dilation of the convolution. Can be a single integer (shared along
all spatial dimensions), or an ``N``-tuple of integers. Default: ``1``.
device: Device to create the tensors on. Default: ``'cpu'``.
dtype: Data type of the pattern tensor. Default: ``torch.bool``.
Returns:
List of index pattern tensors for dimensions ``1, ..., N``.
"""
# convert into tuple format
t_input_size = _tuple(input_size, N)
t_kernel_size = _tuple(kernel_size, N)
t_stride: Tuple[int, ...] = _tuple(stride, N)
t_padding = padding if isinstance(padding, str) else _tuple(padding, N)
t_dilation = _tuple(dilation, N)

return [
index_pattern(
t_input_size[n],
t_kernel_size[n],
stride=t_stride[n],
padding=t_padding if isinstance(t_padding, str) else t_padding[n],
dilation=t_dilation[n],
device=device,
dtype=dtype,
)
for n in range(N)
]
Loading

0 comments on commit 273e999

Please sign in to comment.