Skip to content

Commit f7656be

Browse files
angel-coreOrbax Authors
authored andcommitted
Centralize array storage options and implement field-level merging in Orbax v1.
PiperOrigin-RevId: 889877635
1 parent cffb547 commit f7656be

File tree

10 files changed

+320
-114
lines changed

10 files changed

+320
-114
lines changed

checkpoint/CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
2222
- Enforce the array shape and type check during Array restoration when
2323
`ArrayRestoreArgs.strict` is set but shape/dtype is not provided.
2424
- On platforms where `uvloop` is not supported, fallback to `nest_asyncio`.
25+
- #v1 Centralize `StorageOptions` into `ArrayOptions` and implement field-level
26+
merging.
2527

2628
## [0.11.33] - 2026-02-17
2729

checkpoint/orbax/checkpoint/experimental/v1/_src/context/options.py

Lines changed: 43 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -184,27 +184,6 @@ class PyTreeOptions:
184184
185185
# TODO: Include an example of registering a custom LeafHandler.
186186
187-
Example:
188-
To save certain leaves in float16, while others in float32, we can use
189-
`create_array_storage_options_fn` like so::
190-
191-
import jax
192-
import jax.numpy as jnp
193-
from orbax.checkpoint.v1 import options as ocp_options
194-
195-
def create_opts_fn(keypath, value):
196-
if 'small' in jax.tree_util.keystr(keypath):
197-
return ocp_options.ArrayOptions.Saving.StorageOptions(
198-
dtype=jnp.float16
199-
)
200-
return ocp_options.ArrayOptions.Saving.StorageOptions(dtype=jnp.float32)
201-
202-
pytree_options = ocp_options.PyTreeOptions(
203-
saving=ocp_options.PyTreeOptions.Saving(
204-
create_array_storage_options_fn=create_opts_fn
205-
)
206-
)
207-
208187
Attributes:
209188
saving: Options for saving PyTrees.
210189
loading: Options for loading PyTrees.
@@ -216,25 +195,9 @@ def create_opts_fn(keypath, value):
216195
class Saving:
217196
"""Options for saving PyTrees.
218197
219-
create_array_storage_options_fn:
220-
A function that is called in order to create
221-
:py:class:`.ArrayOptions.Saving.StorageOptions` for each leaf in a PyTree,
222-
when it is
223-
being saved. It is called similar to:
224-
`jax.tree.map_with_path(create_array_storage_options_fn, pytree_to_save)`.
225-
If provided, it overrides any default settings in
226-
:py:class:`.ArrayOptions.Saving.StorageOptions`.
227198
pytree_metadata_options: Options for managing PyTree metadata.
228199
"""
229200

230-
class CreateArrayStorageOptionsFn(Protocol):
231-
232-
def __call__(
233-
self, key: tree_types.PyTreeKeyPath, value: Any
234-
) -> ArrayOptions.Saving.StorageOptions:
235-
...
236-
237-
create_array_storage_options_fn: CreateArrayStorageOptionsFn | None = None
238201
pytree_metadata_options: tree_metadata.PyTreeMetadataOptions = (
239202
dataclasses.field(default_factory=tree_metadata.PyTreeMetadataOptions)
240203
)
@@ -265,7 +228,8 @@ class ArrayOptions:
265228
names during initialization.
266229
267230
Example:
268-
Configure array options with specific saving formats and loading behaviors::
231+
To configure array options with specific saving formats and loading
232+
behaviors we can do so like this::
269233
270234
from orbax.checkpoint.v1.options import ArrayOptions
271235
@@ -280,6 +244,30 @@ class ArrayOptions:
280244
)
281245
)
282246
247+
To save certain leaves in float16, while others in float32, we can use
248+
`scoped_storage_options_creator` like so::
249+
250+
import jax
251+
import jax.numpy as jnp
252+
from orbax.checkpoint.v1 import options as ocp_options
253+
254+
def create_opts_fn(keypath, value):
255+
if 'small' in jax.tree_util.keystr(keypath):
256+
return ocp_options.ArrayOptions.Saving.StorageOptions(
257+
dtype=jnp.float16
258+
)
259+
return None # Fall back to global `storage_options`
260+
261+
array_options = ocp_options.ArrayOptions(
262+
saving=ocp_options.ArrayOptions.Saving(
263+
storage_options=ocp_options.ArrayOptions.Saving.StorageOptions(
264+
dtype=jnp.float32
265+
),
266+
scoped_storage_options_creator=create_opts_fn
267+
)
268+
269+
)
270+
283271
Attributes:
284272
saving: Options for saving arrays.
285273
loading: Options for loading arrays.
@@ -322,8 +310,24 @@ class Saving:
322310
True.
323311
array_metadata_store: Store to manage per host ArrayMetadata. To disable
324312
ArrayMetadata persistence, set it to None.
313+
storage_options: Global default for array storage options.
314+
scoped_storage_options_creator: A function that, when dealing with
315+
PyTrees, is applied to every leaf. If it returns an
316+
:py:class:`ArrayOptions.Saving.StorageOptions`, its fields take
317+
precedence when merging if they are set to non-None or non-default
318+
values with respect to `storage_options`. If it returns `None`,
319+
`storage_options` is used as a default for all fields. It is called
320+
similar to: `jax.tree.map_with_path(scoped_storage_options_creator,
321+
pytree_to_save)`.
325322
"""
326323

324+
class ScopedStorageOptionsCreator(Protocol):
325+
326+
def __call__(
327+
self, key: tree_types.PyTreeKeyPath, value: Any
328+
) -> ArrayOptions.Saving.StorageOptions:
329+
...
330+
327331
@dataclasses.dataclass(frozen=True, kw_only=True)
328332
class StorageOptions:
329333
"""Options used to customize array storage behavior for individual leaves.
@@ -367,6 +371,7 @@ class StorageOptions:
367371
array_metadata_store: array_metadata_store_lib.Store | None = (
368372
array_metadata_store_lib.Store()
369373
)
374+
scoped_storage_options_creator: ScopedStorageOptionsCreator | None = None
370375

371376
@dataclasses.dataclass(frozen=True, kw_only=True)
372377
class Loading:

checkpoint/orbax/checkpoint/experimental/v1/_src/handlers/pytree_handler.py

Lines changed: 12 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -29,13 +29,15 @@
2929
from orbax.checkpoint._src.futures import synchronization
3030
from orbax.checkpoint._src.handlers import base_pytree_checkpoint_handler
3131
from orbax.checkpoint._src.metadata import array_metadata_store as array_metadata_store_lib
32+
from orbax.checkpoint._src.serialization import type_handlers as type_handlers_v0
3233
from orbax.checkpoint._src.serialization import types as v0_serialization_types
3334
from orbax.checkpoint.experimental.v1._src.context import context as context_lib
3435
from orbax.checkpoint.experimental.v1._src.context import options as options_lib
3536
from orbax.checkpoint.experimental.v1._src.handlers import types as handler_types
3637
from orbax.checkpoint.experimental.v1._src.metadata import types as metadata_types
3738
from orbax.checkpoint.experimental.v1._src.path import types as path_types
3839
from orbax.checkpoint.experimental.v1._src.serialization import compatibility
40+
from orbax.checkpoint.experimental.v1._src.serialization import options_resolution
3941
from orbax.checkpoint.experimental.v1._src.serialization import protocol_utils
4042
from orbax.checkpoint.experimental.v1._src.serialization import registry
4143
from orbax.checkpoint.experimental.v1._src.serialization import scalar_leaf_handler
@@ -69,32 +71,19 @@ def _get_remaining_timeout(
6971

7072
def _get_v0_save_args(
7173
checkpointable: PyTree,
72-
array_storage_options: options_lib.ArrayOptions.Saving.StorageOptions,
73-
create_array_storage_options_fn: (
74-
options_lib.PyTreeOptions.Saving.CreateArrayStorageOptionsFn | None
75-
),
74+
array_saving_options: options_lib.ArrayOptions.Saving,
7675
) -> PyTree:
7776
"""Returns save args that are compatible with the V0 API."""
78-
7977
def _leaf_get_v0_save_args(k, v):
80-
if create_array_storage_options_fn:
81-
individual_array_storage_options = create_array_storage_options_fn(k, v)
82-
save_dtype = (
83-
np.dtype(individual_array_storage_options.dtype)
84-
if individual_array_storage_options.dtype
85-
else None
86-
)
87-
return v0_serialization_types.SaveArgs(
88-
dtype=save_dtype,
89-
chunk_byte_size=individual_array_storage_options.chunk_byte_size,
90-
shard_axes=individual_array_storage_options.shard_axes,
91-
)
92-
return v0_serialization_types.SaveArgs(
93-
dtype=np.dtype(array_storage_options.dtype)
94-
if array_storage_options.dtype
78+
resolved_options = options_resolution.resolve_storage_options(
79+
k, v, array_saving_options
80+
)
81+
return type_handlers_v0.SaveArgs(
82+
dtype=np.dtype(resolved_options.dtype)
83+
if resolved_options.dtype is not None
9584
else None,
96-
chunk_byte_size=array_storage_options.chunk_byte_size,
97-
shard_axes=array_storage_options.shard_axes,
85+
chunk_byte_size=resolved_options.chunk_byte_size,
86+
shard_axes=resolved_options.shard_axes,
9887
)
9988

10089
return jax.tree.map_with_path(_leaf_get_v0_save_args, checkpointable)
@@ -135,8 +124,7 @@ def create_v0_save_args(
135124
item=checkpointable,
136125
save_args=_get_v0_save_args(
137126
checkpointable,
138-
context.array_options.saving.storage_options,
139-
context.pytree_options.saving.create_array_storage_options_fn,
127+
context.array_options.saving,
140128
),
141129
ocdbt_target_data_file_size=context.array_options.saving.ocdbt_target_data_file_size,
142130
)

checkpoint/orbax/checkpoint/experimental/v1/_src/serialization/array_leaf_handler.py

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@
3232
from orbax.checkpoint._src.metadata import value as value_metadata
3333
from orbax.checkpoint._src.serialization import type_handlers as type_handlers_v0
3434
from orbax.checkpoint.experimental.v1._src.context import context as context_lib
35+
import orbax.checkpoint.experimental.v1._src.context.options as options_lib
36+
from orbax.checkpoint.experimental.v1._src.serialization import options_resolution
3537
from orbax.checkpoint.experimental.v1._src.serialization import protocol_utils
3638
from orbax.checkpoint.experimental.v1._src.serialization import registration
3739
from orbax.checkpoint.experimental.v1._src.serialization import types
@@ -109,18 +111,18 @@ def _create_v0_saving_paraminfo(
109111

110112
def _create_v0_savearg(
111113
param: ArraySerializationParam,
112-
context: context_lib.Context,
114+
array_saving_options: options_lib.ArrayOptions.Saving,
113115
) -> type_handlers_v0.SaveArgs:
114-
"""Creates a V0 `SaveArgs` from V1 params and context for saving."""
115-
fn = context.pytree_options.saving.create_array_storage_options_fn
116-
if fn:
117-
storage_options = fn(param.keypath, param.value)
118-
else:
119-
storage_options = context.array_options.saving.storage_options
116+
"""Creates a V0 `SaveArgs` from V1 params and array options for saving."""
117+
resolved_options = options_resolution.resolve_storage_options(
118+
param.keypath, param.value, array_saving_options
119+
)
120120
return type_handlers_v0.SaveArgs(
121-
dtype=jnp.dtype(storage_options.dtype) if storage_options.dtype else None,
122-
chunk_byte_size=storage_options.chunk_byte_size,
123-
shard_axes=storage_options.shard_axes,
121+
dtype=jnp.dtype(resolved_options.dtype)
122+
if resolved_options.dtype is not None
123+
else None,
124+
chunk_byte_size=resolved_options.chunk_byte_size,
125+
shard_axes=resolved_options.shard_axes,
124126
)
125127

126128

@@ -223,7 +225,10 @@ async def serialize(
223225
_create_v0_saving_paraminfo(p, self._context, serialization_context)
224226
for p in params
225227
]
226-
saveargs = [_create_v0_savearg(p, self._context) for p in params]
228+
saveargs = [
229+
_create_v0_savearg(p, self._context.array_options.saving)
230+
for p in params
231+
]
227232

228233
commit_futures = await self._handler_impl.serialize(
229234
values, paraminfos, saveargs

checkpoint/orbax/checkpoint/experimental/v1/_src/serialization/numpy_leaf_handler.py

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@
3030
from orbax.checkpoint._src.metadata import value as value_metadata
3131
from orbax.checkpoint._src.serialization import type_handlers as type_handlers_v0
3232
from orbax.checkpoint.experimental.v1._src.context import context as context_lib
33+
import orbax.checkpoint.experimental.v1._src.context.options as options_lib
34+
from orbax.checkpoint.experimental.v1._src.serialization import options_resolution
3335
from orbax.checkpoint.experimental.v1._src.serialization import registration
3436
from orbax.checkpoint.experimental.v1._src.serialization import types
3537

@@ -96,18 +98,18 @@ def _create_v0_saving_paraminfo(
9698

9799
def _create_v0_savearg(
98100
param: NumpySerializationParam,
99-
context: context_lib.Context,
101+
array_saving_options: options_lib.ArrayOptions.Saving,
100102
) -> type_handlers_v0.SaveArgs:
101-
"""Creates a V0 `SaveArgs` from V1 params and context for saving."""
102-
fn = context.pytree_options.saving.create_array_storage_options_fn
103-
if fn:
104-
storage_options = fn(param.keypath, param.value)
105-
else:
106-
storage_options = context.array_options.saving.storage_options
103+
"""Creates a V0 `SaveArgs` from V1 params and array saving options."""
104+
resolved_options = options_resolution.resolve_storage_options(
105+
param.keypath, param.value, array_saving_options
106+
)
107107
return type_handlers_v0.SaveArgs(
108-
dtype=np.dtype(storage_options.dtype) if storage_options.dtype else None,
109-
chunk_byte_size=storage_options.chunk_byte_size,
110-
shard_axes=storage_options.shard_axes,
108+
dtype=np.dtype(resolved_options.dtype)
109+
if resolved_options.dtype is not None
110+
else None,
111+
chunk_byte_size=resolved_options.chunk_byte_size,
112+
shard_axes=resolved_options.shard_axes,
111113
)
112114

113115

@@ -188,7 +190,10 @@ async def serialize(
188190
_create_v0_saving_paraminfo(p, self._context, serialization_context)
189191
for p in params
190192
]
191-
saveargs = [_create_v0_savearg(p, self._context) for p in params]
193+
saveargs = [
194+
_create_v0_savearg(p, self._context.array_options.saving)
195+
for p in params
196+
]
192197

193198
commit_futures = await self._handler_impl.serialize(
194199
values, paraminfos, saveargs
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
# Copyright 2026 The Orbax Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Utility functions for serialization."""
16+
17+
from orbax.checkpoint.experimental.v1._src.context import options as options_lib
18+
from orbax.checkpoint.experimental.v1._src.tree import types as tree_types
19+
20+
21+
def resolve_storage_options(
22+
keypath: tree_types.PyTreeKeyPath,
23+
value: tree_types.LeafType,
24+
array_saving_options: options_lib.ArrayOptions.Saving,
25+
) -> options_lib.ArrayOptions.Saving.StorageOptions:
26+
"""Resolves storage options using a global default and a per-leaf creator.
27+
28+
When dealing with PyTrees, `scoped_storage_options_creator` is applied to
29+
every leaf. Its fields take precedence when merging if they are set to
30+
non-None or non-default values with respect to the global `storage_options`.
31+
If the creator returns `None`, the global `storage_options` is used for all
32+
fields.
33+
34+
Args:
35+
keypath: The PyTree keypath of the array being saved.
36+
value: The PyTree leaf value (array) being saved.
37+
array_saving_options: The Orbax array saving options to use for resolution.
38+
39+
Returns:
40+
The resolved StorageOptions containing storage options.
41+
"""
42+
global_opts = array_saving_options.storage_options
43+
if global_opts is None:
44+
global_opts = options_lib.ArrayOptions.Saving.StorageOptions()
45+
46+
fn = array_saving_options.scoped_storage_options_creator
47+
individual_opts = None
48+
if fn is not None:
49+
individual_opts = fn(keypath, value)
50+
51+
if individual_opts is not None:
52+
resolved_dtype = (
53+
individual_opts.dtype
54+
if individual_opts.dtype is not None
55+
else global_opts.dtype
56+
)
57+
resolved_chunk_byte_size = (
58+
individual_opts.chunk_byte_size
59+
if individual_opts.chunk_byte_size is not None
60+
else global_opts.chunk_byte_size
61+
)
62+
resolved_shard_axes = (
63+
individual_opts.shard_axes
64+
if individual_opts.shard_axes
65+
else global_opts.shard_axes
66+
)
67+
else:
68+
resolved_dtype = global_opts.dtype
69+
resolved_chunk_byte_size = global_opts.chunk_byte_size
70+
resolved_shard_axes = global_opts.shard_axes
71+
72+
return options_lib.ArrayOptions.Saving.StorageOptions(
73+
dtype=resolved_dtype,
74+
chunk_byte_size=resolved_chunk_byte_size,
75+
shard_axes=resolved_shard_axes,
76+
)
77+

0 commit comments

Comments
 (0)