Skip to content

Commit

Permalink
[REF] Extract get_letters utility
Browse files Browse the repository at this point in the history
  • Loading branch information
f-dangel committed Apr 28, 2023
1 parent 8cc036e commit 8939770
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 27 deletions.
11 changes: 2 additions & 9 deletions einconv/einconvnd.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from torch.nn import Conv1d, Conv2d, Conv3d, Module, Parameter, init

from einconv.index_pattern import conv_index_pattern
from einconv.utils import _tuple, sync_parameters
from einconv.utils import _tuple, get_letters, sync_parameters


class EinconvNd(Module):
Expand Down Expand Up @@ -310,14 +310,7 @@ def _conv_einsum_equation(N: int) -> str:
kernel_str = ""

# requires 4 + 3 * N letters
# einsum can deal with the 26 lowercase letters of the alphabet
max_letters, required_letters = 26, 4 + 3 * N
if required_letters > max_letters:
raise ValueError(
f"Cannot form einsum equation. Need {required_letters} letters."
+ f" But einsum only supports {max_letters}."
)
letters = [chr(ord("a") + i) for i in range(required_letters)]
letters = get_letters(4 + 3 * N)

# batch dimension
batch_letter = letters.pop()
Expand Down
11 changes: 2 additions & 9 deletions einconv/kfc.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from torch import Tensor, einsum

from einconv.index_pattern import conv_index_pattern
from einconv.utils import _tuple
from einconv.utils import _tuple, get_letters


def kfc_factor(
Expand Down Expand Up @@ -75,14 +75,7 @@ def _kfc_factor_einsum_equation(N: int) -> str:
output_str = ""

# requires 3 + 5 * N letters
# einsum can deal with the 26 lowercase letters of the alphabet
max_letters, required_letters = 26, 3 + 5 * N
if required_letters > max_letters:
raise ValueError(
f"Cannot form einsum equation. Need {required_letters} letters."
+ f" But einsum only supports {max_letters}."
)
letters = [chr(ord("a") + i) for i in range(required_letters)]
letters = get_letters(3 + 5 * N)

# batch dimension
batch_letter = letters.pop()
Expand Down
11 changes: 2 additions & 9 deletions einconv/unfoldnd.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from torch.nn import Module

from einconv.index_pattern import conv_index_pattern
from einconv.utils import _tuple
from einconv.utils import _tuple, get_letters


class UnfoldNd(Module):
Expand Down Expand Up @@ -110,14 +110,7 @@ def _unfold_einsum_equation(N: int) -> str:
pattern_strs: List[str] = []

# requires 2 + 3 * N letters
# einsum can deal with the 26 lowercase letters of the alphabet
max_letters, required_letters = 26, 2 + 3 * N
if required_letters > max_letters:
raise ValueError(
f"Cannot form einsum equation. Need {required_letters} letters."
+ f" But einsum only supports {max_letters}."
)
letters = [chr(ord("a") + i) for i in range(required_letters)]
letters = get_letters(2 + 3 * N)

# batch dimension
batch_letter = letters.pop()
Expand Down
20 changes: 20 additions & 0 deletions einconv/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,3 +160,23 @@ def get_conv_output_size(
)
/ stride
)


def get_letters(num_letters: int) -> List[str]:
"""Return a list of ``num_letters`` unique letters for an einsum equation.
Args:
num_letters: Number of letters to return.
Returns:
List of ``num_letters`` unique letters.
Raises:
ValueError: If ``num_letters`` is larger than the maximum number of letters.
"""
max_letters = 26
if num_letters > max_letters:
raise ValueError(
f"einsum supports {max_letters} letters. Requested {num_letters}."
)
return [chr(ord("a") + i) for i in range(num_letters)]

0 comments on commit 8939770

Please sign in to comment.