Skip to content

Commit e3c4045

Browse files
ChromeHeartsOrbax Authors
authored andcommitted
Enable OCDBT read coalescing for remote storage
PiperOrigin-RevId: 725709173
1 parent 387ff8a commit e3c4045

File tree

7 files changed

+73
-64
lines changed

7 files changed

+73
-64
lines changed

checkpoint/CHANGELOG.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
77

88
## [Unreleased]
99

10+
## [0.11.5] - 2025-02-10
11+
12+
### Fixed
13+
14+
- Enable OCDBT read coalescing for remote storage
15+
1016
### Added
1117

1218
- `ocp.metadata.get_step_metadata(path)` to public api.

checkpoint/orbax/checkpoint/_src/serialization/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ py_test(
4949
name = "tensorstore_utils_test",
5050
srcs = ["tensorstore_utils_test.py"],
5151
deps = [
52+
":serialization",
5253
":tensorstore_utils",
5354
"//checkpoint/orbax/checkpoint/_src/arrays:subchunking",
5455
"//checkpoint/orbax/checkpoint/_src/arrays:types",

checkpoint/orbax/checkpoint/_src/serialization/serialization.py

Lines changed: 0 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -40,12 +40,6 @@
4040
TS_CONTEXT = ts.Context({'file_io_concurrency': {'limit': 128}})
4141
_REMOVED_VALUE = 'Value removed'
4242
_CHECKPOINT_SUCCESS = 'checkpoint_write_success'
43-
_REMOTE_URL_PREFIXES = ['gs://', 's3://']
44-
_REMOTE_DRIVER_VALIDATIONS = [
45-
{'driver': 'gcs', 'path_regex': None},
46-
{'driver': 's3', 'path_regex': None},
47-
]
48-
4943

5044
Index = types.Index
5145
Layout = layout.Layout
@@ -110,42 +104,6 @@ def get_tensorstore_spec(ckpt_path: str, ocdbt: bool = False):
110104
return spec
111105

112106

113-
def is_remote_storage(tspec: Union[Dict[str, Any], str]) -> bool:
114-
"""Detect if user is using remote storages.
115-
116-
This can detect common defines and unable to detect some corner cases such as
117-
using gcsfuse.
118-
119-
Args:
120-
tspec: Tensorstore spec.
121-
122-
Returns:
123-
True if the spec is using remote storage.
124-
"""
125-
if isinstance(tspec, str):
126-
# KvStoreUrl
127-
if re.match(rf'^({"|".join(_REMOTE_URL_PREFIXES)})', tspec):
128-
return True
129-
else:
130-
return False
131-
132-
for key in ('base', 'kvstore'):
133-
if key in tspec:
134-
return is_remote_storage(tspec[key])
135-
136-
if 'driver' in tspec:
137-
for rule in _REMOTE_DRIVER_VALIDATIONS:
138-
if tspec['driver'] == rule['driver']:
139-
if rule['path_regex'] is None:
140-
return True
141-
142-
# check if path matches the regex.
143-
if re.match(rule['path_regex'], tspec['path']):
144-
return True
145-
146-
return False
147-
148-
149107
class ByteLimiter(Protocol):
150108

151109
async def wait_for_bytes(self, requested_bytes: int):

checkpoint/orbax/checkpoint/_src/serialization/serialization_test.py

Lines changed: 0 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -472,25 +472,6 @@ def test_get_tensorstore_spec_not_absolute_path(self):
472472
):
473473
serialization.get_tensorstore_spec(path, ocdbt=True)
474474

475-
def test_maybe_cloud_storage(self):
476-
gs_path = 'gs://some-buck/path'
477-
gs_spec = serialization.get_tensorstore_spec(gs_path, ocdbt=True)
478-
self.assertTrue(serialization.is_remote_storage(gs_spec))
479-
480-
local_path = '/tmp/checkpoint'
481-
local_spec = serialization.get_tensorstore_spec(local_path, ocdbt=True)
482-
self.assertFalse(serialization.is_remote_storage(local_spec))
483-
484-
nested_tspec = {
485-
'driver': 'cast',
486-
'dtype': 'int32',
487-
'base': {
488-
'driver': 'zarr',
489-
'kvstore': {'driver': 'ocdbt', 'base': 's3://some-bucket/path'},
490-
},
491-
}
492-
self.assertTrue(serialization.is_remote_storage(nested_tspec))
493-
494475
def test_deserialization_with_int4(self):
495476
dtype = jnp.int4
496477
shape = (8, 2)

checkpoint/orbax/checkpoint/_src/serialization/tensorstore_utils.py

Lines changed: 45 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
import math
1818
import os
1919
import re
20-
from typing import Any, TypeAlias
20+
from typing import Any, Dict, TypeAlias, Union
2121

2222
from absl import logging
2323
from jax import numpy as jnp
@@ -62,6 +62,13 @@
6262
**{'cache_pool#ocdbt': {'total_bytes_limit': 100000000}},
6363
}
6464

65+
_REMOTE_URL_PREFIXES = ['gs://', 's3://']
66+
_REMOTE_DRIVER_VALIDATIONS = [
67+
{'driver': 'gcs', 'path_regex': None},
68+
{'driver': 's3', 'path_regex': None},
69+
]
70+
71+
6572

6673
def get_ts_context(*, use_ocdbt: bool = True) -> ts.Context:
6774
"""Creates a TensorStore context object.
@@ -155,7 +162,7 @@ def build_kvstore_tspec(
155162
'cache_pool': 'cache_pool#ocdbt',
156163
})
157164

158-
if default_driver != FILE_DRIVER:
165+
if is_remote_storage(kv_spec):
159166
kv_spec.update({ # pytype: disable=attribute-error
160167
# Enable read coalescing. This feature merges adjacent read_ops into
161168
# one, which could reduce I/O ops by a factor of 10. This is
@@ -415,3 +422,39 @@ def json(self) -> JsonSpec:
415422
def metadata(self) -> ArrayMetadata:
416423
"""Checkpoint-relevant TensorStore metadata of the array."""
417424
return self._metadata
425+
426+
427+
def is_remote_storage(tspec: Union[Dict[str, Any], str]) -> bool:
428+
"""Detect if user is using remote storages.
429+
430+
This can detect common defines and unable to detect some corner cases such as
431+
using gcsfuse.
432+
433+
Args:
434+
tspec: Tensorstore spec.
435+
436+
Returns:
437+
True if the spec is using remote storage.
438+
"""
439+
if isinstance(tspec, str):
440+
# KvStoreUrl
441+
if re.match(rf'^({"|".join(_REMOTE_URL_PREFIXES)})', tspec):
442+
return True
443+
else:
444+
return False
445+
446+
for key in ('base', 'kvstore'):
447+
if key in tspec:
448+
return is_remote_storage(tspec[key])
449+
450+
if 'driver' in tspec:
451+
for rule in _REMOTE_DRIVER_VALIDATIONS:
452+
if tspec['driver'] == rule['driver']:
453+
if rule['path_regex'] is None:
454+
return True
455+
456+
# check if path matches the regex.
457+
if re.match(rule['path_regex'], tspec['path']):
458+
return True
459+
460+
return False

checkpoint/orbax/checkpoint/_src/serialization/tensorstore_utils_test.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import numpy as np
2222
from orbax.checkpoint._src.arrays import subchunking
2323
from orbax.checkpoint._src.arrays import types
24+
from orbax.checkpoint._src.serialization import serialization
2425
from orbax.checkpoint._src.serialization import tensorstore_utils as ts_utils
2526

2627

@@ -599,6 +600,25 @@ def test_chunk_byte_size_is_adjusted_for_target_data_file_size(
599600
expected_chunk_byte_size_limit,
600601
)
601602

603+
def test_maybe_cloud_storage(self):
604+
gs_path = 'gs://some-buck/path'
605+
gs_spec = serialization.get_tensorstore_spec(gs_path, ocdbt=True)
606+
self.assertTrue(ts_utils.is_remote_storage(gs_spec))
607+
608+
local_path = '/tmp/checkpoint'
609+
local_spec = serialization.get_tensorstore_spec(local_path, ocdbt=True)
610+
self.assertFalse(ts_utils.is_remote_storage(local_spec))
611+
612+
nested_tspec = {
613+
'driver': 'cast',
614+
'dtype': 'int32',
615+
'base': {
616+
'driver': 'zarr',
617+
'kvstore': {'driver': 'ocdbt', 'base': 's3://some-bucket/path'},
618+
},
619+
}
620+
self.assertTrue(ts_utils.is_remote_storage(nested_tspec))
621+
602622

603623
if __name__ == '__main__':
604624
absltest.main()

checkpoint/orbax/checkpoint/version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
# A new PyPI release will be pushed everytime `__version__` is increased.
1818
# Also modify version and date in CHANGELOG.
19-
__version__ = '0.11.4'
19+
__version__ = '0.11.5'
2020

2121

2222
# TODO: b/362813406 - Add latest change timestamp and commit number.

0 commit comments

Comments
 (0)