Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Integrate Orbax's emergency checkpoint. #820

Open
wants to merge 11 commits into
base: main
Choose a base branch
from

Conversation

hanzhi713
Copy link
Member

No description provided.

@hanzhi713 hanzhi713 force-pushed the in-mem-ckpt branch 2 times, most recently from b4a00eb to 0e74dc8 Compare November 12, 2024 19:49
@hanzhi713 hanzhi713 force-pushed the in-mem-ckpt branch 2 times, most recently from cf43485 to 56f51de Compare November 19, 2024 00:37
@hanzhi713 hanzhi713 force-pushed the in-mem-ckpt branch 2 times, most recently from d5a3e0f to bea8b71 Compare January 8, 2025 22:45
@hanzhi713 hanzhi713 marked this pull request as ready for review January 30, 2025 23:13
@hanzhi713 hanzhi713 requested review from ruomingp, markblee and a team as code owners January 30, 2025 23:13
del os.environ["JAX_PLATFORMS"]


class OrbaxEmergencyCheckpointer(BaseCheckpointer):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How does this overlap with

class OrbaxCheckpointer(BaseCheckpointer):
?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The orbax regular checkpointer saves 1 gcs checkpoint for n slices per save. This checkpointer saves n-1 checkpoints to a local path (usually a ramdisk), and also 1 checkpoint to gcs.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately it's not possible to share code between the two implementations.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the clarification.

  • How should users use them? Should they use both of them or one of them? How should they pick?
  • Or can we replace OrbaxCheckpointer with this class?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a comment to clarify this:

    This checkpointer is designed to improve the goodput of large multi-slice training jobs that
    use data-parallelism across slices. At least two data-parallel slices are required. For other
    use cases where this is not applicable or ultimate goodput is not required, please use
    `OrbaxCheckpointer`.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not just data-parallelism, but data-parallel slices, i.e. it has to be multi-slice training.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIUC, only the local checkpoints require multiple slices, since we will need to restore from another slice upon a slice restart. Could we disable local checkpoints when num_slices=1? This way we always use the emergency checkpointer consistently.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could be an idea.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Still needs support from orbax though.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. Right now it assumes local checkpoints are always used: https://github.com/google/orbax/blob/6e80ecc27581a413b1a481d4740e61df7316a4f4/checkpoint/orbax/checkpoint/experimental/emergency/checkpoint_manager.py#L695-L709.

Could you raise this request to the Orbax team and link to the issue?

@ruomingp ruomingp self-assigned this Jan 31, 2025
@hanzhi713 hanzhi713 requested a review from ruomingp January 31, 2025 21:35
@ruomingp ruomingp removed their request for review February 1, 2025 02:51
Copy link
Contributor

@ruomingp ruomingp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the explanation of the constraints. I wonder what the long term plan is.

Is the emergency checkpointer a temporary solution that will eventually be dropped when the main Orbax checkpointer supports in-memory checkpoints?

Or will we keep maintaining two separate checkpointers, with potentially incompatible ckpt layouts?

@hanzhi713
Copy link
Member Author

Is the emergency checkpointer a temporary solution that will eventually be dropped when the main Orbax checkpointer supports in-memory checkpoints?

I don't know if Google has such a plan. The orbax in-memory checkpointer actually uses the orbax regular checkpointer under the hood, which might be required by design/by nature of the problem that it solves.

Or will we keep maintaining two separate checkpointers, with potentially incompatible ckpt layouts?

Since in-memory checkpointer uses the regular orbax checkpointer under the hood, the tensor state in the persistent checkpoint (i.e. the one stored to gcs) can be loaded by OrbaxStateBuilder (see #866). Therefore, we can say that the checkpoints are compatible for eval and inference purposes. It's just that the training checkpoint will be incompatible, meaning that OrbaxEmergencyCheckpointer's checkpoint cannot be loaded by OrbaxCheckpointer.

@hanzhi713
Copy link
Member Author

I think in the long term, it's probably possible to unify the checkpoint structure between the two checkpointer (regular and in-memory), but it's unknown whether we can unify the codepath.

del os.environ["JAX_PLATFORMS"]


class OrbaxEmergencyCheckpointer(BaseCheckpointer):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIUC, only the local checkpoints require multiple slices, since we will need to restore from another slice upon a slice restart. Could we disable local checkpoints when num_slices=1? This way we always use the emergency checkpointer consistently.

Comment on lines 744 to 748
# Note that save() waits for prior serialization to finish.
self._non_tensor_manager.save(step=step, state=state)
self._get_tensor_manager(state_with_tensors).save(
step=step, args=ocp.args.PyTreeSave(item=state_with_tensors)
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How do we mark the completion of a checkpoint? It should happen only when both tensor and non-tensor states are saved. How is this ensured?

Please add a comment.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's no special marker for completion of both. There are only markers for each of them individually. So, during restore, we look for both of them only load a specific step when both marker exists.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. Can you add a comment here and point to where "we look for both of them only load a specific step when both marker exists"?

Do we have testing for incomplete checkpoints?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can add a testcase.

@hanzhi713
Copy link
Member Author

@ruomingp I guess my question now is what's the plan here. Should we wait for Orbax's support for non tensor states and unified checkpointer API?

I personally don't see in-mem ckpt as a life-changing feature, so waiting could be an viable option. Alternatively, we can proceed with this PR and make changes later.

@ruomingp
Copy link
Contributor

ruomingp commented Feb 2, 2025

@ruomingp I guess my question now is what's the plan here. Should we wait for Orbax's support for non tensor states and unified checkpointer API?

I personally don't see in-mem ckpt as a life-changing feature, so waiting could be an viable option. Alternatively, we can proceed with this PR and make changes later.

There's no best solution. I see three possibilities:
1 Wait until Orbax supports a Checkpointer that works for both single-slice and multi-slice settings with a unified persistent layout;
1 Merge this PR for testing, but be willing to discard checkpoints if we decide to change the persistent layout later;
1 Merge this PR and commit to unify the checkpointers and build tools to convert checkpoints for users;

I do not think we want to maintain two Orbax checkpointers in the longer run, especially with incompatible layouts.

WDYT?

Comment on lines 744 to 748
# Note that save() waits for prior serialization to finish.
self._non_tensor_manager.save(step=step, state=state)
self._get_tensor_manager(state_with_tensors).save(
step=step, args=ocp.args.PyTreeSave(item=state_with_tensors)
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. Can you add a comment here and point to where "we look for both of them only load a specific step when both marker exists"?

Do we have testing for incomplete checkpoints?

global_mesh=thread_resources.env.physical_mesh,
abstract_state=self._get_abstract_state(state_with_tensors),
options=oecp.CheckpointManagerOptions(
local=oecp.LocalCheckpointOptions(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shall we expose oecp.LocalCheckpointOptions to users as Config.local_checkpoint_options? User can set it to None to disable local checkpoints.

We can provide a helper function for users to construct should_save_fn from their policy.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's already exposed, via local_keep_last_n and local_save_policy

Comment on lines 770 to 772
# Find the intersection of the checkpoint steps managed by tensor and non-tensor
# manager, and then use the latest step in the intersection for restore. `all_steps`
# from tensor manager contains both local and persistent checkpoints.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider refactoring this logic to a separate function so that it can be tested directly?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done and added a test.

del os.environ["JAX_PLATFORMS"]


class OrbaxEmergencyCheckpointer(BaseCheckpointer):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. Right now it assumes local checkpoints are always used: https://github.com/google/orbax/blob/6e80ecc27581a413b1a481d4740e61df7316a4f4/checkpoint/orbax/checkpoint/experimental/emergency/checkpoint_manager.py#L695-L709.

Could you raise this request to the Orbax team and link to the issue?

@kelvin-zou
Copy link
Contributor

@ruomingp I guess my question now is what's the plan here. Should we wait for Orbax's support for non tensor states and unified checkpointer API?
I personally don't see in-mem ckpt as a life-changing feature, so waiting could be an viable option. Alternatively, we can proceed with this PR and make changes later.

There's no best solution. I see three possibilities: 1 Wait until Orbax supports a Checkpointer that works for both single-slice and multi-slice settings with a unified persistent layout; 1 Merge this PR for testing, but be willing to discard checkpoints if we decide to change the persistent layout later; 1 Merge this PR and commit to unify the checkpointers and build tools to convert checkpoints for users;

I do not think we want to maintain two Orbax checkpointers in the longer run, especially with incompatible layouts.

WDYT?

I think maybe waiting may not be the best idea, given that Orbax dev is often delayed, it has been for over a year since last time Anthropic raised this in-mem checkpointing feature request.

I think both options below works for me

  1. Merge this PR for testing, but be willing to discard checkpoints if we decide to change the persistent layout later
  2. Merge this PR and commit to unify the checkpointers and build tools to convert checkpoints for users;

Maybe let's merge it for testing (we need quite sometime to test throughout), and then wait for a quarter and see if Orbax can really make it happen and available to our users? if so we can discard the tf iterator handling and use the upstream unified one. If not, then we can commit to unification ourselves?

@hanzhi713
Copy link
Member Author

@kelvin-zou SGTM.

@hanzhi713 hanzhi713 requested a review from ruomingp February 3, 2025 18:32
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants