File tree Expand file tree Collapse file tree 1 file changed +6
-2
lines changed
megatron/core/dist_checkpointing/strategies Expand file tree Collapse file tree 1 file changed +6
-2
lines changed Original file line number Diff line number Diff line change @@ -221,6 +221,7 @@ def sharded_tensor_to_torch_sharded_tensor(
221
221
]
222
222
223
223
# Create a ShardedTensor without invoking communication. Determine global shards
224
+ world_size = torch .distributed .get_world_size ()
224
225
shard_metadata = []
225
226
# NOTE: here we assume a regular grid of shards
226
227
for fragment_offsets in itertools .product (* map (range , some_sh_ten .axis_fragmentations )):
@@ -244,13 +245,16 @@ def sharded_tensor_to_torch_sharded_tensor(
244
245
245
246
else :
246
247
# for shards from other ranks we provide simplistic data - this information will be discarded
247
- # during TorchShardedTensor._init_from_local_shards_and_global_metadata call
248
+ # during TorchShardedTensor._init_from_local_shards_and_global_metadata call.
249
+ # Due to a bug in PyT 24.05 container we must specify some concrete rank within a world size.
250
+ # The exact rank doesn't matter as long as it's different than my rank - hence (rank + 1) % WS.
251
+ placement = f"rank:{ (rank + 1 ) % world_size } /cuda"
248
252
if has_flattened_range and not is_flattened_range_1d :
249
253
offset = offset + (0 ,)
250
254
size = (1 ,) * len (offsets_shape ) + global_shape [- 1 :]
251
255
else :
252
256
size = offsets_shape
253
- shard_metadata .append (ShardMetadata (offset , size , "cuda" ))
257
+ shard_metadata .append (ShardMetadata (offset , size , placement ))
254
258
255
259
tensor = some_sh_ten .data
256
260
sharded_tensor_metadata = ShardedTensorMetadata (
You can’t perform that action at this time.
0 commit comments