diff --git a/docs/concepts/sharding.md b/docs/concepts/sharding.md index d146e4085..2fe56d3af 100644 --- a/docs/concepts/sharding.md +++ b/docs/concepts/sharding.md @@ -88,3 +88,14 @@ dcn_pipeline_parallelism: 1 dcn_expert_parallelism: 1 dcn_autoregressive_parallelism: 1 # never recommended ``` + +## Sharding implementation details + +You may think of the sharding in the maxtext codebase as split into three levels +1. The physical mesh where e.g. `ici_fsdp_parallelism` is used - see [`create_device_mesh`](https://github.com/AI-Hypercomputer/maxtext/blob/e7c4824ee9cc13fd6db863796bbe7696b03eb448/MaxText/max_utils.py#L363) +2. The logical names, with physical <-> logical mappings [here](https://github.com/AI-Hypercomputer/maxtext/blob/e7c4824ee9cc13fd6db863796bbe7696b03eb448/MaxText/configs/base.yml#L211-L248) +3. Individual tensors which will use logical names, here is one [example](https://github.com/AI-Hypercomputer/maxtext/blob/e7c4824ee9cc13fd6db863796bbe7696b03eb448/MaxText/layers/linears.py#L243) + +Following this example we see the first axis is sharded by logical name "embed". This logical name maps the physical names "fsdp, fsdp_transpose, sequence, expert", thus this axes will get sharded by the product of these specified parallelisms. E.g. if `ici_fsdp_parallelism=4` and `ici_sequence_parallelism=2` then this array axis will get sharded 8 ways. + +This example showed a "kernel_axes" which is used to define a weight matrix. For activations we use shardings hints for the compiler such as `nn.with_logical_constraint` (example [here](https://github.com/AI-Hypercomputer/maxtext/blob/e7c4824ee9cc13fd6db863796bbe7696b03eb448/MaxText/layers/linears.py#L261)). This will generally shard the activations according to these constraints, but the compiler occasionally chooses a different sharding other that what we specified for these activations.