Skip to content

Commit 7218079

Browse files
piterclgabrys
andauthored
feat: support >1k series in fetch_metric_buckets (#107)
* feat: support >1k series in fetch_metric_buckets * Un-skip tests for nan/inf * Comment-out unused fields in TimeseriesBucket to save memory * Remove TODO * Fix fuzzy test after comment-out * Restore constant * More readable _update_range Co-authored-by: Piotr Gabryjeluk <[email protected]> * Fixes * Fix min/max in update_range --------- Co-authored-by: Piotr Gabryjeluk <[email protected]>
1 parent 6027db3 commit 7218079

File tree

10 files changed

+261
-90
lines changed

10 files changed

+261
-90
lines changed

src/neptune_query/internal/composition/fetch_metric_buckets.py

Lines changed: 98 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,11 @@
1818
Generator,
1919
Literal,
2020
Optional,
21+
Protocol,
2122
Union,
2223
)
2324

25+
import numpy as np
2426
import pandas as pd
2527
from neptune_api.client import AuthenticatedClient
2628

@@ -48,13 +50,23 @@
4850
util,
4951
)
5052
from ..retrieval.attribute_values import AttributeValue
51-
from ..retrieval.metric_buckets import TimeseriesBucket
53+
from ..retrieval.metric_buckets import (
54+
MAX_SERIES_PER_REQUEST,
55+
TimeseriesBucket,
56+
)
5257
from ..retrieval.search import ContainerType
5358
from .attribute_components import fetch_attribute_values_by_filter_split
5459

5560
__all__ = ("fetch_metric_buckets",)
5661

5762

63+
class _FetchInChunksProtocol(Protocol):
64+
def __call__(
65+
self, x_range: Optional[tuple[float, float]], bucket_limit: int
66+
) -> dict[RunAttributeDefinition, list[TimeseriesBucket]]:
67+
...
68+
69+
5870
def fetch_metric_buckets(
5971
*,
6072
project_identifier: identifiers.ProjectIdentifier,
@@ -152,25 +164,102 @@ def go_fetch_sys_attrs() -> Generator[list[identifiers.SysId], None, None]:
152164
),
153165
)
154166

155-
results: Generator[util.Page[AttributeValue], None, None] = concurrency.gather_results(output)
167+
attribute_value_pages: Generator[util.Page[AttributeValue], None, None] = concurrency.gather_results(output)
156168

157169
run_attribute_definitions = []
158-
for page in results:
170+
for page in attribute_value_pages:
159171
for value in page.items:
160172
run_attribute_definition = RunAttributeDefinition(
161173
run_identifier=value.run_identifier, attribute_definition=value.attribute_definition
162174
)
163175
run_attribute_definitions.append(run_attribute_definition)
164176

165-
buckets_data = metric_buckets.fetch_time_series_buckets(
177+
if not run_attribute_definitions:
178+
return {}, sys_id_label_mapping
179+
180+
fetch_in_chunks: _FetchInChunksProtocol = lambda x_range, bucket_limit: _fetch_in_chunks(
166181
client=client,
167-
x=x,
168182
run_attribute_definitions=run_attribute_definitions,
169-
lineage_to_the_root=lineage_to_the_root,
183+
x=x,
170184
include_point_previews=include_point_previews,
171-
limit=limit,
172185
container_type=container_type,
173-
x_range=None,
186+
x_range=x_range,
187+
lineage_to_the_root=lineage_to_the_root,
188+
limit=bucket_limit,
189+
executor=executor,
190+
)
191+
192+
# if len(run_attribute_definitions) <= MAX_SERIES_PER_REQUEST:
193+
# return fetch_in_chunks(x_range=None, bucket_limit=limit), sys_id_label_mapping
194+
195+
global_x_range = _compute_global_x_range(fetch_in_chunks=fetch_in_chunks)
196+
if global_x_range is None:
197+
# No finite points / bucket bounds found
198+
return {}, sys_id_label_mapping
199+
200+
return fetch_in_chunks(x_range=global_x_range, bucket_limit=limit), sys_id_label_mapping
201+
202+
203+
def _fetch_in_chunks(
204+
client: AuthenticatedClient,
205+
run_attribute_definitions: list[RunAttributeDefinition],
206+
x: Literal["step"],
207+
include_point_previews: bool,
208+
container_type: ContainerType,
209+
x_range: Optional[tuple[float, float]],
210+
lineage_to_the_root: bool,
211+
limit: int,
212+
executor: Executor,
213+
) -> dict[RunAttributeDefinition, list[TimeseriesBucket]]:
214+
chunks = (
215+
run_attribute_definitions[offset : offset + MAX_SERIES_PER_REQUEST]
216+
for offset in range(0, len(run_attribute_definitions), MAX_SERIES_PER_REQUEST)
174217
)
175218

176-
return buckets_data, sys_id_label_mapping
219+
output = concurrency.generate_concurrently(
220+
items=chunks,
221+
executor=executor,
222+
downstream=lambda chunk: concurrency.return_value(
223+
metric_buckets.fetch_time_series_buckets(
224+
client=client,
225+
x=x,
226+
run_attribute_definitions=chunk,
227+
lineage_to_the_root=lineage_to_the_root,
228+
include_point_previews=include_point_previews,
229+
limit=limit,
230+
container_type=container_type,
231+
x_range=x_range,
232+
)
233+
),
234+
)
235+
236+
merged: dict[RunAttributeDefinition, list[TimeseriesBucket]] = {}
237+
for chunk_data in concurrency.gather_results(output): # type: dict[RunAttributeDefinition, list[TimeseriesBucket]]
238+
merged.update(chunk_data)
239+
return merged
240+
241+
242+
def _compute_global_x_range(fetch_in_chunks: _FetchInChunksProtocol) -> Optional[tuple[float, float]]:
243+
x_range: tuple[Optional[float], Optional[float]] = (None, None)
244+
# We only need the minimal number of buckets to determine min/max x
245+
for buckets in fetch_in_chunks(x_range=None, bucket_limit=2).values():
246+
for bucket in buckets:
247+
x_range = _update_range(x_range, bucket)
248+
249+
if x_range[0] is None or x_range[1] is None:
250+
return None
251+
return x_range[0], x_range[1]
252+
253+
254+
def _update_range(
255+
current_range: tuple[Optional[float], Optional[float]], bucket: TimeseriesBucket
256+
) -> tuple[Optional[float], Optional[float],]:
257+
# We're including from_x and to_x because some buckets might hold only non-finite points,
258+
# in which case first_x and last_x are None.
259+
candidates = [bucket.first_x, bucket.last_x, bucket.from_x, bucket.to_x] + list(current_range)
260+
finite_candidates = [x for x in candidates if x is not None and np.isfinite(x)]
261+
262+
if len(finite_candidates):
263+
return min(finite_candidates), max(finite_candidates)
264+
else:
265+
return None, None

src/neptune_query/internal/retrieval/metric_buckets.py

Lines changed: 22 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,9 @@
4444
logger = get_logger()
4545

4646

47+
MAX_SERIES_PER_REQUEST = 1000
48+
49+
4750
@dataclass(frozen=True)
4851
class TimeseriesBucket:
4952
index: int
@@ -55,13 +58,13 @@ class TimeseriesBucket:
5558
last_y: Optional[float]
5659

5760
# statistics:
58-
y_min: Optional[float]
59-
y_max: Optional[float]
60-
finite_point_count: int
61-
nan_count: int
62-
positive_inf_count: int
63-
negative_inf_count: int
64-
finite_points_sum: Optional[float]
61+
# y_min: Optional[float]
62+
# y_max: Optional[float]
63+
# finite_point_count: int
64+
# nan_count: int
65+
# positive_inf_count: int
66+
# negative_inf_count: int
67+
# finite_points_sum: Optional[float]
6568

6669

6770
# Build once at module import
@@ -112,8 +115,8 @@ def fetch_time_series_buckets(
112115
expressions = {}
113116
request_id_to_request_mapping = {}
114117
for num, run_attribute_definition in enumerate(run_attribute_definitions):
115-
if num >= 1000:
116-
raise ValueError("Cannot fetch more than 1000 time series at once")
118+
if num >= MAX_SERIES_PER_REQUEST:
119+
raise ValueError(f"Cannot fetch more than {MAX_SERIES_PER_REQUEST} time series at once")
117120

118121
request_id = int_to_uuid(num)
119122
request_id_to_request_mapping[request_id] = run_attribute_definition
@@ -163,10 +166,10 @@ def fetch_time_series_buckets(
163166
for entry in result_object.entries:
164167
request = request_id_to_request_mapping.get(entry.requestId, None)
165168
if request is None:
166-
raise RuntimeError(f"Received unknown requestId from the server: {request_id}")
169+
raise RuntimeError(f"Received unknown requestId from the server: {entry.requestId}")
167170

168171
if request in out:
169-
raise RuntimeError(f"Received duplicate requestId from the server: {request_id}")
172+
raise RuntimeError(f"Received duplicate requestId from the server: {entry.requestId}")
170173

171174
out[request] = [
172175
TimeseriesBucket(
@@ -177,19 +180,19 @@ def fetch_time_series_buckets(
177180
first_y=bucket.first.y if bucket.HasField("first") else None,
178181
last_x=bucket.last.x if bucket.HasField("last") else None,
179182
last_y=bucket.last.y if bucket.HasField("last") else None,
180-
y_min=bucket.localMin if bucket.HasField("localMin") else None,
181-
y_max=bucket.localMax if bucket.HasField("localMax") else None,
182-
finite_point_count=bucket.finitePointCount,
183-
nan_count=bucket.nanCount,
184-
positive_inf_count=bucket.positiveInfCount,
185-
negative_inf_count=bucket.negativeInfCount,
186-
finite_points_sum=bucket.localSum if bucket.HasField("localSum") else None,
183+
# y_min=bucket.localMin if bucket.HasField("localMin") else None,
184+
# y_max=bucket.localMax if bucket.HasField("localMax") else None,
185+
# finite_point_count=bucket.finitePointCount,
186+
# nan_count=bucket.nanCount,
187+
# positive_inf_count=bucket.positiveInfCount,
188+
# negative_inf_count=bucket.negativeInfCount,
189+
# finite_points_sum=bucket.localSum if bucket.HasField("localSum") else None,
187190
)
188191
for bucket in entry.bucket
189192
]
190193

191194
for request in run_attribute_definitions:
192195
if request not in out:
193-
raise RuntimeError("Didn't get data for all the requests from the server. " f"Missing request {request_id}")
196+
raise RuntimeError("Didn't get data for all the requests from the server. " f"Missing request {request}")
194197

195198
return out

tests/e2e/metric_buckets.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -71,13 +71,13 @@ def aggregate_metric_buckets(
7171
first_y=ys[0] if ys else float("nan"),
7272
last_x=xs[-1] if xs else float("nan"),
7373
last_y=ys[-1] if ys else float("nan"),
74-
y_min=float(np.min(ys)) if ys else float("nan"),
75-
y_max=float(np.max(ys)) if ys else float("nan"),
76-
finite_point_count=len(ys),
77-
nan_count=nan_count,
78-
positive_inf_count=positive_inf_count,
79-
negative_inf_count=negative_inf_count,
80-
finite_points_sum=float(np.sum(ys)) if ys else 0.0,
74+
# y_min=float(np.min(ys)) if ys else float("nan"),
75+
# y_max=float(np.max(ys)) if ys else float("nan"),
76+
# finite_point_count=len(ys),
77+
# nan_count=nan_count,
78+
# positive_inf_count=positive_inf_count,
79+
# negative_inf_count=negative_inf_count,
80+
# finite_points_sum=float(np.sum(ys)) if ys else 0.0,
8181
)
8282
buckets.append(bucket)
8383
return buckets

tests/e2e/v1/test_experiments.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -321,7 +321,6 @@ def test__fetch_experiments_table_with_attributes_regex_filter_for_metrics(
321321
assert df[expected.columns].columns.equals(expected.columns)
322322

323323

324-
@pytest.mark.skip(reason="Skipped until inf/nan handling is enabled in the backend")
325324
def test__fetch_experiments_table_nan_inf(new_project_id):
326325
df = fetch_experiments_table(
327326
project=new_project_id,

tests/e2e/v1/test_fetch_metric_buckets.py

Lines changed: 81 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import threading
12
from typing import (
23
Iterable,
34
Literal,
@@ -21,6 +22,7 @@
2122
SysId,
2223
)
2324
from neptune_query.internal.output_format import create_metric_buckets_dataframe
25+
from neptune_query.internal.retrieval import metric_buckets
2426
from neptune_query.internal.retrieval.metric_buckets import TimeseriesBucket
2527
from tests.e2e.data import (
2628
NUMBER_OF_STEPS,
@@ -296,6 +298,85 @@ def test__fetch_metric_buckets__handles_misaligned_steps_in_metrics(
296298
pd.testing.assert_frame_equal(result_df, expected_df)
297299

298300

301+
@pytest.mark.parametrize(
302+
"attribute_filter, expected_attributes",
303+
[
304+
(
305+
AttributeFilter(name=r"series-.*", type=["float_series"]),
306+
[
307+
"series-containing-inf",
308+
"series-containing-nan",
309+
"series-ending-with-inf",
310+
"series-ending-with-nan",
311+
],
312+
),
313+
(
314+
r"series-ending-.*",
315+
["series-ending-with-inf", "series-ending-with-nan"],
316+
),
317+
],
318+
)
319+
def test__fetch_metric_buckets__over_1k_series(
320+
new_project_id,
321+
monkeypatch,
322+
attribute_filter,
323+
expected_attributes,
324+
):
325+
"""
326+
This test verifies that when fetching metric buckets for a run with over 1000 series,
327+
the function correctly splits the requests into multiple chunks to avoid exceeding the limit.
328+
329+
It does so by monkeypatching the actual limit to 1 and capturing the calls made to fetch_time_series_buckets.
330+
"""
331+
original_fetch = metric_buckets.fetch_time_series_buckets
332+
call_chunks: list[list[RunAttributeDefinition]] = []
333+
lock = threading.Lock()
334+
335+
def capture_fetch(*args, **kwargs):
336+
with lock:
337+
call_chunks.append(kwargs["run_attribute_definitions"])
338+
return original_fetch(*args, **kwargs)
339+
340+
monkeypatch.setattr(metric_buckets, "fetch_time_series_buckets", capture_fetch)
341+
342+
forced_limit = 1
343+
monkeypatch.setattr("neptune_query.internal.retrieval.metric_buckets.MAX_SERIES_PER_REQUEST", forced_limit)
344+
monkeypatch.setattr("neptune_query.internal.composition.fetch_metric_buckets.MAX_SERIES_PER_REQUEST", forced_limit)
345+
346+
experiment_name = EXP_NAME_INF_NAN_RUN
347+
run_id = RUN_ID_INF_NAN_RUN
348+
349+
result_df = fetch_metric_buckets(
350+
project=new_project_id,
351+
experiments=[experiment_name],
352+
x="step",
353+
y=attribute_filter,
354+
limit=5,
355+
include_point_previews=False,
356+
lineage_to_the_root=True,
357+
)
358+
359+
expected_data = {
360+
experiment_name: {
361+
attribute_name: RUN_BY_ID[run_id].metrics_values(attribute_name) for attribute_name in expected_attributes
362+
}
363+
}
364+
expected_df = _create_expected_data_metric_buckets_dataframe(
365+
data=expected_data,
366+
project_identifier=new_project_id,
367+
x="step",
368+
limit=5,
369+
include_point_previews=False,
370+
)
371+
372+
pd.testing.assert_frame_equal(result_df, expected_df)
373+
374+
assert len(call_chunks) > 1
375+
total_series = sum(len(chunk) for chunk in call_chunks)
376+
assert total_series > forced_limit
377+
assert all(len(chunk) <= forced_limit for chunk in call_chunks)
378+
379+
299380
@pytest.mark.parametrize(
300381
"arg_experiments,run_id,y",
301382
[
@@ -317,7 +398,6 @@ def test__fetch_metric_buckets__handles_misaligned_steps_in_metrics(
317398
"include_point_previews",
318399
[True],
319400
)
320-
@pytest.mark.skip(reason="Skipped until inf/nan handling is enabled in the backend")
321401
def test__fetch_metric_buckets__inf_nan(
322402
new_project_id,
323403
arg_experiments,

tests/e2e/v1/test_fetch_metrics.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -405,7 +405,6 @@ def test__fetch_metrics__lineage(new_project_id, lineage_to_the_root, expected_v
405405
("series-containing-nan", RUN_BY_ID[RUN_ID_INF_NAN_RUN].metrics_values("series-containing-nan")),
406406
],
407407
)
408-
@pytest.mark.skip(reason="Skipped until inf/nan handling is enabled in the backend")
409408
def test__fetch_metrics_nan_inf(new_project_id, series_name, expected_values):
410409
df = fetch_metrics(
411410
project=new_project_id,

0 commit comments

Comments
 (0)