Skip to content

Commit 5fa7d2d

Browse files
authored
Merge branch 'develop' into claude/zenml-issue-4248-01FHgBjs7uLxfNF5oP7incqW
2 parents 6b41e90 + d8bd68a commit 5fa7d2d

36 files changed

+1334
-238
lines changed

.github/workflows/linting.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ jobs:
7575
remove-android: 'true'
7676
remove-haskell: 'true'
7777
build-mount-path: /var/lib/docker/
78+
if: inputs.os == 'ubuntu-latest'
7879
- name: Checkout code
7980
uses: actions/[email protected]
8081
with:

.github/workflows/unit-test.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@ jobs:
8686
remove-android: 'true'
8787
remove-haskell: 'true'
8888
build-mount-path: /var/lib/docker/
89+
if: inputs.os == 'ubuntu-latest'
8990
- name: Checkout code
9091
uses: actions/[email protected]
9192
with:

docs/book/how-to/steps-pipelines/dynamic_pipelines.md

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,90 @@ Use `runtime="inline"` when you need:
9696
- Shared resources with the orchestrator
9797
- Sequential execution
9898

99+
### Map/Reduce over collections
100+
101+
Dynamic pipelines support a high-level map/reduce pattern over sequence-like step outputs. This lets you fan out a step across items of a collection and then reduce the results without manually writing loops or loading data in the orchestration environment.
102+
103+
```python
104+
from zenml import pipeline, step
105+
106+
@step
107+
def producer() -> list[int]:
108+
return [1, 2, 3]
109+
110+
@step
111+
def worker(value: int) -> int:
112+
return value * 2
113+
114+
@step
115+
def reducer(values: list[int]) -> int:
116+
return sum(values)
117+
118+
@pipeline(dynamic=True, enable_cache=False)
119+
def map_reduce():
120+
values = producer()
121+
results = worker.map(values) # fan out over collection
122+
reducer(results) # pass list of artifacts directly
123+
```
124+
125+
Key points:
126+
- `step.map(...)` fans out a step over sequence-like inputs.
127+
- Steps can accept lists of artifacts directly as inputs (useful for reducers).
128+
- You can pass the mapped output directly to a downstream step without loading in the orchestration environment.
129+
130+
#### Mapping semantics: map vs product
131+
132+
- `step.map(...)`: If multiple sequence-like inputs are provided, all must have the same length `n`. ZenML creates `n` mapped steps where the i-th step receives the i-th element from each input.
133+
- `step.product(...)`: Creates a mapped step for each combination of elements across all input sequences (cartesian product).
134+
135+
Example (cartesian product):
136+
137+
```python
138+
from zenml import pipeline, step
139+
140+
@step
141+
def int_values() -> list[int]:
142+
return [1, 2]
143+
144+
@step
145+
def str_values() -> list[str]:
146+
return ["a", "b", "c"]
147+
148+
@step
149+
def do_something(a: int, b: str) -> int:
150+
...
151+
152+
@pipeline(dynamic=True)
153+
def cartesian_example():
154+
a = int_values()
155+
b = str_values()
156+
# Produces 2 * 3 = 6 mapped steps
157+
combine.product(a, b)
158+
```
159+
160+
#### Broadcasting inputs with unmapped(...)
161+
162+
If you want to pass a sequence-like artifact as a whole to each mapped invocation (i.e., avoid splitting), wrap it with `unmapped(...)`:
163+
164+
```python
165+
from zenml import pipeline, step, unmapped
166+
167+
@step
168+
def producer(length: int) -> list[int]:
169+
return [1] * length
170+
171+
@step
172+
def consumer(a: int, b: list[int]) -> None:
173+
# `b` is the full list for every mapped call
174+
...
175+
176+
@pipeline(dynamic=True)
177+
def unmapped_example():
178+
a = producer(length=3) # list of 3 ints
179+
b = producer(length=4) # list of 4 ints
180+
consumer.map(a=a, b=unmapped(b))
181+
```
182+
99183
### Parallel Step Execution
100184

101185
Dynamic pipelines support true parallel execution using `step.submit()`. This method returns a `StepRunFuture` that you can use to wait for results or pass to downstream steps:
@@ -205,6 +289,11 @@ def dynamic_pipeline():
205289

206290
When you call `.load()` on an artifact in a dynamic pipeline, it synchronously loads the data. For large artifacts or when you want to maintain parallelism, consider passing the step outputs (future or artifact) directly to downstream steps instead of loading them.
207291

292+
### Mapping Limitations
293+
294+
- Mapping is currently supported only over artifacts produced within the same pipeline run (mapping over raw data or external artifacts is not supported).
295+
- Chunk size for mapped collection loading defaults to 1 and is not yet configurable.
296+
208297
## Best Practices
209298

210299
1. **Use `runtime="isolated"` for parallel steps**: This ensures better resource isolation and prevents interference between concurrent step executions.

src/zenml/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ def __getattr__(name: str) -> Any:
6161
from zenml.steps.utils import log_step_metadata
6262
from zenml.utils.metadata_utils import log_metadata, bulk_log_metadata
6363
from zenml.utils.tag_utils import Tag, add_tags, remove_tags
64-
64+
from zenml.execution.pipeline.dynamic.utils import unmapped
6565

6666
__all__ = [
6767
"add_tags",
@@ -84,4 +84,5 @@ def __getattr__(name: str) -> Any:
8484
"register_artifact",
8585
"show",
8686
"step",
87+
"unmapped",
8788
]

src/zenml/artifacts/utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,7 @@ def _store_artifact_data_and_prepare_request(
208208
visualizations=visualizations,
209209
has_custom_name=has_custom_name,
210210
save_type=save_type,
211+
item_count=materializer.get_item_count(data),
211212
metadata=validate_metadata(combined_metadata)
212213
if combined_metadata
213214
else None,

src/zenml/config/compiler.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -468,11 +468,16 @@ def _get_step_spec(
468468
The step spec.
469469
"""
470470
inputs = {
471-
key: InputSpec(
472-
step_name=artifact.invocation_id,
473-
output_name=artifact.output_name,
474-
)
475-
for key, artifact in invocation.input_artifacts.items()
471+
key: [
472+
InputSpec(
473+
step_name=artifact.invocation_id,
474+
output_name=artifact.output_name,
475+
chunk_index=artifact.chunk_index,
476+
chunk_size=artifact.chunk_size,
477+
)
478+
for artifact in artifact_list
479+
]
480+
for key, artifact_list in invocation.input_artifacts.items()
476481
}
477482
return StepSpec(
478483
source=invocation.step.resolve(),

src/zenml/config/step_configurations.py

Lines changed: 31 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -407,25 +407,53 @@ class InputSpec(FrozenBaseModel):
407407

408408
step_name: str
409409
output_name: str
410+
chunk_index: Optional[int] = None
411+
chunk_size: Optional[int] = None
410412

411413

412414
class StepSpec(FrozenBaseModel):
413415
"""Specification of a pipeline."""
414416

415417
source: SourceWithValidator
416418
upstream_steps: List[str]
417-
inputs: Dict[str, InputSpec] = {}
419+
# TODO: This should be `Dict[str, List[InputSpec]]`, but that would break
420+
# client-server compatibility. In the next major release, change this and
421+
# uncomment the code that migrates legacy specs.
422+
inputs: Dict[str, Union[List[InputSpec], InputSpec]] = {}
418423
invocation_id: str
419424
enable_heartbeat: bool = False
420425

421426
@model_validator(mode="before")
422427
@classmethod
423428
@before_validator_handler
424-
def _migrate_invocation_id(cls, data: Dict[str, Any]) -> Dict[str, Any]:
429+
def _migrate_legacy_fields(cls, data: Dict[str, Any]) -> Dict[str, Any]:
425430
if "invocation_id" not in data:
426431
data["invocation_id"] = data.pop("pipeline_parameter_name", "")
432+
433+
# converted_inputs = {}
434+
# for key, value in data.get("inputs", {}).items():
435+
# if isinstance(value, (InputSpec, dict)):
436+
# converted_inputs[key] = [value]
437+
# else:
438+
# converted_inputs[key] = value
439+
# data["inputs"] = converted_inputs
440+
427441
return data
428442

443+
# TODO: Remove this and use the `inputs` property once we change the type
444+
# of the `inputs` field.
445+
@property
446+
def inputs_v2(self) -> Dict[str, List[InputSpec]]:
447+
"""Inputs of the step spec in v2 format.
448+
449+
Returns:
450+
The inputs of the step spec in v2 format.
451+
"""
452+
return {
453+
key: [value] if isinstance(value, InputSpec) else value
454+
for key, value in self.inputs.items()
455+
}
456+
429457
def __eq__(self, other: Any) -> bool:
430458
"""Returns whether the other object is referring to the same step.
431459
@@ -445,7 +473,7 @@ def __eq__(self, other: Any) -> bool:
445473
if self.upstream_steps != other.upstream_steps:
446474
return False
447475

448-
if self.inputs != other.inputs:
476+
if self.inputs_v2 != other.inputs_v2:
449477
return False
450478

451479
if self.invocation_id != other.invocation_id:

src/zenml/execution/pipeline/dynamic/outputs.py

Lines changed: 36 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,10 @@
1414
"""Dynamic pipeline execution outputs."""
1515

1616
from concurrent.futures import Future
17-
from typing import Any, List, Tuple, Union
17+
from typing import Any, List, Optional, Tuple, Union, overload
1818

1919
from zenml.logger import get_logger
20-
from zenml.models import (
21-
ArtifactVersionResponse,
22-
)
20+
from zenml.models import ArtifactVersionResponse
2321

2422
logger = get_logger(__name__)
2523

@@ -29,6 +27,8 @@ class OutputArtifact(ArtifactVersionResponse):
2927

3028
output_name: str
3129
step_name: str
30+
chunk_index: Optional[int] = None
31+
chunk_size: Optional[int] = None
3232

3333

3434
StepRunOutputs = Union[None, OutputArtifact, Tuple[OutputArtifact, ...]]
@@ -191,34 +191,46 @@ def load(self, disable_cache: bool = False) -> Any:
191191
else:
192192
raise ValueError(f"Invalid step run output: {result}")
193193

194-
def __getitem__(self, key: Any) -> ArtifactFuture:
195-
"""Get an artifact future by key or index.
194+
@overload
195+
def __getitem__(self, key: int) -> ArtifactFuture: ...
196+
197+
@overload
198+
def __getitem__(self, key: slice) -> Tuple[ArtifactFuture, ...]: ...
199+
200+
def __getitem__(
201+
self, key: Union[int, slice]
202+
) -> Union[ArtifactFuture, Tuple[ArtifactFuture, ...]]:
203+
"""Get an artifact future.
196204
197205
Args:
198-
key: The key or index of the artifact future.
206+
key: The index or slice of the artifact futures.
199207
200208
Raises:
201-
TypeError: If the key is not an integer.
202-
IndexError: If the index is out of range.
209+
TypeError: If the key is not an integer or slice.
203210
204211
Returns:
205-
The artifact future.
212+
The artifact futures.
206213
"""
207-
if not isinstance(key, int):
208-
raise TypeError(f"Invalid key type: {type(key)}")
214+
if isinstance(key, int):
215+
output_key = self._output_keys[key]
209216

210-
# Convert to positive index if necessary
211-
if key < 0:
212-
key += len(self._output_keys)
213-
214-
if key > len(self._output_keys):
215-
raise IndexError(f"Index out of range: {key}")
216-
217-
return ArtifactFuture(
218-
wrapped=self._wrapped,
219-
invocation_id=self._invocation_id,
220-
index=key,
221-
)
217+
return ArtifactFuture(
218+
wrapped=self._wrapped,
219+
invocation_id=self._invocation_id,
220+
index=self._output_keys.index(output_key),
221+
)
222+
elif isinstance(key, slice):
223+
output_keys = self._output_keys[key]
224+
return tuple(
225+
ArtifactFuture(
226+
wrapped=self._wrapped,
227+
invocation_id=self._invocation_id,
228+
index=self._output_keys.index(output_key),
229+
)
230+
for output_key in output_keys
231+
)
232+
else:
233+
raise TypeError(f"Invalid key type: {type(key)}")
222234

223235
def __iter__(self) -> Any:
224236
"""Iterate over the artifact futures.

0 commit comments

Comments
 (0)