Skip to content

Commit

Permalink
[REF] Inline equation/operands/shape creation
Browse files Browse the repository at this point in the history
  • Loading branch information
f-dangel committed Jul 3, 2023
1 parent 8c25fbd commit 76beba0
Show file tree
Hide file tree
Showing 6 changed files with 98 additions and 391 deletions.
77 changes: 15 additions & 62 deletions einconv/expressions/convNd_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,50 +40,23 @@ def einsum_expression(
Output shape: ``[batch_size, out_channels, *output_sizes]``.
"""
N = x.dim() - 2
equation = _equation(N)
operands, shape = _operands_and_shape(
x, weight, stride=stride, padding=padding, dilation=dilation, groups=groups
)
return equation, operands, shape

# construct einsum equation
x_str = "n g c_in " + " ".join([f"i{i}" for i in range(N)])
pattern_strs: List[str] = [f"k{i} o{i} i{i}" for i in range(N)]
weight_str = "g c_out c_in " + " ".join([f"k{i}" for i in range(N)])
lhs = ",".join([x_str, *pattern_strs, weight_str])

def _operands_and_shape(
x: Tensor,
weight: Tensor,
stride: Union[int, Tuple[int, ...]] = 1,
padding: Union[int, str, Tuple[int, ...]] = 0,
dilation: Union[int, Tuple[int, ...]] = 1,
groups: int = 1,
) -> Tuple[List[Union[Tensor, Parameter]], Tuple[int, ...]]:
"""Prepare operands for contraction with einsum.
rhs = "n g c_out " + " ".join([f"o{i}" for i in range(N)])

Args:
x: Convolution input. Has shape ``[batch_size, in_channels, *input_sizes]``
where ``len(input_sizes) == N``.
weight: Kernel of the convolution. Has shape ``[out_channels,
in_channels / groups, *kernel_size]`` where ``kernel_size`` is an
``N``-tuple of kernel dimensions.
stride: Stride of the convolution. Can be a single integer (shared along all
spatial dimensions), or an ``N``-tuple of integers. Default: ``1``.
padding: Padding of the convolution. Can be a single integer (shared along
all spatial dimensions), an ``N``-tuple of integers, or a string.
Default: ``0``. Allowed strings are ``'same'`` and ``'valid'``.
dilation: Dilation of the convolution. Can be a single integer (shared along
all spatial dimensions), or an ``N``-tuple of integers. Default: ``1``.
groups: In how many groups to split the input channels. Default: ``1``.
equation = "->".join([lhs, rhs])
equation = translate_to_torch(equation)

Returns:
Tensor list containing the operands in order un-grouped input, index patterns, \
un-grouped weight.
Output shape.
"""
N = x.dim() - 2
input_size = x.shape[2:]
kernel_size = weight.shape[2:]
# construct einsum operands
patterns = create_conv_index_patterns(
N,
input_size,
kernel_size,
input_size=x.shape[2:],
kernel_size=weight.shape[2:],
stride=stride,
padding=padding,
dilation=dilation,
Expand All @@ -94,30 +67,10 @@ def _operands_and_shape(
weight_ungrouped = rearrange(weight, "(g c_out) ... -> g c_out ...", g=groups)
operands = [x_ungrouped, *patterns, weight_ungrouped]

output_sizes = [p.shape[1] for p in patterns]
# construct output shape
output_size = [p.shape[1] for p in patterns]
batch_size = x.shape[0]
out_channels = weight.shape[0]
shape = (batch_size, out_channels, *output_sizes)

return operands, shape


def _equation(N: int) -> str:
"""Generate einsum equation for convolution.
Args:
N: Convolution dimension.
shape = (batch_size, out_channels, *output_size)

Returns:
Einsum equation for N-dimensional convolution. Operand order is un-grouped \
input, patterns, un-grouped weight.
"""
x_str = "n g c_in " + " ".join([f"i{i}" for i in range(N)])
pattern_strs: List[str] = [f"k{i} o{i} i{i}" for i in range(N)]
weight_str = "g c_out c_in " + " ".join([f"k{i}" for i in range(N)])
lhs = ",".join([x_str, *pattern_strs, weight_str])

rhs = "n g c_out " + " ".join([f"o{i}" for i in range(N)])

equation = "->".join([lhs, rhs])
return translate_to_torch(equation)
return equation, operands, shape
80 changes: 12 additions & 68 deletions einconv/expressions/convNd_input_vjp.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,59 +45,23 @@ def einsum_expression(
Output shape: ``[batch_size, in_channels, *input_sizes]``
"""
N = weight.dim() - 2
equation = _equation(N)
operands, shape = _operands_and_shape(
weight,
v,
input_size,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
)
return equation, operands, shape

# construct einsum equation
v_str = "n g c_out " + " ".join([f"o{i}" for i in range(N)])
pattern_strs: List[str] = [f"k{i} o{i} i{i}" for i in range(N)]
weight_str = "g c_out c_in " + " ".join([f"k{i}" for i in range(N)])
lhs = ",".join([v_str, *pattern_strs, weight_str])

def _operands_and_shape(
weight: Union[Tensor, Parameter],
v: Tensor,
input_size: Union[int, Tuple[int, ...]],
stride: Union[int, Tuple[int, ...]] = 1,
padding: Union[int, str, Tuple[int, ...]] = 0,
dilation: Union[int, Tuple[int, ...]] = 1,
groups: int = 1,
) -> Tuple[List[Union[Tensor, Parameter]], Tuple[int, ...]]:
"""Prepare operands for contraction with einsum.
rhs = "n g c_in " + " ".join([f"i{i}" for i in range(N)])

Args:
weight: Kernel of the convolution. Has shape ``[out_channels,
in_channels / groups, *kernel_size]`` where ``kernel_size`` is an
``N``-tuple of kernel dimensions.
v: Vector multiplied by the Jacobian. Has shape
``[batch_size, out_channels, *output_sizes]``
where ``len(output_sizes) == N`` (same shape as the convolution's output).
input_size: Spatial dimensions of the convolution. Can be a single integer
(shared along all spatial dimensions), or an ``N``-tuple of integers.
stride: Stride of the convolution. Can be a single integer (shared along all
spatial dimensions), or an ``N``-tuple of integers. Default: ``1``.
padding: Padding of the convolution. Can be a single integer (shared along
all spatial dimensions), an ``N``-tuple of integers, or a string.
Default: ``0``. Allowed strings are ``'same'`` and ``'valid'``.
dilation: Dilation of the convolution. Can be a single integer (shared along
all spatial dimensions), or an ``N``-tuple of integers. Default: ``1``.
groups: In how many groups to split the input channels. Default: ``1``.
equation = "->".join([lhs, rhs])
equation = translate_to_torch(equation)

Returns:
Tensor list containing the operands in order un-grouped vector, patterns, \
un-grouped weight.
Output shape
"""
N = weight.dim() - 2
kernel_size = weight.shape[2:]
# construct einsum operands
patterns = create_conv_index_patterns(
N,
input_size,
kernel_size,
kernel_size=weight.shape[2:],
stride=stride,
padding=padding,
dilation=dilation,
Expand All @@ -108,30 +72,10 @@ def _operands_and_shape(
weight_ungrouped = rearrange(weight, "(g c_out) ... -> g c_out ...", g=groups)
operands = [v_ungrouped, *patterns, weight_ungrouped]

# construct output shape
batch_size = v.shape[0]
group_in_channels = weight.shape[1]
t_input_size = _tuple(input_size, N)
shape = (batch_size, groups * group_in_channels, *t_input_size)

return operands, shape


def _equation(N: int) -> str:
"""Return the einsum equation for an input VJP.
Args:
N: Convolution dimension.
Returns:
Einsum equation for the input VJP of N-dimensional convolution. Operand \
order is un-grouped vector, patterns, un-grouped weight.
"""
v_str = "n g c_out " + " ".join([f"o{i}" for i in range(N)])
pattern_strs: List[str] = [f"k{i} o{i} i{i}" for i in range(N)]
weight_str = "g c_out c_in " + " ".join([f"k{i}" for i in range(N)])
lhs = ",".join([v_str, *pattern_strs, weight_str])

rhs = "n g c_in " + " ".join([f"i{i}" for i in range(N)])

equation = "->".join([lhs, rhs])
return translate_to_torch(equation)
return equation, operands, shape
87 changes: 20 additions & 67 deletions einconv/expressions/convNd_kfac_reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,48 +47,30 @@ def einsum_expression(
in_channels //groups * tot_kernel_sizes]``
"""
N = x.dim() - 2
equation = _equation(N)
operands, shape = _operands_and_shape(
x, kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups
)
return equation, operands, shape

# construct einsum equation
x1_str = "n g c_in " + " ".join([f"i{i}" for i in range(N)])
x2_str = "n g c_in_ " + " ".join([f"i{i}_" for i in range(N)])
pattern1_strs: List[str] = [f"k{i} o{i} i{i}" for i in range(N)]
pattern2_strs: List[str] = [f"k{i}_ o{i}_ i{i}_" for i in range(N)]
scale_str = "s"
lhs = ",".join([x1_str, *pattern1_strs, *pattern2_strs, x2_str, scale_str])

def _operands_and_shape(
x: Tensor,
kernel_size: Union[int, Tuple[int, ...]],
stride: Union[int, Tuple[int, ...]] = 1,
padding: Union[int, str, Tuple[int, ...]] = 0,
dilation: Union[int, Tuple[int, ...]] = 1,
groups: int = 1,
) -> Tuple[List[Tensor], Tuple[int, ...]]:
"""Generate einsum operands for KFAC-reduce factor.
rhs = (
"g c_in "
+ " ".join([f"k{i}" for i in range(N)])
+ " c_in_ "
+ " ".join([f"k{i}_" for i in range(N)])
)

Args:
x: Convolution input. Has shape ``[batch_size, in_channels, *input_sizes]``
where ``len(input_sizes) == N``.
kernel_size: Kernel dimensions. Can be a single integer (shared along all
spatial dimensions), or an ``N``-tuple of integers.
stride: Stride of the convolution. Can be a single integer (shared along all
spatial dimensions), or an ``N``-tuple of integers. Default: ``1``.
padding: Padding of the convolution. Can be a single integer (shared along
all spatial dimensions), an ``N``-tuple of integers, or a string.
Default: ``0``. Allowed strings are ``'same'`` and ``'valid'``.
dilation: Dilation of the convolution. Can be a single integer (shared along
all spatial dimensions), or an ``N``-tuple of integers. Default: ``1``.
groups: In how many groups to split the input channels. Default: ``1``.
equation = "->".join([lhs, rhs])
equation = translate_to_torch(equation)

Returns:
Tensor list containing the operands. Convention: Un-grouped input, patterns, \
un-grouped input, patterns, normalization scaling.
Output shape
"""
N = x.dim() - 2
input_size = x.shape[2:]
# construct einsum operands
patterns = create_conv_index_patterns(
N,
input_size,
kernel_size,
input_size=x.shape[2:],
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
Expand All @@ -101,6 +83,7 @@ def _operands_and_shape(
scale = Tensor([1.0 / (batch_size * output_tot_size**2)]).to(x.device).to(x.dtype)
operands = [x_ungrouped, *patterns, *patterns, x_ungrouped, scale]

# construct output shape
t_kernel_size = _tuple(kernel_size, N)
kernel_tot_size = int(Tensor(t_kernel_size).int().prod())
in_channels = x.shape[1]
Expand All @@ -110,34 +93,4 @@ def _operands_and_shape(
in_channels // groups * kernel_tot_size,
)

return operands, shape


def _equation(N: int) -> str:
"""Generate einsum equation for KFAC reduce factor.
The arguments are
``input, *index_patterns, *index_patterns, input, scale -> output``.
Args:
N: Convolution dimension.
Returns:
Einsum equation for KFAC reduce factor of N-dimensional convolution.
"""
x1_str = "n g c_in " + " ".join([f"i{i}" for i in range(N)])
x2_str = "n g c_in_ " + " ".join([f"i{i}_" for i in range(N)])
pattern1_strs: List[str] = [f"k{i} o{i} i{i}" for i in range(N)]
pattern2_strs: List[str] = [f"k{i}_ o{i}_ i{i}_" for i in range(N)]
scale_str = "s"
lhs = ",".join([x1_str, *pattern1_strs, *pattern2_strs, x2_str, scale_str])

rhs = (
"g c_in "
+ " ".join([f"k{i}" for i in range(N)])
+ " c_in_ "
+ " ".join([f"k{i}_" for i in range(N)])
)

equation = "->".join([lhs, rhs])
return translate_to_torch(equation)
return equation, operands, shape
Loading

0 comments on commit 76beba0

Please sign in to comment.