-
Notifications
You must be signed in to change notification settings - Fork 287
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Integrate Orbax's emergency checkpoint. #820
base: main
Are you sure you want to change the base?
Conversation
b4a00eb
to
0e74dc8
Compare
cf43485
to
56f51de
Compare
d5a3e0f
to
bea8b71
Compare
bea8b71
to
65f3d46
Compare
65f3d46
to
c1a476d
Compare
del os.environ["JAX_PLATFORMS"] | ||
|
||
|
||
class OrbaxEmergencyCheckpointer(BaseCheckpointer): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How does this overlap with
axlearn/axlearn/common/checkpointer_orbax.py
Line 169 in 140a18f
class OrbaxCheckpointer(BaseCheckpointer): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The orbax regular checkpointer saves 1 gcs checkpoint for n slices per save. This checkpointer saves n-1 checkpoints to a local path (usually a ramdisk), and also 1 checkpoint to gcs.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unfortunately it's not possible to share code between the two implementations.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the clarification.
- How should users use them? Should they use both of them or one of them? How should they pick?
- Or can we replace OrbaxCheckpointer with this class?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added a comment to clarify this:
This checkpointer is designed to improve the goodput of large multi-slice training jobs that
use data-parallelism across slices. At least two data-parallel slices are required. For other
use cases where this is not applicable or ultimate goodput is not required, please use
`OrbaxCheckpointer`.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not just data-parallelism, but data-parallel slices, i.e. it has to be multi-slice training.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IIUC, only the local checkpoints require multiple slices, since we will need to restore from another slice upon a slice restart. Could we disable local checkpoints when num_slices=1? This way we always use the emergency checkpointer consistently.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could be an idea.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Still needs support from orbax though.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see. Right now it assumes local checkpoints are always used: https://github.com/google/orbax/blob/6e80ecc27581a413b1a481d4740e61df7316a4f4/checkpoint/orbax/checkpoint/experimental/emergency/checkpoint_manager.py#L695-L709.
Could you raise this request to the Orbax team and link to the issue?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the explanation of the constraints. I wonder what the long term plan is.
Is the emergency checkpointer a temporary solution that will eventually be dropped when the main Orbax checkpointer supports in-memory checkpoints?
Or will we keep maintaining two separate checkpointers, with potentially incompatible ckpt layouts?
I don't know if Google has such a plan. The orbax in-memory checkpointer actually uses the orbax regular checkpointer under the hood, which might be required by design/by nature of the problem that it solves.
Since in-memory checkpointer uses the regular orbax checkpointer under the hood, the tensor state in the persistent checkpoint (i.e. the one stored to gcs) can be loaded by |
I think in the long term, it's probably possible to unify the checkpoint structure between the two checkpointer (regular and in-memory), but it's unknown whether we can unify the codepath. |
del os.environ["JAX_PLATFORMS"] | ||
|
||
|
||
class OrbaxEmergencyCheckpointer(BaseCheckpointer): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IIUC, only the local checkpoints require multiple slices, since we will need to restore from another slice upon a slice restart. Could we disable local checkpoints when num_slices=1? This way we always use the emergency checkpointer consistently.
# Note that save() waits for prior serialization to finish. | ||
self._non_tensor_manager.save(step=step, state=state) | ||
self._get_tensor_manager(state_with_tensors).save( | ||
step=step, args=ocp.args.PyTreeSave(item=state_with_tensors) | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How do we mark the completion of a checkpoint? It should happen only when both tensor and non-tensor states are saved. How is this ensured?
Please add a comment.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There's no special marker for completion of both. There are only markers for each of them individually. So, during restore, we look for both of them only load a specific step when both marker exists.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks. Can you add a comment here and point to where "we look for both of them only load a specific step when both marker exists"?
Do we have testing for incomplete checkpoints?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I can add a testcase.
@ruomingp I guess my question now is what's the plan here. Should we wait for Orbax's support for non tensor states and unified checkpointer API? I personally don't see in-mem ckpt as a life-changing feature, so waiting could be an viable option. Alternatively, we can proceed with this PR and make changes later. |
There's no best solution. I see three possibilities: I do not think we want to maintain two Orbax checkpointers in the longer run, especially with incompatible layouts. WDYT? |
# Note that save() waits for prior serialization to finish. | ||
self._non_tensor_manager.save(step=step, state=state) | ||
self._get_tensor_manager(state_with_tensors).save( | ||
step=step, args=ocp.args.PyTreeSave(item=state_with_tensors) | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks. Can you add a comment here and point to where "we look for both of them only load a specific step when both marker exists"?
Do we have testing for incomplete checkpoints?
global_mesh=thread_resources.env.physical_mesh, | ||
abstract_state=self._get_abstract_state(state_with_tensors), | ||
options=oecp.CheckpointManagerOptions( | ||
local=oecp.LocalCheckpointOptions( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shall we expose oecp.LocalCheckpointOptions
to users as Config.local_checkpoint_options
? User can set it to None to disable local checkpoints.
We can provide a helper function for users to construct should_save_fn
from their policy.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's already exposed, via local_keep_last_n
and local_save_policy
# Find the intersection of the checkpoint steps managed by tensor and non-tensor | ||
# manager, and then use the latest step in the intersection for restore. `all_steps` | ||
# from tensor manager contains both local and persistent checkpoints. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Consider refactoring this logic to a separate function so that it can be tested directly?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done and added a test.
del os.environ["JAX_PLATFORMS"] | ||
|
||
|
||
class OrbaxEmergencyCheckpointer(BaseCheckpointer): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see. Right now it assumes local checkpoints are always used: https://github.com/google/orbax/blob/6e80ecc27581a413b1a481d4740e61df7316a4f4/checkpoint/orbax/checkpoint/experimental/emergency/checkpoint_manager.py#L695-L709.
Could you raise this request to the Orbax team and link to the issue?
I think maybe waiting may not be the best idea, given that Orbax dev is often delayed, it has been for over a year since last time Anthropic raised this in-mem checkpointing feature request. I think both options below works for me
Maybe let's merge it for testing (we need quite sometime to test throughout), and then wait for a quarter and see if Orbax can really make it happen and available to our users? if so we can discard the tf iterator handling and use the upstream unified one. If not, then we can commit to unification ourselves? |
@kelvin-zou SGTM. |
No description provided.