Skip to content

Commit 203b463

Browse files
committed
Merge branch 'mlaz/fix-24.05-pyt-dist' into 'main'
PyT Dist fix for 24.05 container See merge request ADLR/megatron-lm!1823
2 parents 1bb6337 + 58a8a62 commit 203b463

File tree

1 file changed

+6
-2
lines changed
  • megatron/core/dist_checkpointing/strategies

1 file changed

+6
-2
lines changed

megatron/core/dist_checkpointing/strategies/torch.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -221,6 +221,7 @@ def sharded_tensor_to_torch_sharded_tensor(
221221
]
222222

223223
# Create a ShardedTensor without invoking communication. Determine global shards
224+
world_size = torch.distributed.get_world_size()
224225
shard_metadata = []
225226
# NOTE: here we assume a regular grid of shards
226227
for fragment_offsets in itertools.product(*map(range, some_sh_ten.axis_fragmentations)):
@@ -244,13 +245,16 @@ def sharded_tensor_to_torch_sharded_tensor(
244245

245246
else:
246247
# for shards from other ranks we provide simplistic data - this information will be discarded
247-
# during TorchShardedTensor._init_from_local_shards_and_global_metadata call
248+
# during TorchShardedTensor._init_from_local_shards_and_global_metadata call.
249+
# Due to a bug in PyT 24.05 container we must specify some concrete rank within a world size.
250+
# The exact rank doesn't matter as long as it's different than my rank - hence (rank + 1) % WS.
251+
placement = f"rank:{(rank + 1) % world_size}/cuda"
248252
if has_flattened_range and not is_flattened_range_1d:
249253
offset = offset + (0,)
250254
size = (1,) * len(offsets_shape) + global_shape[-1:]
251255
else:
252256
size = offsets_shape
253-
shard_metadata.append(ShardMetadata(offset, size, "cuda"))
257+
shard_metadata.append(ShardMetadata(offset, size, placement))
254258

255259
tensor = some_sh_ten.data
256260
sharded_tensor_metadata = ShardedTensorMetadata(

0 commit comments

Comments
 (0)