Skip to content

Commit 19cd842

Browse files
committed
Update fetch_series to return values is ascending order
1 parent 585c541 commit 19cd842

File tree

4 files changed

+117
-121
lines changed

4 files changed

+117
-121
lines changed

src/neptune_query/internal/composition/fetch_series.py

Lines changed: 55 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -49,11 +49,10 @@
4949
util,
5050
)
5151
from ..retrieval.search import ContainerType
52+
from ..retrieval.series import SeriesValue
5253

5354
__all__ = ("fetch_series",)
5455

55-
from ..retrieval.series import SeriesValue
56-
5756

5857
def fetch_series(
5958
*,
@@ -113,6 +112,7 @@ def fetch_series(
113112

114113
return df
115114

115+
116116
def _fetch_series(
117117
filter_: Optional[_Filter],
118118
attributes: _BaseAttributeFilter,
@@ -125,69 +125,63 @@ def _fetch_series(
125125
tail_limit: Optional[int],
126126
container_type: ContainerType,
127127
) -> tuple[dict[identifiers.RunAttributeDefinition, list[SeriesValue]], dict[identifiers.SysId, str]]:
128-
sys_id_label_mapping: dict[identifiers.SysId, str] = {}
129-
130-
def go_fetch_sys_attrs() -> Generator[list[identifiers.SysId], None, None]:
131-
for page in search.fetch_sys_id_labels(container_type)(
132-
client=client,
133-
project_identifier=project_identifier,
134-
filter_=filter_,
135-
):
136-
sys_ids = []
137-
for item in page.items:
138-
sys_id_label_mapping[item.sys_id] = item.label
139-
sys_ids.append(item.sys_id)
140-
yield sys_ids
141-
142-
output = concurrency.generate_concurrently(
143-
items=go_fetch_sys_attrs(),
128+
sys_id_label_mapping: dict[identifiers.SysId, str] = {}
129+
130+
def go_fetch_sys_attrs() -> Generator[list[identifiers.SysId], None, None]:
131+
for page in search.fetch_sys_id_labels(container_type)(
132+
client=client,
133+
project_identifier=project_identifier,
134+
filter_=filter_,
135+
):
136+
sys_ids = []
137+
for item in page.items:
138+
sys_id_label_mapping[item.sys_id] = item.label
139+
sys_ids.append(item.sys_id)
140+
yield sys_ids
141+
142+
output = concurrency.generate_concurrently(
143+
items=go_fetch_sys_attrs(),
144+
executor=executor,
145+
downstream=lambda sys_ids: _components.fetch_attribute_definitions_split(
146+
client=client,
147+
project_identifier=project_identifier,
148+
attribute_filter=attributes,
144149
executor=executor,
145-
downstream=lambda sys_ids: _components.fetch_attribute_definitions_split(
146-
client=client,
147-
project_identifier=project_identifier,
148-
attribute_filter=attributes,
149-
executor=executor,
150-
fetch_attribute_definitions_executor=fetch_attribute_definitions_executor,
151-
sys_ids=sys_ids,
152-
downstream=lambda sys_ids_split, definitions_page: concurrency.generate_concurrently(
153-
items=split.split_series_attributes(
154-
items=(
155-
identifiers.RunAttributeDefinition(
156-
run_identifier=identifiers.RunIdentifier(project_identifier, sys_id),
157-
attribute_definition=definition,
158-
)
159-
for sys_id in sys_ids_split
160-
for definition in definitions_page.items
161-
),
150+
fetch_attribute_definitions_executor=fetch_attribute_definitions_executor,
151+
sys_ids=sys_ids,
152+
downstream=lambda sys_ids_split, definitions_page: concurrency.generate_concurrently(
153+
items=split.split_series_attributes(
154+
items=(
155+
identifiers.RunAttributeDefinition(
156+
run_identifier=identifiers.RunIdentifier(project_identifier, sys_id),
157+
attribute_definition=definition,
158+
)
159+
for sys_id in sys_ids_split
160+
for definition in definitions_page.items
162161
),
163-
executor=executor,
164-
downstream=lambda run_attribute_definitions_split: concurrency.generate_concurrently(
165-
items=series.fetch_series_values(
166-
client=client,
167-
run_attribute_definitions=run_attribute_definitions_split,
168-
include_inherited=lineage_to_the_root,
169-
container_type=container_type,
170-
step_range=step_range,
171-
tail_limit=tail_limit,
172-
),
173-
executor=executor,
174-
downstream=concurrency.return_value,
162+
),
163+
executor=executor,
164+
downstream=lambda run_attribute_definitions_split: concurrency.return_value(
165+
series.fetch_series_values(
166+
client=client,
167+
run_attribute_definitions=run_attribute_definitions_split,
168+
include_inherited=lineage_to_the_root,
169+
container_type=container_type,
170+
step_range=step_range,
171+
tail_limit=tail_limit,
175172
),
176173
),
177174
),
178-
)
179-
results: Generator[
180-
util.Page[tuple[identifiers.RunAttributeDefinition, list[series.SeriesValue]]], None, None
181-
] = concurrency.gather_results(output)
175+
),
176+
)
182177

183-
series_data: dict[identifiers.RunAttributeDefinition, list[series.SeriesValue]] = {}
184-
for result in results:
185-
for run_attribute_definition, series_values in result.items:
186-
series_data.setdefault(run_attribute_definition, []).extend(series_values)
178+
results: Generator[
179+
util.Page[tuple[identifiers.RunAttributeDefinition, list[series.SeriesValue]]], None, None
180+
] = concurrency.gather_results(output)
187181

188-
return create_series_dataframe(
189-
series_data,
190-
project_identifier,
191-
sys_id_label_mapping,
192-
index_column_name="experiment" if container_type == ContainerType.EXPERIMENT else "run",
193-
)
182+
series_data: dict[identifiers.RunAttributeDefinition, list[series.SeriesValue]] = {}
183+
for result in results:
184+
for run_attribute_definition, series_values in result.items:
185+
series_data.setdefault(run_attribute_definition, []).extend(series_values)
186+
187+
return series_data, sys_id_label_mapping

src/neptune_query/internal/output_format.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -296,6 +296,7 @@ def generate_categorized_rows() -> Generator[Tuple, None, None]:
296296

297297
def create_series_dataframe(
298298
series_data: dict[identifiers.RunAttributeDefinition, list[series.SeriesValue]],
299+
# TODO: PY-310 remove unused parameter project_identifier
299300
project_identifier: str,
300301
sys_id_label_mapping: dict[identifiers.SysId, str],
301302
index_column_name: str,

src/neptune_query/internal/retrieval/series.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
import functools as ft
1616
from typing import (
1717
Any,
18-
Generator,
1918
Iterable,
2019
NamedTuple,
2120
Optional,
@@ -58,10 +57,9 @@ def fetch_series_values(
5857
container_type: ContainerType,
5958
step_range: Tuple[Union[float, None], Union[float, None]] = (None, None),
6059
tail_limit: Optional[int] = None,
61-
) -> Generator[util.Page[tuple[RunAttributeDefinition, list[SeriesValue]]], None, None]:
60+
) -> dict[RunAttributeDefinition, list[SeriesValue]]:
6261
if not run_attribute_definitions:
63-
yield from []
64-
return
62+
return {}
6563

6664
run_attribute_definitions = list(run_attribute_definitions)
6765
width = len(str(len(run_attribute_definitions) - 1))
@@ -86,20 +84,33 @@ def fetch_series_values(
8684
for request_id, run_definition in request_id_to_run_attr_definition.items()
8785
],
8886
"stepRange": {"from": step_range[0], "to": step_range[1]},
87+
# Fetch in descending order to enable efficient tail fetching
8988
"order": "descending",
9089
}
9190
if tail_limit is not None:
9291
params["perSeriesPointsLimit"] = tail_limit
9392

94-
yield from util.fetch_pages(
93+
results: dict[RunAttributeDefinition, list[SeriesValue]] = {
94+
run_attribute: [] for run_attribute in run_attribute_definitions
95+
}
96+
97+
for page_result in util.fetch_pages(
9598
client=client,
9699
fetch_page=_fetch_series_page,
97100
process_page=ft.partial(
98101
_process_series_page, request_id_to_run_attr_definition=request_id_to_run_attr_definition
99102
),
100103
make_new_page_params=_make_new_series_page_params,
101104
initial_params=params,
102-
)
105+
):
106+
for attribute, values in page_result.items:
107+
results[attribute].extend(values)
108+
109+
# Reverse the order of values to maintain ascending order
110+
for attribute in results:
111+
results[attribute].reverse()
112+
113+
return results
103114

104115

105116
def _fetch_series_page(

0 commit comments

Comments
 (0)