From 8cc761e6eb181532a497aefd2ba73e0ce80b9c8d Mon Sep 17 00:00:00 2001 From: Felix Dangel <48687646+f-dangel@users.noreply.github.com> Date: Thu, 6 Jun 2024 11:49:51 -0400 Subject: [PATCH] [ADD] einsum expression for transpose convolution's KFAC-reduce (#45) * [ADD] Utility function to compute input sizes of a convolution * [ADD] Einsum expression and functional for transpose input unfolding * [DEL] Print statements * [DOC] Polish docstrings * [DOC] Minor polish * [FIX] Too long lines * [ADD] einsum expression for transpose convolution's KFAC-reduce * [FIX] Long line * [REF] Minor polish * [FIX] flake8 --- docs/api/expressions.md | 1 + einconv/expressions/convNd_kfac_reduce.py | 32 +++- einconv/expressions/convNd_kfc.py | 24 ++- .../conv_transposeNd_kfac_reduce.py | 173 ++++++++++++++++++ .../expressions/conv_transposeNd_unfold.py | 21 +++ mkdocs.yml | 4 +- .../conv_transposeNd_kfac_reduce_cases.py | 117 ++++++++++++ test/expressions/test_convNd_kfac_reduce.py | 2 +- .../test_conv_transposeNd_kfac_reduce.py | 63 +++++++ test/functionals/test_unfold_transpose.py | 31 ++++ 10 files changed, 459 insertions(+), 9 deletions(-) create mode 100644 einconv/expressions/conv_transposeNd_kfac_reduce.py create mode 100644 test/expressions/conv_transposeNd_kfac_reduce_cases.py create mode 100644 test/expressions/test_conv_transposeNd_kfac_reduce.py diff --git a/docs/api/expressions.md b/docs/api/expressions.md index 7a753d3..f177d89 100644 --- a/docs/api/expressions.md +++ b/docs/api/expressions.md @@ -5,3 +5,4 @@ :::einconv.expressions.convNd_kfc :::einconv.expressions.convNd_kfac_reduce :::einconv.expressions.conv_transposeNd_unfold +:::einconv.expressions.conv_transposeNd_kfac_reduce diff --git a/einconv/expressions/convNd_kfac_reduce.py b/einconv/expressions/convNd_kfac_reduce.py index 6ad10db..dd6d766 100644 --- a/einconv/expressions/convNd_kfac_reduce.py +++ b/einconv/expressions/convNd_kfac_reduce.py @@ -2,8 +2,10 @@ KFAC-reduce was introduced by: -- Eschenhagen, R. (2022). Kronecker-factored approximate curvature for linear - weight-sharing layers, Master thesis. +- [Eschenhagen, R., Immer, A., Turner, R. E., Schneider, F., & Hennig, P. + (2023). Kronecker-factored approximate curvature for modern neural network + architectures. In Advances in Neural Information Processing Systems (NeurIPS)]\ +(https://arxiv.org/abs/2311.00636). """ from typing import List, Tuple, Union @@ -27,6 +29,26 @@ def einsum_expression( ) -> Tuple[str, List[Tensor], Tuple[int, ...]]: """Generate einsum expression of input-based KFAC-reduce factor for convolution. + Let $\\mathbf{X}\\in\\mathbb{R}^{C_\\text{in}\\times I_1\\times I_2\\times\\dots}$ + denote the input of a convolution. The unfolded input $[[\\mathbf{X}]]$ + has dimension $(C_\\text{in} \\cdot K_1 \\cdot K_2 \\cdots) \\times (O_1 \\cdot O_2 + \\cdots)$ where $K_i$ and $O_i$ are the kernel and output sizes of the convolution. + The input-based KFAC-reduce factor is the batch-averaged outer product + of the column-averaged unfolded input, + + $$ + \\hat{\\mathbf{\\Omega}} = + \\frac{1}{B \\cdot (O_1 \\cdot O_2 \\cdots)^2} \\sum_{b=1}^B + ( [[\\mathbf{X}_b]]^\\top \\mathbf{1} ) + ( [[\\mathbf{X}_b]]^\\top \\mathbf{1} )^\\top + \\in \\mathbb{R}^{(C_\\text{in} \\cdot K_1 \\cdot K_2 \\cdots) \\times + (C_\\text{in} \\cdot K_1 \\cdot K_2 \\cdots)} + \\,, + $$ + + where $B$ is the batch size and $\\mathbf{X}_b$ is the convolution's input from the + $b$th data point. + Args: x: Convolution input. Has shape ``[batch_size, in_channels, *input_sizes]`` where ``len(input_sizes) == N``. @@ -46,8 +68,8 @@ def einsum_expression( Einsum equation Einsum operands in order un-grouped input, patterns, un-grouped input, \ patterns, normalization scaling - Output shape: ``[groups, in_channels //groups * tot_kernel_sizes,\ - in_channels //groups * tot_kernel_sizes]`` + Output shape: ``[groups, in_channels // groups * tot_kernel_sizes,\ + in_channels // groups * tot_kernel_sizes]`` """ N = x.dim() - 2 @@ -83,7 +105,7 @@ def einsum_expression( x_ungrouped = rearrange(x, "n (g c_in) ... -> n g c_in ...", g=groups) output_tot_size = Tensor([p.shape[1] for p in patterns]).int().prod() batch_size = x.shape[0] - scale = Tensor([1.0 / (batch_size * output_tot_size**2)]).to(x.device).to(x.dtype) + scale = Tensor([1.0 / (batch_size * output_tot_size**2)]).to(x.device, x.dtype) operands = [x_ungrouped, *patterns, *patterns, x_ungrouped, scale] # construct output shape diff --git a/einconv/expressions/convNd_kfc.py b/einconv/expressions/convNd_kfc.py index bb422e1..20c0071 100644 --- a/einconv/expressions/convNd_kfc.py +++ b/einconv/expressions/convNd_kfc.py @@ -2,8 +2,9 @@ KFC was introduced by: -- Grosse, R., & Martens, J. (2016). A Kronecker-factored approximate Fisher matrix - for convolution layers. International Conference on Machine Learning (ICML). +- [Grosse, R., & Martens, J. (2016). A Kronecker-factored approximate Fisher matrix + for convolution layers. International Conference on Machine Learning (ICML).]\ +(https://arxiv.org/abs/1602.01407) """ from typing import List, Tuple, Union @@ -27,6 +28,25 @@ def einsum_expression( ) -> Tuple[str, List[Tensor], Tuple[int, ...]]: """Generate einsum expression of input-based KFC factor for convolution. + Let $\\mathbf{X}\\in\\mathbb{R}^{C_\\text{in}\\times I_1\\times I_2\\times\\dots}$ + denote the input of a convolution. The unfolded input $[[\\mathbf{X}]]$ + has dimension $(C_\\text{in} \\cdot K_1 \\cdot K_2 \\cdots) \\times (O_1 \\cdot O_2 + \\cdots)$ where $K_i$ and $O_i$ are the kernel and output sizes of the convolution. + The input-based KFC factor is the batch-averaged outer product of the unfolded + input, + + $$ + \\mathbf{\\Omega} = + \\frac{1}{B} \\sum_{b=1}^B + [[\\mathbf{X}_b]] [[\\mathbf{X}_b]]^\\top + \\in \\mathbb{R}^{(C_\\text{in} \\cdot K_1 \\cdot K_2 \\cdots) \\times + (C_\\text{in} \\cdot K_1 \\cdot K_2 \\cdots)} + \\,, + $$ + + where $B$ is the batch size and $\\mathbf{X}_b$ is the convolution's input from the + $b$th data point. + Args: x: Convolution input. Has shape ``[batch_size, in_channels, *input_sizes]`` where ``len(input_sizes) == N``. diff --git a/einconv/expressions/conv_transposeNd_kfac_reduce.py b/einconv/expressions/conv_transposeNd_kfac_reduce.py new file mode 100644 index 0000000..ba397d0 --- /dev/null +++ b/einconv/expressions/conv_transposeNd_kfac_reduce.py @@ -0,0 +1,173 @@ +"""Input-based factor of the KFAC-reduce approximation for transpose convolutions. + +KFAC-reduce was introduced by: + +- [Eschenhagen, R., Immer, A., Turner, R. E., Schneider, F., & Hennig, P. + (2023). Kronecker-factored approximate curvature for modern neural network + architectures. In Advances in Neural Information Processing Systems (NeurIPS)]\ +(https://arxiv.org/abs/2311.00636). +""" + +from typing import List, Optional, Tuple, Union + +from einops import rearrange +from torch import Tensor + +import einconv +from einconv.expressions.utils import create_conv_index_patterns, translate_to_torch +from einconv.utils import _tuple, get_conv_input_size + + +def einsum_expression( + x: Tensor, + kernel_size: Union[int, Tuple[int, ...]], + stride: Union[int, Tuple[int, ...]] = 1, + padding: Union[int, Tuple[int, ...]] = 0, + output_padding: Union[int, Tuple[int, ...]] = 0, + output_size: Optional[Union[int, Tuple[int, ...]]] = None, + dilation: Union[int, Tuple[int, ...]] = 1, + groups: int = 1, + simplify: bool = True, +) -> Tuple[str, List[Tensor], Tuple[int, ...]]: + """Generate einsum expr. of input-based KFAC-reduce factor for transp. convolution. + + We describe the `N`d transpose convolution using its associated `N`d convolution + which maps an input of shape `[batch_size, in_channels, *input_sizes]` to an output + of shape `[batch_size, out_channels, *output_sizes]`. The transpose convolution's + input has shape `[batch_size, out_channels, *output_sizes]` and the output has shape + `[batch_size, in_channels, *input_sizes]`. + + Let $\\mathbf{X}\\in\\mathbb{R}^{C_\\text{out}\\times O_1\\times O_2\\times\\dots}$ + denote the input of a transpose convolution. The unfolded input $[[\\mathbf{X} + ]]_\\top$ has dimension $(C_\\text{out} \\cdot K_1 \\cdot K_2 \\cdots) \\times + (I_1 \\cdot I_2 \\cdots)$ where $K_i$ and $I_i$ are the kernel and input sizes of + the associated convolution. The input-based KFAC-reduce factor is the batch-averaged + outer product of the column-averaged unfolded input, + + $$ + \\hat{\\mathbf{\\Omega}} = + \\frac{1}{B \\cdot (I_1 \\cdot I_2 \\cdots)^2} \\sum_{b=1}^B + ( [[\\mathbf{X}_b]]^\\top_\\top \\mathbf{1} ) + ( [[\\mathbf{X}_b]]^\\top_\\top \\mathbf{1} )^\\top + \\in \\mathbb{R}^{(C_\\text{out} \\cdot K_1 \\cdot K_2 \\cdots) \\times + (C_\\text{out} \\cdot K_1 \\cdot K_2 \\cdots)} + \\,, + $$ + + where $B$ is the batch size and $\\mathbf{X}_b$ is the transpose convolution's + input from the $b$th data point. + + Args: + x: Input tensor of shape `[batch_size, out_channels, *output_sizes]`. + kernel_size: Size of the convolutional kernel. Can be a single integer (shared + along all spatial dimensions), or an `N`-tuple of integers. + stride: Stride of the associated convolution. Can be a single integer (shared + along all spatial dimensions), or an `N`-tuple of integers. Default: `1`. + padding: Padding of the associated convolution. Can be a single integer (shared + along all spatial dimensions), or an `N`-tuple of integers. Default: `0`. + output_padding: Number of unused pixels at the end of the spatial domain. + This is used to resolve the ambiguity that a convolution can map different + input sizes to the same output size if its stride is different from 1. + Instead of specifying this argument, you can directly specify the output + size of the transpose convolution (i.e. the input size of the associated + convolution via the `output_size` argument). Can be a single integer + (shared along all spatial dimensions), or an `N`-tuple. Default: `0`. + output_size: Size of the output of the transpose convolution (i.e. the input + size of the associated convolution). Specifying this argument will override + the `output_padding` argument. Can be a single integer (shared along all + spatial dimensions), or an `N`-tuple of integers. Default: `None`. + dilation: Dilation of the convolution. Can be a single integer (shared along + all spatial dimensions), or an `N`-tuple of integers. Default: `1`. + groups: In how many groups to split the channels. Default: `1`. + simplify: Whether to simplify the einsum expression. Default: `True`. + + Returns: + Einsum equation + Einsum operands in order un-grouped input, patterns, un-grouped input, \ + patterns, normalization scaling + Output shape: `[groups, out_channels // groups * tot_kernel_sizes,\ + out_channels // groups * tot_kernel_sizes]` + """ + N = x.dim() - 2 + + # construct einsum equation + x1_str = "n g c_out " + " ".join([f"o{i}" for i in range(N)]) + x2_str = "n g c_out_ " + " ".join([f"o{i}_" for i in range(N)]) + pattern1_strs: List[str] = [f"k{i} o{i} i{i}" for i in range(N)] + pattern2_strs: List[str] = [f"k{i}_ o{i}_ i{i}_" for i in range(N)] + scale_str = "s" + lhs = ",".join([x1_str, *pattern1_strs, *pattern2_strs, x2_str, scale_str]) + rhs = ( + "g c_out " + + " ".join([f"k{i}" for i in range(N)]) + + " c_out_ " + + " ".join([f"k{i}_" for i in range(N)]) + ) + equation = "->".join([lhs, rhs]) + equation = translate_to_torch(equation) + + conv_output_size = x.shape[2:] + t_kernel_size = _tuple(kernel_size, N) + t_stride = _tuple(stride, N) + t_padding = _tuple(padding, N) + t_dilation = _tuple(dilation, N) + + # infer output_padding from convolution's input size + if output_size is not None: + t_output_size = _tuple(output_size, N) + t_output_padding = tuple( + output_size - get_conv_input_size(out, K, S, P, 0, D) + for output_size, out, K, S, P, D in zip( + t_output_size, + conv_output_size, + t_kernel_size, + t_stride, + t_padding, + t_dilation, + ) + ) + else: + t_output_padding = _tuple(output_padding, N) + + conv_input_size = tuple( + get_conv_input_size(out, K, S, P, output_padding, D) + for out, K, S, P, output_padding, D in zip( + conv_output_size, + t_kernel_size, + t_stride, + t_padding, + t_output_padding, + t_dilation, + ) + ) + + # construct einsum operands + patterns = create_conv_index_patterns( + N, + input_size=conv_input_size, + kernel_size=t_kernel_size, + stride=t_stride, + padding=t_padding, + dilation=dilation, + device=x.device, + dtype=x.dtype, + ) + x_ungrouped = rearrange(x, "n (g c_in) ... -> n g c_in ...", g=groups) + conv_input_tot_size = Tensor(conv_input_size).int().prod() + batch_size, out_channels = x.shape[:2] + scale = Tensor([1.0 / (batch_size * conv_input_tot_size**2)]).to(x.device, x.dtype) + operands = [x_ungrouped, *patterns, *patterns, x_ungrouped, scale] + + # construct output shape + t_kernel_size = _tuple(kernel_size, N) + kernel_tot_size = int(Tensor(t_kernel_size).int().prod()) + shape = ( + groups, + out_channels // groups * kernel_tot_size, + out_channels // groups * kernel_tot_size, + ) + + if simplify: + equation, operands = einconv.simplify(equation, operands) + + return equation, operands, shape diff --git a/einconv/expressions/conv_transposeNd_unfold.py b/einconv/expressions/conv_transposeNd_unfold.py index 09d7172..9c10779 100644 --- a/einconv/expressions/conv_transposeNd_unfold.py +++ b/einconv/expressions/conv_transposeNd_unfold.py @@ -30,6 +30,27 @@ def einsum_expression( has shape `[batch_size, out_channels, *output_sizes]` and the output has shape `[batch_size, in_channels, *input_sizes]`. + Let $\\mathbf{X}\\in\\mathbb{R}^{C_\\text{out}\\times O_1\\times O_2\\times\\dots}$ + denote the input of a transpose convolution, $\\mathbf{W} \\in \\mathbb{R}^{ + C_\\text{out} \\times C_\\text{in} \\times K_1\\times K_2\\times\\dots}$ its kernel + and $\\mathbf{Y}\\in\\mathbb{R}^{C_\\text{in}\\times I_1\\times I_2\\times\\dots}$ + its output. The unfolded input $[[\\mathbf{X}]]_\\top$ has dimension + $(C_\\text{out} \\cdot K_1 \\cdot K_2 \\cdots) \\times (I_1 \\cdot I_2 \\cdots)$ and + can be used to express transpose convolution as matrix multiplication, + + $$ + \\mathrm{mat}(\\mathbf{Y}) + = + \\mathrm{mat}(\\mathbf{W}) + [[\\mathbf{X})]]_\\top + \\,, + $$ + + where $\\mathrm{mat}(\\mathbf{Y}) \\in \\mathbb{R}^{C_\\text{in}\\times (I_1\\cdot + I_2 \\cdots)}$ and $\\mathrm{mat}(\\mathbf{W}) \\in \\mathbb{R}^{C_\\text{in}\\times + (C_\\text{out} \\cdot K_1\\cdot K_2 \\cdots)}$ are matrix views of $\\mathbf{Y}, + \\mathbf{W}$ (note that $\\mathbf{W}$ must also be transposed before matricizing). + Args: x: Input to the `N`d transpose convolution. Has shape `[batch_size, in_channels, *input_sizes]` where `len(input_sizes) == N`. diff --git a/mkdocs.yml b/mkdocs.yml index 1c29bef..a950584 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -1,8 +1,10 @@ site_name: Einconv -site_url: https://example.com # TODO Fill in the link from the hosting platform +site_url: https://einconv.readthedocs.io repo_url: https://github.com/f-dangel/einconv/ repo_name: f-dangel/einconv site_author: Felix Dangel +watch: + - einconv nav: - Getting Started: index.md - Tutorials: diff --git a/test/expressions/conv_transposeNd_kfac_reduce_cases.py b/test/expressions/conv_transposeNd_kfac_reduce_cases.py new file mode 100644 index 0000000..6ec5bb2 --- /dev/null +++ b/test/expressions/conv_transposeNd_kfac_reduce_cases.py @@ -0,0 +1,117 @@ +"""Test cases for einsum expression of input-based KFAC-reduce for conv. transpose.""" + +from test.utils import make_id + +from torch import rand + +TRANSPOSE_KFAC_REDUCE_1D_CASES = [ + # no kwargs + { + "seed": 0, + # (batch_size, in_channels, num_pixels) + "input_fn": lambda: rand(2, 3, 8), + "kernel_size": 5, + "kwargs": {}, + }, + # non-default stride + { + "seed": 0, + "input_fn": lambda: rand(2, 3, 8), + "kernel_size": 5, + "kwargs": {"stride": 2}, + }, + # non-default stride, groups + { + "seed": 0, + "input_fn": lambda: rand(2, 4, 8), + "kernel_size": 5, + "kwargs": {"stride": 3, "groups": 2}, + }, + # non-default padding + { + "seed": 0, + "input_fn": lambda: rand(2, 3, 8), + "kernel_size": 5, + "kwargs": {"padding": 2}, + }, + # non-default output padding + { + "seed": 0, + "input_fn": lambda: rand(2, 3, 8), + "kernel_size": 5, + "kwargs": {"output_padding": 1, "stride": 2}, + }, + # non-default dilation + { + "seed": 0, + "input_fn": lambda: rand(2, 3, 8), + "kernel_size": 5, + "kwargs": {"dilation": 2}, + }, + # non-default arguments supplied as tuple + { + "seed": 0, + "input_fn": lambda: rand(2, 3, 8), + "kernel_size": (3,), + "kwargs": { + "padding": (1,), + "stride": (2,), + "dilation": (1,), + "output_padding": (1,), + }, + }, +] +TRANSPOSE_KFAC_REDUCE_1D_IDS = [ + make_id(case) for case in TRANSPOSE_KFAC_REDUCE_1D_CASES +] + +TRANSPOSE_KFAC_REDUCE_2D_CASES = [ + # no kwargs + { + "seed": 0, + # (batch_size, in_channels, num_pixels_h, num_pixels_w) + "input_fn": lambda: rand(2, 3, 8, 7), + "kernel_size": 5, + "kwargs": {}, + }, + # non-default kwargs supplied as tuple + { + "seed": 0, + "input_fn": lambda: rand(2, 3, 8, 7), + "kernel_size": (5, 3), + "kwargs": { + "padding": (1, 2), + "stride": (2, 3), + "dilation": (2, 1), + "output_padding": (1, 2), + }, + }, +] +TRANSPOSE_KFAC_REDUCE_2D_IDS = [ + make_id(case) for case in TRANSPOSE_KFAC_REDUCE_2D_CASES +] + +TRANSPOSE_KFAC_REDUCE_3D_CASES = [ + # no kwargs + { + "seed": 0, + "input_fn": lambda: rand(2, 3, 8, 7, 6), + "kernel_size": 5, + "kwargs": {}, + }, + # non-default kwargs supplied as tuple + { + "seed": 0, + "input_fn": lambda: rand(2, 3, 8, 7, 6), + "kernel_size": (5, 3, 2), + "kwargs": { + "padding": (0, 1, 2), + "stride": (3, 2, 1), + "dilation": (1, 2, 3), + "output_padding": (2, 0, 0), + }, + }, +] +TRANSPOSE_KFAC_REDUCE_3D_IDS = [ + make_id(case) for case in TRANSPOSE_KFAC_REDUCE_3D_CASES +] diff --git a/test/expressions/test_convNd_kfac_reduce.py b/test/expressions/test_convNd_kfac_reduce.py index e6c7677..083e142 100644 --- a/test/expressions/test_convNd_kfac_reduce.py +++ b/test/expressions/test_convNd_kfac_reduce.py @@ -50,7 +50,7 @@ def test_einsum_expression(case: Dict, device: torch.device, simplify: bool): avg_unfolded_x = unfolded_x.mean(dim=-1) groups = kwargs.get("groups", 1) avg_unfolded_x = rearrange(avg_unfolded_x, "n (g c_in_k) -> n g c_in_k", g=groups) - kfac_unfold = einsum("ngi,ngj->gij", avg_unfolded_x, avg_unfolded_x) / (batch_size) + kfac_unfold = einsum("ngi,ngj->gij", avg_unfolded_x, avg_unfolded_x) / batch_size equation, operands, shape = convNd_kfac_reduce.einsum_expression( x, kernel_size, **kwargs, simplify=simplify diff --git a/test/expressions/test_conv_transposeNd_kfac_reduce.py b/test/expressions/test_conv_transposeNd_kfac_reduce.py new file mode 100644 index 0000000..79fb637 --- /dev/null +++ b/test/expressions/test_conv_transposeNd_kfac_reduce.py @@ -0,0 +1,63 @@ +"""Tests ``einconv.expressions.conv_transposeNd_kfac_reduce``.""" + +from test.expressions.conv_transposeNd_kfac_reduce_cases import ( + TRANSPOSE_KFAC_REDUCE_1D_CASES, + TRANSPOSE_KFAC_REDUCE_1D_IDS, + TRANSPOSE_KFAC_REDUCE_2D_CASES, + TRANSPOSE_KFAC_REDUCE_2D_IDS, + TRANSPOSE_KFAC_REDUCE_3D_CASES, + TRANSPOSE_KFAC_REDUCE_3D_IDS, +) +from test.utils import DEVICE_IDS, DEVICES, SIMPLIFIES, SIMPLIFY_IDS, report_nonclose +from typing import Dict + +import unfoldNd +from einops import rearrange +from pytest import mark +from torch import device, einsum, manual_seed + +from einconv.expressions import conv_transposeNd_kfac_reduce + + +@mark.parametrize("simplify", SIMPLIFIES, ids=SIMPLIFY_IDS) +@mark.parametrize( + "case", + TRANSPOSE_KFAC_REDUCE_1D_CASES + + TRANSPOSE_KFAC_REDUCE_2D_CASES + + TRANSPOSE_KFAC_REDUCE_3D_CASES, + ids=TRANSPOSE_KFAC_REDUCE_1D_IDS + + TRANSPOSE_KFAC_REDUCE_2D_IDS + + TRANSPOSE_KFAC_REDUCE_3D_IDS, +) +@mark.parametrize("dev", DEVICES, ids=DEVICE_IDS) +def test_einsum_expression(case: Dict, dev: device, simplify: bool): + """Compare einsum expression of KFAC reduce with implementation via unfolding. + + Args: + case: Dictionary describing the test case. + dev: Device to execute the test on. + simplify: Whether to simplify the einsum expression. + """ + seed = case["seed"] + input_fn = case["input_fn"] + kernel_size = case["kernel_size"] + kwargs = case["kwargs"] + + manual_seed(seed) + x = input_fn().to(dev) + batch_size = x.shape[0] + + # ground truth + unfold_kwargs = {key: value for key, value in kwargs.items() if key != "groups"} + unfolded_x = unfoldNd.unfold_transposeNd(x, kernel_size, **unfold_kwargs) + avg_unfolded_x = unfolded_x.mean(dim=-1) + groups = kwargs.get("groups", 1) + avg_unfolded_x = rearrange(avg_unfolded_x, "n (g c_in_k) -> n g c_in_k", g=groups) + kfac_unfold = einsum("ngi,ngj->gij", avg_unfolded_x, avg_unfolded_x) / batch_size + + equation, operands, shape = conv_transposeNd_kfac_reduce.einsum_expression( + x, kernel_size, **kwargs, simplify=simplify + ) + kfac_einconv = einsum(equation, *operands).reshape(shape) + + report_nonclose(kfac_unfold, kfac_einconv) diff --git a/test/functionals/test_unfold_transpose.py b/test/functionals/test_unfold_transpose.py index e2778c2..9c15325 100644 --- a/test/functionals/test_unfold_transpose.py +++ b/test/functionals/test_unfold_transpose.py @@ -11,6 +11,7 @@ from test.utils import DEVICE_IDS, DEVICES, SIMPLIFIES, SIMPLIFY_IDS, report_nonclose from typing import Dict +import unfoldNd from einops import einsum, rearrange from pytest import mark from torch import ( @@ -34,6 +35,36 @@ ) @mark.parametrize("dev", DEVICES, ids=DEVICE_IDS) def test_unfoldNd_transpose(case: Dict, dev: device, simplify: bool): + """Compare input unfolding for transpose convolution with `unfoldNd` package. + + Args: + case: Dictionary describing the test case. + dev: Device to execute the test on. + simplify: Whether to use a simplified einsum expression. + """ + seed = case["seed"] + input_fn = case["input_fn"] + kernel_size = case["kernel_size"] + kwargs = case["kwargs"] + + manual_seed(seed) + inputs = input_fn().to(dev) + + result = unfoldNd.unfold_transposeNd(inputs, kernel_size, **kwargs) + einconv_result = unfoldNd_transpose( + inputs, kernel_size, **kwargs, simplify=simplify + ) + report_nonclose(result, einconv_result) + + +@mark.parametrize("simplify", SIMPLIFIES, ids=SIMPLIFY_IDS) +@mark.parametrize( + "case", + TRANSPOSE_UNFOLD_1D_CASES + TRANSPOSE_UNFOLD_2D_CASES + TRANSPOSE_UNFOLD_3D_CASES, + ids=TRANSPOSE_UNFOLD_1D_IDS + TRANSPOSE_UNFOLD_2D_IDS + TRANSPOSE_UNFOLD_3D_IDS, +) +@mark.parametrize("dev", DEVICES, ids=DEVICE_IDS) +def test_unfoldNd_transpose_via_conv_transpose(case: Dict, dev: device, simplify: bool): """Compare transpose convolution via input unfolding with built-in one. Args: