Skip to content

Commit d10bd8c

Browse files
committed
fix #592
1 parent b8285c9 commit d10bd8c

2 files changed

Lines changed: 34 additions & 7 deletions

File tree

nequip/model/param_groups.py

Lines changed: 26 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,24 @@
22
import torch
33

44

5+
def _normalize_weight_index_slices(weight_index_slices):
6+
normalized = []
7+
for entry in weight_index_slices:
8+
index_slice = getattr(entry, "slice_1D", None)
9+
shape_2d = getattr(entry, "shape_2D", None)
10+
if index_slice is None or shape_2d is None:
11+
index_slice, shape_2d = entry
12+
if isinstance(index_slice, slice):
13+
index_slice = (index_slice.start, index_slice.stop, index_slice.step)
14+
else:
15+
index_slice = tuple(index_slice)
16+
assert len(index_slice) == 3
17+
shape_2d = tuple(shape_2d)
18+
assert len(shape_2d) == 2
19+
normalized.append((index_slice, shape_2d))
20+
return normalized
21+
22+
523
def MuonParamGroups(
624
model: torch.nn.Module,
725
muon: dict,
@@ -57,13 +75,15 @@ def MuonParamGroups(
5775
module_name, _, _ = name.rpartition(".")
5876
module = modules[module_name]
5977

60-
# Attribute from e3nn giving the slices/shapes
61-
# of the corresponding linear weight
62-
index = len(muon_weights)
63-
slices = module.weight_index_slices
64-
65-
e3nn_reshaping[index] = slices
78+
# use Muon only when reshape metadata is available
79+
weight_index_slices = getattr(module, "weight_index_slices", None)
80+
if weight_index_slices is None:
81+
adam_weights.append(param)
82+
continue
6683

84+
# store plain tuples to keep optimizer state picklable
85+
index = len(muon_weights)
86+
e3nn_reshaping[index] = _normalize_weight_index_slices(weight_index_slices)
6787
muon_weights.append(param)
6888
continue
6989

nequip/train/muon.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,14 @@ def muon_update(
5757
# NequIP addon: handle reshaping for e3nn ``Linear`` layers
5858
if e3nn_reshaping is not None:
5959
update_list = []
60-
for index_slice, shape_2D in e3nn_reshaping: # square weight slices of updates
60+
for (
61+
index_slice_data,
62+
shape_2D,
63+
) in e3nn_reshaping: # square weight slices of updates
64+
if isinstance(index_slice_data, slice):
65+
index_slice = index_slice_data
66+
else:
67+
index_slice = slice(*index_slice_data)
6168
weight_slice = update[index_slice].reshape(shape_2D)
6269
grad_slice = grad[index_slice].reshape(shape_2D)
6370
update_weight_slice = zeropower_via_newtonschulz5(

0 commit comments

Comments
 (0)