|
2 | 2 | import torch |
3 | 3 |
|
4 | 4 |
|
| 5 | +def _normalize_weight_index_slices(weight_index_slices): |
| 6 | + normalized = [] |
| 7 | + for entry in weight_index_slices: |
| 8 | + index_slice = getattr(entry, "slice_1D", None) |
| 9 | + shape_2d = getattr(entry, "shape_2D", None) |
| 10 | + if index_slice is None or shape_2d is None: |
| 11 | + index_slice, shape_2d = entry |
| 12 | + if isinstance(index_slice, slice): |
| 13 | + index_slice = (index_slice.start, index_slice.stop, index_slice.step) |
| 14 | + else: |
| 15 | + index_slice = tuple(index_slice) |
| 16 | + assert len(index_slice) == 3 |
| 17 | + shape_2d = tuple(shape_2d) |
| 18 | + assert len(shape_2d) == 2 |
| 19 | + normalized.append((index_slice, shape_2d)) |
| 20 | + return normalized |
| 21 | + |
| 22 | + |
5 | 23 | def MuonParamGroups( |
6 | 24 | model: torch.nn.Module, |
7 | 25 | muon: dict, |
@@ -57,13 +75,15 @@ def MuonParamGroups( |
57 | 75 | module_name, _, _ = name.rpartition(".") |
58 | 76 | module = modules[module_name] |
59 | 77 |
|
60 | | - # Attribute from e3nn giving the slices/shapes |
61 | | - # of the corresponding linear weight |
62 | | - index = len(muon_weights) |
63 | | - slices = module.weight_index_slices |
64 | | - |
65 | | - e3nn_reshaping[index] = slices |
| 78 | + # use Muon only when reshape metadata is available |
| 79 | + weight_index_slices = getattr(module, "weight_index_slices", None) |
| 80 | + if weight_index_slices is None: |
| 81 | + adam_weights.append(param) |
| 82 | + continue |
66 | 83 |
|
| 84 | + # store plain tuples to keep optimizer state picklable |
| 85 | + index = len(muon_weights) |
| 86 | + e3nn_reshaping[index] = _normalize_weight_index_slices(weight_index_slices) |
67 | 87 | muon_weights.append(param) |
68 | 88 | continue |
69 | 89 |
|
|
0 commit comments