Skip to content

Commit c1832a3

Browse files
yashk2810Flax Authors
authored andcommitted
Canonicalize PartitionSpec so that we can delete ParsedPartitionSpec
* `_partitions` is now canonicalized and only contains `tuples`, `None` or `UNCONSTRAINED`. * Cache the creating of sharding on ShapedArray since it's expensive to do it a lot of times * Change the `__hash__` and `__eq__` of `NamedSharding` to depend on `self.spec` instead of `self._parsed_pspec`. PiperOrigin-RevId: 730436599
1 parent 5413850 commit c1832a3

File tree

1 file changed

+8
-1
lines changed

1 file changed

+8
-1
lines changed

flax/linen/spmd.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,14 @@ def _logical_to_mesh_axes(
105105
# We assign mesh axes using a priority based ruleset over logical axis names.
106106
result: list[_UnassignedAxis | None | str | tuple[str, ...]]
107107
result = [
108-
(_unassigned_axis if isinstance(name, str) else name)
108+
(
109+
_unassigned_axis
110+
if (
111+
isinstance(name, str)
112+
or (isinstance(name, tuple) and len(name) == 1)
113+
)
114+
else name
115+
)
109116
for name in array_dim_names
110117
]
111118
for rule_model_name, rule_mesh_names in rules:

0 commit comments

Comments
 (0)