Skip to content

Commit

Permalink
[ADD] Implement index grouping symbolically
Browse files Browse the repository at this point in the history
  • Loading branch information
f-dangel committed Nov 3, 2023
1 parent 052ab46 commit bf27057
Show file tree
Hide file tree
Showing 2 changed files with 104 additions and 0 deletions.
68 changes: 68 additions & 0 deletions einconv/simplifications/symbolic.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from typing import Callable, List, Optional, Tuple

import torch
from einops import rearrange
from torch import Tensor, get_default_dtype

from einconv.utils import cpu
Expand Down Expand Up @@ -102,3 +103,70 @@ def _check_transformed_tensor(self, tensor: Tensor, stage_idx: int = 0):
f"Tensor after stage {stage_idx} ({stage}) expected to have shape "
+ f"{shape} (got {tuple(tensor.shape)}) with axes named {indices}."
)

def group(self, indices: Tuple[str, ...]):
"""Combine multiple indices into a single index.
Args:
indices: Indices to group.
Raises:
NotImplementedError: If the indices are not consecutive.
ValueError: If the new index name already exists.
"""
pos = self.indices.index(indices[0])
if indices != self.indices[pos : pos + len(indices)]:
raise NotImplementedError(
f"Only consecutive indices can be grouped. Got {indices} but axes "
+ f"are {self.indices}."
)

group_name = "(" + " ".join(indices) + ")"
if group_name in self.indices:
raise ValueError(f"Index {group_name} already exists.")

# determine dimension and indices of grouped tensor
group_dim = 1
for dim in self.shape[pos : pos + len(indices)]:
group_dim *= dim

new_indices = (
self.indices[:pos] + (group_name,) + self.indices[pos + len(indices) :]
)
new_shape = self.shape[:pos] + (group_dim,) + self.shape[pos + len(indices) :]

# construct transform and update internal state
equation = f"{' '.join(self.indices)} -> {' '.join(new_indices)}"

def apply_grouping(tensor: Tensor) -> Tensor:
"""Group the specified axes into a single one.
Args:
tensor: Tensor to group axes of.
Returns:
Tensor with grouped axes.
"""
return rearrange(tensor, equation)

self.history.append(
(f"group {indices} into {group_name!r}", new_shape, new_indices)
)
self.transforms.append(apply_grouping)
self.indices = new_indices
self.shape = new_shape

def __repr__(self) -> str:
"""Return a string representation of the symbolic tensor.
Returns:
String representation of the symbolic tensor, including its transformation
history.
"""
as_str = f"SymbolicTensor({self.name!r}, {self.shape}, {self.indices})"

as_str += "\nTransformations:"
for idx, (info, shape, indices) in enumerate(self.history):
as_str += "\n\t- "
as_str += f"({idx}) {info}: shape {shape}, indices {indices}"
return as_str
36 changes: 36 additions & 0 deletions test/simplifications/test_symbolic.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,3 +48,39 @@ def test_instantiate_no_history(device: torch.device, dtype: torch.dtype):
assert weight_tensor.dtype == dtype
assert weight_tensor.device == device
report_nonclose(tensor.to(device, dtype), weight_tensor)


@mark.parametrize("dtype", DTYPES, ids=DTYPE_IDS)
@mark.parametrize("device", DEVICES, ids=DEVICE_IDS)
def test_group(device: torch.device, dtype: torch.dtype):
"""Test grouping multiple indices together.
Args:
device: Device to instantiate on after grouping.
dtype: Data type to instantiate with after grouping.
"""
manual_seed(0)

name = "weight"
shape = (4, 3, 5, 2)
indices = ("c_out", "c_in", "k1", "k2")
weight_symbolic = SymbolicTensor(name, shape, indices)
weight_tensor = rand(*shape, device=device, dtype=dtype)

# grouping non-consecutive indices is not supported
with raises(NotImplementedError):
weight_symbolic.group(("c_out", "k1"))

# grouping fails if grouped axis name already exists
poor_naming = ("(c_out c_in)", "c_out", "c_in", "a")
with raises(ValueError):
SymbolicTensor(name, shape, poor_naming).group(("c_out", "c_in"))

# grouping correctly transforms a tensor when instantiating
weight_symbolic.group(("c_out", "c_in"))
grouped_indices = ("(c_out c_in)", "k1", "k2")
grouped_shape = (12, 5, 2)
assert weight_symbolic.indices == grouped_indices
assert weight_symbolic.shape == grouped_shape
grouped_weight_tensor = weight_symbolic.instantiate(weight_tensor)
report_nonclose(weight_tensor.flatten(end_dim=1), grouped_weight_tensor)

0 comments on commit bf27057

Please sign in to comment.