Skip to content

Restore from single slice #2084

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 11 commits into
base: main
Choose a base branch
from

Conversation

findmyway
Copy link
Contributor

@findmyway findmyway commented Jul 9, 2025

  1. Note that I'm still using an old version (v0.11.15) of slice_devices method. (This means that it will return the devices from single slice instead of single replica)
  2. The basic idea is to keep the original implementation almost unchanged. And I created another dimension of replicas to broadcast data.
  3. I also tried the original idea of simply restore from single replica. But I get the InvalidShardingError. The reason is obvious, the devices of a process are distributed across different replicas.

if primary_replica_ids != expected_primary_replica_ids:
raise InvalidShardingError(
'The provided sharding is not valid. The primary replica has the'
f' following devices: {primary_replica_ids}, but process indices'
' associated with primary replica devices are expected to be:'
f' {primary_replica_pids}.'
)

My questions:

  1. Any obvious errors or potential improvements with my current implementation?
    • One datapoint from my latest test: ~30s on deserialization plus ~60s on broadcasting (only one broadcast in total).
  2. Any idea on how to address the above InvalidShardingError? (My initial thought is that, the resharding should still work after the sum op even though here's a mismatch.)

Copy link
Collaborator

@cpgaffney1 cpgaffney1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also tried the original idea of simply restore from single replica. But I get the InvalidShardingError. The reason is obvious, the devices of a process are distributed across different replicas.

Devices from one process belonging to different replicas does violate a fundamental assumption. I suppose the validation just needs to be modified to account for this possibility, if we are saying that it is indeed a possibility.

Ideally unit testing can be improved but it might be tricky to emulate this situation in a unit test. FWIW here are some test cases for SingleReplicaArrayHandler - it's private only because it runs with an internal TPU-based test harness. https://gist.github.com/cpgaffney1/35161a6e6f6e1bc7bf2ffd3df543efe5

@@ -224,8 +225,9 @@ def broadcast_one_replica_to_all(
- pytree with broadcasted data
- number of broadcasts performed.
"""
num_replicas = global_mesh.devices.shape[replica_axis_index]
replica_axis_name = global_mesh.axis_names[replica_axis_index]
# num_replicas = global_mesh.devices.shape[replica_axis_index]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't quite understand why this was incorrect. Isn't the contract that the replica_axis_index-th dimension of the mesh should be the replica dimension?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah sorry, realized the intention here. The idea is to always use a single slice to broadcast, even when n_replicas != n_slices?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes

replica_axis_name = global_mesh.axis_names[replica_axis_index]
# num_replicas = global_mesh.devices.shape[replica_axis_index]
# replica_axis_name = global_mesh.axis_names[replica_axis_index]
replica_axis_name = global_mesh.axis_names[0] # assuming pp dimension is never used
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unused now?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correct

# Validate merged params.
if enable_validation:
await _validate_params(directory, ts_context, use_zarr3=use_zarr3)
# # Validate merged params.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there also a problem with this check?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is unrelated to the PR here.

Actually, it may take a relative long time (>500s) in some cases to finish the checking. So I disabled it temporarily.

Ideally we'd also have a config on it.

@findmyway
Copy link
Contributor Author

Thanks!

Looks like you are using two processes for testing here.

Could you also add the (2, 4) mesh shape in the test below? Make sure the mesh is created from jax.experimental.mesh_utils.create_device_mesh to validate my assumption.

https://gist.github.com/cpgaffney1/35161a6e6f6e1bc7bf2ffd3df543efe5#file-type_handlers_test-L291-L295

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants