Skip to content

Commit 85bc8bb

Browse files
angel-coreOrbax Authors
authored andcommitted
No public description
PiperOrigin-RevId: 889877635
1 parent f643ea0 commit 85bc8bb

File tree

9 files changed

+255
-101
lines changed

9 files changed

+255
-101
lines changed

checkpoint/orbax/checkpoint/experimental/v1/_src/context/options.py

Lines changed: 26 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,7 @@ class PyTreeOptions:
186186
187187
Example:
188188
To save certain leaves in float16, while others in float32, we can use
189-
`create_array_storage_options_fn` like so::
189+
`scoped_storage_options_creator` like so::
190190
191191
import jax
192192
import jax.numpy as jnp
@@ -197,11 +197,14 @@ def create_opts_fn(keypath, value):
197197
return ocp_options.ArrayOptions.Saving.StorageOptions(
198198
dtype=jnp.float16
199199
)
200-
return ocp_options.ArrayOptions.Saving.StorageOptions(dtype=jnp.float32)
201-
202-
pytree_options = ocp_options.PyTreeOptions(
203-
saving=ocp_options.PyTreeOptions.Saving(
204-
create_array_storage_options_fn=create_opts_fn
200+
return None # Fall back to global `storage_options`
201+
202+
array_options = ocp_options.ArrayOptions(
203+
saving=ocp_options.ArrayOptions.Saving(
204+
storage_options=ocp_options.ArrayOptions.Saving.StorageOptions(
205+
dtype=jnp.float32
206+
),
207+
scoped_storage_options_creator=create_opts_fn
205208
)
206209
)
207210
@@ -216,25 +219,9 @@ def create_opts_fn(keypath, value):
216219
class Saving:
217220
"""Options for saving PyTrees.
218221
219-
create_array_storage_options_fn:
220-
A function that is called in order to create
221-
:py:class:`.ArrayOptions.Saving.StorageOptions` for each leaf in a PyTree,
222-
when it is
223-
being saved. It is called similar to:
224-
`jax.tree.map_with_path(create_array_storage_options_fn, pytree_to_save)`.
225-
If provided, it overrides any default settings in
226-
:py:class:`.ArrayOptions.Saving.StorageOptions`.
227222
pytree_metadata_options: Options for managing PyTree metadata.
228223
"""
229224

230-
class CreateArrayStorageOptionsFn(Protocol):
231-
232-
def __call__(
233-
self, key: tree_types.PyTreeKeyPath, value: Any
234-
) -> ArrayOptions.Saving.StorageOptions:
235-
...
236-
237-
create_array_storage_options_fn: CreateArrayStorageOptionsFn | None = None
238225
pytree_metadata_options: tree_metadata.PyTreeMetadataOptions = (
239226
dataclasses.field(default_factory=tree_metadata.PyTreeMetadataOptions)
240227
)
@@ -322,8 +309,24 @@ class Saving:
322309
True.
323310
array_metadata_store: Store to manage per host ArrayMetadata. To disable
324311
ArrayMetadata persistence, set it to None.
312+
storage_options: Global default for array storage options.
313+
scoped_storage_options_creator: A function that, when dealing with
314+
PyTrees, is applied to every leaf. If it returns an
315+
:py:class:`ArrayOptions.Saving.StorageOptions`, its fields take
316+
precedence when merging if they are set to non-None or non-default
317+
values with respect to `storage_options`. If it returns `None`,
318+
`storage_options` is used as a default for all fields. It is called
319+
similar to: `jax.tree.map_with_path(scoped_storage_options_creator,
320+
pytree_to_save)`.
325321
"""
326322

323+
class ScopedStorageOptionsCreator(Protocol):
324+
325+
def __call__(
326+
self, key: tree_types.PyTreeKeyPath, value: Any
327+
) -> ArrayOptions.Saving.StorageOptions:
328+
...
329+
327330
@dataclasses.dataclass(frozen=True, kw_only=True)
328331
class StorageOptions:
329332
"""Options used to customize array storage behavior for individual leaves.
@@ -367,6 +370,7 @@ class StorageOptions:
367370
array_metadata_store: array_metadata_store_lib.Store | None = (
368371
array_metadata_store_lib.Store()
369372
)
373+
scoped_storage_options_creator: ScopedStorageOptionsCreator | None = None
370374

371375
@dataclasses.dataclass(frozen=True, kw_only=True)
372376
class Loading:

checkpoint/orbax/checkpoint/experimental/v1/_src/handlers/pytree_handler.py

Lines changed: 4 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@
3131
from orbax.checkpoint._src.metadata import array_metadata_store as array_metadata_store_lib
3232
from orbax.checkpoint._src.serialization import types as v0_serialization_types
3333
from orbax.checkpoint.experimental.v1._src.context import context as context_lib
34-
from orbax.checkpoint.experimental.v1._src.context import options as options_lib
3534
from orbax.checkpoint.experimental.v1._src.handlers import types as handler_types
3635
from orbax.checkpoint.experimental.v1._src.metadata import types as metadata_types
3736
from orbax.checkpoint.experimental.v1._src.path import types as path_types
@@ -40,6 +39,7 @@
4039
from orbax.checkpoint.experimental.v1._src.serialization import registry
4140
from orbax.checkpoint.experimental.v1._src.serialization import scalar_leaf_handler
4241
from orbax.checkpoint.experimental.v1._src.serialization import types as serialization_types
42+
from orbax.checkpoint.experimental.v1._src.serialization import utils
4343
from orbax.checkpoint.experimental.v1._src.synchronization import multihost
4444
from orbax.checkpoint.experimental.v1._src.tree import types as tree_types
4545

@@ -69,34 +69,11 @@ def _get_remaining_timeout(
6969

7070
def _get_v0_save_args(
7171
checkpointable: PyTree,
72-
array_storage_options: options_lib.ArrayOptions.Saving.StorageOptions,
73-
create_array_storage_options_fn: (
74-
options_lib.PyTreeOptions.Saving.CreateArrayStorageOptionsFn | None
75-
),
72+
context: context_lib.Context,
7673
) -> PyTree:
7774
"""Returns save args that are compatible with the V0 API."""
78-
7975
def _leaf_get_v0_save_args(k, v):
80-
if create_array_storage_options_fn:
81-
individual_array_storage_options = create_array_storage_options_fn(k, v)
82-
save_dtype = (
83-
np.dtype(individual_array_storage_options.dtype)
84-
if individual_array_storage_options.dtype
85-
else None
86-
)
87-
return v0_serialization_types.SaveArgs(
88-
dtype=save_dtype,
89-
chunk_byte_size=individual_array_storage_options.chunk_byte_size,
90-
shard_axes=individual_array_storage_options.shard_axes,
91-
)
92-
return v0_serialization_types.SaveArgs(
93-
dtype=np.dtype(array_storage_options.dtype)
94-
if array_storage_options.dtype
95-
else None,
96-
chunk_byte_size=array_storage_options.chunk_byte_size,
97-
shard_axes=array_storage_options.shard_axes,
98-
)
99-
76+
return utils.resolve_storage_options(k, v, context)
10077
return jax.tree.map_with_path(_leaf_get_v0_save_args, checkpointable)
10178

10279

@@ -133,11 +110,7 @@ def create_v0_save_args(
133110
"""Creates v0 CheckpointArgs for saving."""
134111
return base_pytree_checkpoint_handler.BasePyTreeSaveArgs(
135112
item=checkpointable,
136-
save_args=_get_v0_save_args(
137-
checkpointable,
138-
context.array_options.saving.storage_options,
139-
context.pytree_options.saving.create_array_storage_options_fn,
140-
),
113+
save_args=_get_v0_save_args(checkpointable, context),
141114
ocdbt_target_data_file_size=context.array_options.saving.ocdbt_target_data_file_size,
142115
)
143116

checkpoint/orbax/checkpoint/experimental/v1/_src/serialization/array_leaf_handler.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
from orbax.checkpoint.experimental.v1._src.serialization import protocol_utils
3636
from orbax.checkpoint.experimental.v1._src.serialization import registration
3737
from orbax.checkpoint.experimental.v1._src.serialization import types
38+
from orbax.checkpoint.experimental.v1._src.serialization import utils
3839

3940

4041
Shape = arrays_types_v0.Shape
@@ -112,15 +113,8 @@ def _create_v0_savearg(
112113
context: context_lib.Context,
113114
) -> type_handlers_v0.SaveArgs:
114115
"""Creates a V0 `SaveArgs` from V1 params and context for saving."""
115-
fn = context.pytree_options.saving.create_array_storage_options_fn
116-
if fn:
117-
storage_options = fn(param.keypath, param.value)
118-
else:
119-
storage_options = context.array_options.saving.storage_options
120-
return type_handlers_v0.SaveArgs(
121-
dtype=jnp.dtype(storage_options.dtype) if storage_options.dtype else None,
122-
chunk_byte_size=storage_options.chunk_byte_size,
123-
shard_axes=storage_options.shard_axes,
116+
return utils.resolve_storage_options(
117+
param.keypath, param.value, context, dtype_converter=jnp.dtype
124118
)
125119

126120

checkpoint/orbax/checkpoint/experimental/v1/_src/serialization/numpy_leaf_handler.py

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
from orbax.checkpoint.experimental.v1._src.context import context as context_lib
3333
from orbax.checkpoint.experimental.v1._src.serialization import registration
3434
from orbax.checkpoint.experimental.v1._src.serialization import types
35+
from orbax.checkpoint.experimental.v1._src.serialization import utils
3536

3637

3738
NumpySerializationParam = types.SerializationParam[np.ndarray]
@@ -99,16 +100,7 @@ def _create_v0_savearg(
99100
context: context_lib.Context,
100101
) -> type_handlers_v0.SaveArgs:
101102
"""Creates a V0 `SaveArgs` from V1 params and context for saving."""
102-
fn = context.pytree_options.saving.create_array_storage_options_fn
103-
if fn:
104-
storage_options = fn(param.keypath, param.value)
105-
else:
106-
storage_options = context.array_options.saving.storage_options
107-
return type_handlers_v0.SaveArgs(
108-
dtype=np.dtype(storage_options.dtype) if storage_options.dtype else None,
109-
chunk_byte_size=storage_options.chunk_byte_size,
110-
shard_axes=storage_options.shard_axes,
111-
)
103+
return utils.resolve_storage_options(param.keypath, param.value, context)
112104

113105

114106
def _create_v0_restore_paraminfo(

checkpoint/orbax/checkpoint/experimental/v1/_src/serialization/scalar_leaf_handler.py

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from orbax.checkpoint.experimental.v1._src.context import context as context_lib
3030
from orbax.checkpoint.experimental.v1._src.serialization import registration
3131
from orbax.checkpoint.experimental.v1._src.serialization import types
32+
from orbax.checkpoint.experimental.v1._src.serialization import utils
3233

3334
Scalar = types.Scalar
3435
AbstractScalar = types.AbstractScalar
@@ -70,16 +71,7 @@ def _create_v0_savearg(
7071
context: context_lib.Context,
7172
) -> type_handlers_v0.SaveArgs:
7273
"""Creates a V0 SaveArgs from V1 params and context for saving."""
73-
fn = context.pytree_options.saving.create_array_storage_options_fn
74-
if fn:
75-
storage_options = fn(param.keypath, param.value)
76-
else:
77-
storage_options = context.array_options.saving.storage_options
78-
return type_handlers_v0.SaveArgs(
79-
dtype=np.dtype(storage_options.dtype) if storage_options.dtype else None,
80-
chunk_byte_size=storage_options.chunk_byte_size,
81-
shard_axes=storage_options.shard_axes,
82-
)
74+
return utils.resolve_storage_options(param.keypath, param.value, context)
8375

8476

8577
def _create_v0_restore_paraminfo(
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
# Copyright 2026 The Orbax Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Utility functions for serialization."""
16+
17+
from collections.abc import Callable
18+
from typing import Any
19+
20+
import numpy as np
21+
from orbax.checkpoint._src.serialization import type_handlers as type_handlers_v0
22+
from orbax.checkpoint.experimental.v1._src.context import context as context_lib
23+
from orbax.checkpoint.experimental.v1._src.context import options as options_lib
24+
from orbax.checkpoint.experimental.v1._src.tree import types as tree_types
25+
26+
27+
def resolve_storage_options(
28+
keypath: tree_types.PyTreeKeyPath,
29+
value: tree_types.LeafType,
30+
context: context_lib.Context,
31+
*,
32+
dtype_converter: Callable[[Any], Any] = np.dtype,
33+
) -> type_handlers_v0.SaveArgs:
34+
"""Resolves storage options using a global default and a per-leaf creator.
35+
36+
When dealing with PyTrees, `scoped_storage_options_creator` is applied to
37+
every leaf. Its fields take precedence when merging if they are set to
38+
non-None or non-default values with respect to the global `storage_options`.
39+
If the creator returns `None`, the global `storage_options` is used for all
40+
fields.
41+
42+
Args:
43+
keypath: The PyTree keypath of the array being saved.
44+
value: The PyTree leaf value (array) being saved.
45+
context: The Orbax context containing saving options.
46+
dtype_converter: An optional callable to convert the resolved dtype.
47+
48+
Returns:
49+
The resolved SaveArgs containing storage options.
50+
"""
51+
global_opts = context.array_options.saving.storage_options
52+
if global_opts is None:
53+
global_opts = options_lib.ArrayOptions.Saving.StorageOptions()
54+
55+
resolved_dtype = global_opts.dtype
56+
resolved_chunk_byte_size = global_opts.chunk_byte_size
57+
resolved_shard_axes = global_opts.shard_axes
58+
59+
fn = context.array_options.saving.scoped_storage_options_creator
60+
individual_opts = None
61+
if fn is not None:
62+
individual_opts = fn(keypath, value)
63+
64+
if individual_opts is not None:
65+
if individual_opts.dtype is not None:
66+
resolved_dtype = individual_opts.dtype
67+
if individual_opts.chunk_byte_size is not None:
68+
resolved_chunk_byte_size = individual_opts.chunk_byte_size
69+
if individual_opts.shard_axes:
70+
resolved_shard_axes = individual_opts.shard_axes
71+
return type_handlers_v0.SaveArgs(
72+
dtype=dtype_converter(resolved_dtype)
73+
if resolved_dtype is not None
74+
else None,
75+
chunk_byte_size=resolved_chunk_byte_size,
76+
shard_axes=resolved_shard_axes,
77+
)

0 commit comments

Comments
 (0)