Skip to content

Commit 3795042

Browse files
cpgaffney1copybara-github
authored andcommitted
Add general method to re-register TypeHandlers with options. Push new version.
PiperOrigin-RevId: 516919811
1 parent 670ef8e commit 3795042

File tree

4 files changed

+19
-11
lines changed

4 files changed

+19
-11
lines changed

CHANGELOG.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
77

88
## [Unreleased]
99

10+
## [0.1.4] - 2022-03-15
11+
1012
### Added
11-
- Add support for Tensorstore OCDBT option.
1213
- Support for generic transformation function in PyTreeCheckpointHandler.
1314
- Support n-digit checkpoint step format.
1415

orbax/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,4 +15,4 @@
1515
"""Orbax API."""
1616

1717
# A new PyPI release will be pushed everytime `__version__` is increased.
18-
__version__ = '0.1.3'
18+
__version__ = '0.1.4'

orbax/checkpoint/pytree_checkpoint_handler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -222,7 +222,7 @@ def __init__(
222222
self._aggregate_filename = aggregate_filename
223223
self._concurrent_gb = concurrent_gb
224224
if use_ocdbt:
225-
type_handlers.register_ocdbt_handlers()
225+
type_handlers.register_standard_handlers_with_options(use_ocdbt=use_ocdbt)
226226

227227
def _get_param_names(self, item: PyTree) -> PyTree:
228228
"""Gets parameter names for PyTree elements."""

orbax/checkpoint/type_handlers.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -450,15 +450,22 @@ def has_type_handler(ty: Any) -> bool:
450450
return False
451451

452452

453-
def register_ocdbt_handlers():
454-
"""Re-registers select TypeHanders to use Tensorstore OCDBT driver."""
455-
register_type_handler(int, ScalarHandler(use_ocdbt=True), override=True)
456-
register_type_handler(float, ScalarHandler(use_ocdbt=True), override=True)
457-
register_type_handler(np.number, ScalarHandler(use_ocdbt=True), override=True)
458-
register_type_handler(np.ndarray, NumpyHandler(use_ocdbt=True), override=True)
453+
def register_standard_handlers_with_options(**kwargs):
454+
"""Re-registers a select set of handlers with the given options."""
455+
register_type_handler(int, ScalarHandler(**kwargs), override=True)
456+
register_type_handler(float, ScalarHandler(**kwargs), override=True)
457+
register_type_handler(
458+
np.number,
459+
ScalarHandler(**kwargs),
460+
override=True,
461+
)
462+
register_type_handler(
463+
np.ndarray,
464+
NumpyHandler(**kwargs),
465+
override=True,
466+
)
459467
register_type_handler(
460468
jax.Array,
461-
ArrayHandler(use_ocdbt=True),
462-
func=lambda ty: issubclass(ty, jax.Array) and jax.config.jax_array,
469+
ArrayHandler(**kwargs),
463470
override=True,
464471
)

0 commit comments

Comments
 (0)