Skip to content

Commit d349fad

Browse files
committed
Rewrite internal/retrieval/test_metric_buckets.py
1 parent 6a80b90 commit d349fad

File tree

1 file changed

+75
-18
lines changed

1 file changed

+75
-18
lines changed

tests/e2e/internal/retrieval/test_metric_buckets.py

Lines changed: 75 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,88 @@
11
import pytest
22

3+
from neptune_query.internal.filters import _Filter
34
from neptune_query.internal.identifiers import (
45
AttributeDefinition,
6+
ProjectIdentifier,
57
RunAttributeDefinition,
8+
RunIdentifier,
9+
SysId,
610
)
11+
from neptune_query.internal.retrieval import search
712
from neptune_query.internal.retrieval.metric_buckets import (
813
TimeseriesBucket,
914
fetch_time_series_buckets,
1015
)
1116
from neptune_query.internal.retrieval.search import ContainerType
12-
from tests.e2e.data import (
13-
FLOAT_SERIES_PATHS,
14-
PATH,
15-
TEST_DATA,
17+
from tests.e2e.data_ingestion import (
18+
IngestedProjectData,
19+
ProjectData,
20+
RunData,
1621
)
1722
from tests.e2e.metric_buckets import (
1823
aggregate_metric_buckets,
1924
calculate_global_range,
2025
calculate_metric_bucket_ranges,
2126
)
2227

23-
EXPERIMENT = TEST_DATA.experiments[0]
28+
29+
@pytest.fixture(scope="module")
30+
def project(ensure_project, test_execution_id) -> IngestedProjectData:
31+
project_data = ProjectData(
32+
project_name_base="metric-buckets-project",
33+
runs=[
34+
RunData(
35+
experiment_name_base="metric-buckets-experiment",
36+
run_id_base="metric-buckets-run-id",
37+
configs={
38+
"configs/int-value": 7,
39+
"configs/string-value": "example-config",
40+
},
41+
float_series={
42+
"metrics/float-series-value_0": {
43+
0.0: 0.5,
44+
1.0: 1.5,
45+
2.0: 2.5,
46+
3.0: 3.5,
47+
},
48+
"metrics/float-series-value_1": {
49+
0.0: 10.0,
50+
1.0: 11.0,
51+
2.0: 12.0,
52+
3.0: 13.0,
53+
},
54+
"metrics/step": {
55+
0.0: 0.0,
56+
1.0: 1.0,
57+
2.0: 2.0,
58+
3.0: 3.0,
59+
},
60+
},
61+
)
62+
],
63+
)
64+
65+
unique_key = f"{test_execution_id}__metric_buckets"
66+
return ensure_project(project_data, unique_key=unique_key)
67+
68+
69+
@pytest.fixture(scope="module")
70+
def experiment_identifier(client, project) -> RunIdentifier:
71+
project_identifier = ProjectIdentifier(project.project_identifier)
72+
experiment_name = project.ingested_runs[0].experiment_name
73+
74+
sys_ids: list[SysId] = []
75+
for page in search.fetch_experiment_sys_ids(
76+
client=client,
77+
project_identifier=project_identifier,
78+
filter_=_Filter.name_eq(experiment_name),
79+
):
80+
sys_ids.extend(page.items)
81+
82+
if len(sys_ids) != 1:
83+
raise RuntimeError(f"Expected to fetch exactly one sys_id for {experiment_name}, got {sys_ids}")
84+
85+
return RunIdentifier(project_identifier=project_identifier, sys_id=SysId(sys_ids[0]))
2486

2587

2688
def test_fetch_time_series_buckets_does_not_exist(client, project, experiment_identifier):
@@ -43,15 +105,6 @@ def test_fetch_time_series_buckets_does_not_exist(client, project, experiment_id
43105
assert result == {run_definition: []}
44106

45107

46-
@pytest.mark.parametrize(
47-
"attribute_name, expected_values",
48-
[
49-
(
50-
FLOAT_SERIES_PATHS[0],
51-
list(zip(EXPERIMENT.float_series[f"{PATH}/metrics/step"], EXPERIMENT.float_series[FLOAT_SERIES_PATHS[0]])),
52-
),
53-
],
54-
)
55108
@pytest.mark.parametrize(
56109
"limit",
57110
[2, 10, 100],
@@ -60,11 +113,15 @@ def test_fetch_time_series_buckets_does_not_exist(client, project, experiment_id
60113
"x_range",
61114
[None, (1, 2), (-100, 100)],
62115
)
63-
def test_fetch_time_series_buckets_single_series(
64-
client, project, experiment_identifier, attribute_name, expected_values, limit, x_range
65-
):
116+
def test_fetch_time_series_buckets_single_series(client, project, experiment_identifier, limit, x_range):
66117
# given
67-
run_definition = RunAttributeDefinition(experiment_identifier, AttributeDefinition(attribute_name, "float-series"))
118+
run_definition = RunAttributeDefinition(
119+
experiment_identifier, AttributeDefinition("metrics/float-series-value_0", "float-series")
120+
)
121+
ingested_run = project.ingested_runs[0]
122+
step_values = ingested_run.float_series["metrics/step"]
123+
series_values = ingested_run.float_series["metrics/float-series-value_0"]
124+
expected_values = [(step_values[step], series_values[step]) for step in sorted(series_values.keys())]
68125

69126
# when
70127
result = fetch_time_series_buckets(

0 commit comments

Comments
 (0)