-
Notifications
You must be signed in to change notification settings - Fork 61
Restore from single slice #2084
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Restore from single slice #2084
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I also tried the original idea of simply restore from single replica. But I get the InvalidShardingError. The reason is obvious, the devices of a process are distributed across different replicas.
Devices from one process belonging to different replicas does violate a fundamental assumption. I suppose the validation just needs to be modified to account for this possibility, if we are saying that it is indeed a possibility.
Ideally unit testing can be improved but it might be tricky to emulate this situation in a unit test. FWIW here are some test cases for SingleReplicaArrayHandler - it's private only because it runs with an internal TPU-based test harness. https://gist.github.com/cpgaffney1/35161a6e6f6e1bc7bf2ffd3df543efe5
@@ -224,8 +225,9 @@ def broadcast_one_replica_to_all( | |||
- pytree with broadcasted data | |||
- number of broadcasts performed. | |||
""" | |||
num_replicas = global_mesh.devices.shape[replica_axis_index] | |||
replica_axis_name = global_mesh.axis_names[replica_axis_index] | |||
# num_replicas = global_mesh.devices.shape[replica_axis_index] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't quite understand why this was incorrect. Isn't the contract that the replica_axis_index
-th dimension of the mesh should be the replica dimension?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah sorry, realized the intention here. The idea is to always use a single slice to broadcast, even when n_replicas != n_slices?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes
replica_axis_name = global_mesh.axis_names[replica_axis_index] | ||
# num_replicas = global_mesh.devices.shape[replica_axis_index] | ||
# replica_axis_name = global_mesh.axis_names[replica_axis_index] | ||
replica_axis_name = global_mesh.axis_names[0] # assuming pp dimension is never used |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unused now?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Correct
# Validate merged params. | ||
if enable_validation: | ||
await _validate_params(directory, ts_context, use_zarr3=use_zarr3) | ||
# # Validate merged params. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there also a problem with this check?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is unrelated to the PR here.
Actually, it may take a relative long time (>500s) in some cases to finish the checking. So I disabled it temporarily.
Ideally we'd also have a config on it.
Thanks! Looks like you are using two processes for testing here. Could you also add the |
v0.11.15
) ofslice_devices
method. (This means that it will return the devices from singleslice
instead of singlereplica
)replicas
to broadcast data.InvalidShardingError
. The reason is obvious, the devices of a process are distributed across different replicas.orbax/checkpoint/orbax/checkpoint/_src/serialization/type_handlers.py
Lines 1468 to 1474 in 9fc3716
My questions:
InvalidShardingError
? (My initial thought is that, the resharding should still work after the sum op even though here's a mismatch.)