Skip to content

Commit 69b6b80

Browse files
bcdurakAlexejPenneravishniakovactions-userschustmi
authored
New log_metadata function, new oneof filtering, additional run_metadata filtering (#3182)
* Initial commit, nuking all metadata responses and seeing what breaks * Removed last remnant of LazyLoader * Reintroducing the lazy loaders. * Add LazyRunMetadataResponse to EntrypointFunctionDefinition * Test for lazy loaders works now * Fixed tests, reformatted * Use updated template * Auto-update of Starter template * Updated more templates * Fixed failing test * Fixed step run schemas * Auto-update of E2E template * Auto-update of NLP template * Fixed tests, removed additional .value access * Further fixing * Fixed linting issues * Reformatted * Linted, formatted and tested again * Typing * Maybe fix everything * Apply some feedback * new operation * new log_metadata function * changes to the base filters * new filters * adding log_metadata to __all__ * checkpoint with float casting * adding tests * final touches and formatting * formatting * moved the utils * modified log metadata function * checkpoint * deprecating the old functions * linting and final fixes * better error message * fixing the client method * better error message * consistent creation\ * adjusting tests * linting * changes for step metadata * more test adjustments * testing unit tests * linting * fixing more tests * fixing more tests * more test fixes * fixing the test * fixing per comments * added validation, constant error message * linting --------- Co-authored-by: AlexejPenner <[email protected]> Co-authored-by: Andrei Vishniakov <[email protected]> Co-authored-by: GitHub Actions <[email protected]> Co-authored-by: Michael Schuster <[email protected]>
1 parent a624ab8 commit 69b6b80

File tree

20 files changed

+709
-68
lines changed

20 files changed

+709
-68
lines changed

src/zenml/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
from zenml.pipelines import get_pipeline_context, pipeline
4949
from zenml.steps import step, get_step_context
5050
from zenml.steps.utils import log_step_metadata
51+
from zenml.utils.metadata_utils import log_metadata
5152
from zenml.entrypoints import entrypoint
5253

5354
__all__ = [
@@ -56,6 +57,7 @@
5657
"get_pipeline_context",
5758
"get_step_context",
5859
"load_artifact",
60+
"log_metadata",
5961
"log_artifact_metadata",
6062
"log_model_metadata",
6163
"log_step_metadata",

src/zenml/artifacts/utils.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -408,7 +408,7 @@ def log_artifact_metadata(
408408
not provided, when being called inside a step that produces an
409409
artifact named `artifact_name`, the metadata will be associated to
410410
the corresponding newly created artifact. Or, if not provided when
411-
being called outside of a step, or in a step that does not produce
411+
being called outside a step, or in a step that does not produce
412412
any artifact named `artifact_name`, the metadata will be associated
413413
to the latest version of that artifact.
414414
@@ -417,6 +417,10 @@ def log_artifact_metadata(
417417
called inside a step with a single output, or, if neither an
418418
artifact nor an output with the given name exists.
419419
"""
420+
logger.warning(
421+
"The `log_artifact_metadata` function is deprecated and will soon be "
422+
"removed. Please use `log_metadata` instead."
423+
)
420424
try:
421425
step_context = get_step_context()
422426
in_step_outputs = (artifact_name in step_context._outputs) or (

src/zenml/client.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3796,6 +3796,7 @@ def list_pipeline_runs(
37963796
templatable: Optional[bool] = None,
37973797
tag: Optional[str] = None,
37983798
user: Optional[Union[UUID, str]] = None,
3799+
run_metadata: Optional[Dict[str, str]] = None,
37993800
pipeline: Optional[Union[UUID, str]] = None,
38003801
code_repository: Optional[Union[UUID, str]] = None,
38013802
model: Optional[Union[UUID, str]] = None,
@@ -3835,6 +3836,7 @@ def list_pipeline_runs(
38353836
templatable: If the runs should be templatable or not.
38363837
tag: Tag to filter by.
38373838
user: The name/ID of the user to filter by.
3839+
run_metadata: The run_metadata of the run to filter by.
38383840
pipeline: The name/ID of the pipeline to filter by.
38393841
code_repository: Filter by code repository name/ID.
38403842
model: Filter by model name/ID.
@@ -3874,6 +3876,7 @@ def list_pipeline_runs(
38743876
tag=tag,
38753877
unlisted=unlisted,
38763878
user=user,
3879+
run_metadata=run_metadata,
38773880
pipeline=pipeline,
38783881
code_repository=code_repository,
38793882
stack=stack,
@@ -4194,7 +4197,7 @@ def get_artifact_version(
41944197
),
41954198
)
41964199
except RuntimeError:
4197-
pass # Cannot link to step run if called outside of a step
4200+
pass # Cannot link to step run if called outside a step
41984201
return artifact
41994202

42004203
def list_artifact_versions(
@@ -4222,6 +4225,7 @@ def list_artifact_versions(
42224225
user: Optional[Union[UUID, str]] = None,
42234226
model: Optional[Union[UUID, str]] = None,
42244227
pipeline_run: Optional[Union[UUID, str]] = None,
4228+
run_metadata: Optional[Dict[str, str]] = None,
42254229
tag: Optional[str] = None,
42264230
hydrate: bool = False,
42274231
) -> Page[ArtifactVersionResponse]:
@@ -4253,6 +4257,7 @@ def list_artifact_versions(
42534257
user: Filter by user name or ID.
42544258
model: Filter by model name or ID.
42554259
pipeline_run: Filter by pipeline run name or ID.
4260+
run_metadata: Filter by run metadata.
42564261
hydrate: Flag deciding whether to hydrate the output model(s)
42574262
by including metadata fields in the response.
42584263

src/zenml/enums.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -253,6 +253,7 @@ class GenericFilterOps(StrEnum):
253253
CONTAINS = "contains"
254254
STARTSWITH = "startswith"
255255
ENDSWITH = "endswith"
256+
ONEOF = "oneof"
256257
GTE = "gte"
257258
GT = "gt"
258259
LTE = "lte"

src/zenml/model/utils.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,11 @@ def log_model_metadata(
5656
ValueError: If no model name/version is provided and the function is not
5757
called inside a step with configured `model` in decorator.
5858
"""
59+
logger.warning(
60+
"The `log_model_metadata` function is deprecated and will soon be "
61+
"removed. Please use `log_metadata` instead."
62+
)
63+
5964
if model_name and model_version:
6065
from zenml import Model
6166

src/zenml/models/v2/base/filter.py

Lines changed: 121 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# permissions and limitations under the License.
1414
"""Base filter model definitions."""
1515

16+
import json
1617
from abc import ABC, abstractmethod
1718
from datetime import datetime
1819
from typing import (
@@ -36,7 +37,7 @@
3637
field_validator,
3738
model_validator,
3839
)
39-
from sqlalchemy import asc, desc
40+
from sqlalchemy import Float, and_, asc, cast, desc
4041
from sqlmodel import SQLModel
4142

4243
from zenml.constants import (
@@ -63,6 +64,11 @@
6364

6465
AnyQuery = TypeVar("AnyQuery", bound=Any)
6566

67+
ONEOF_ERROR = (
68+
"When you are using the 'oneof:' filtering make sure that the "
69+
"provided value is a json formatted list."
70+
)
71+
6672

6773
class Filter(BaseModel, ABC):
6874
"""Filter for all fields.
@@ -171,8 +177,28 @@ class StrFilter(Filter):
171177
GenericFilterOps.STARTSWITH,
172178
GenericFilterOps.CONTAINS,
173179
GenericFilterOps.ENDSWITH,
180+
GenericFilterOps.ONEOF,
181+
GenericFilterOps.GT,
182+
GenericFilterOps.GTE,
183+
GenericFilterOps.LT,
184+
GenericFilterOps.LTE,
174185
]
175186

187+
@model_validator(mode="after")
188+
def check_value_if_operation_oneof(self) -> "StrFilter":
189+
"""Validator to check if value is a list if oneof operation is used.
190+
191+
Raises:
192+
ValueError: If the value is not a list
193+
194+
Returns:
195+
self
196+
"""
197+
if self.operation == GenericFilterOps.ONEOF:
198+
if not isinstance(self.value, list):
199+
raise ValueError(ONEOF_ERROR)
200+
return self
201+
176202
def generate_query_conditions_from_column(self, column: Any) -> Any:
177203
"""Generate query conditions for a string column.
178204
@@ -181,6 +207,9 @@ def generate_query_conditions_from_column(self, column: Any) -> Any:
181207
182208
Returns:
183209
A list of query conditions.
210+
211+
Raises:
212+
ValueError: the comparison of the column to a numeric value fails.
184213
"""
185214
if self.operation == GenericFilterOps.CONTAINS:
186215
return column.like(f"%{self.value}%")
@@ -190,6 +219,40 @@ def generate_query_conditions_from_column(self, column: Any) -> Any:
190219
return column.endswith(f"{self.value}")
191220
if self.operation == GenericFilterOps.NOT_EQUALS:
192221
return column != self.value
222+
if self.operation == GenericFilterOps.ONEOF:
223+
return column.in_(self.value)
224+
if self.operation in {
225+
GenericFilterOps.GT,
226+
GenericFilterOps.LT,
227+
GenericFilterOps.GTE,
228+
GenericFilterOps.LTE,
229+
}:
230+
try:
231+
numeric_column = cast(column, Float)
232+
233+
assert self.value is not None
234+
235+
if self.operation == GenericFilterOps.GT:
236+
return and_(
237+
numeric_column, numeric_column > float(self.value)
238+
)
239+
if self.operation == GenericFilterOps.LT:
240+
return and_(
241+
numeric_column, numeric_column < float(self.value)
242+
)
243+
if self.operation == GenericFilterOps.GTE:
244+
return and_(
245+
numeric_column, numeric_column >= float(self.value)
246+
)
247+
if self.operation == GenericFilterOps.LTE:
248+
return and_(
249+
numeric_column, numeric_column <= float(self.value)
250+
)
251+
except Exception as e:
252+
raise ValueError(
253+
f"Failed to compare the column '{column}' to the "
254+
f"value '{self.value}' (must be numeric): {e}"
255+
)
193256

194257
return column == self.value
195258

@@ -211,6 +274,9 @@ def _remove_hyphens_from_value(cls, value: Any) -> Any:
211274
if isinstance(value, str):
212275
return value.replace("-", "")
213276

277+
if isinstance(value, list):
278+
return [str(v).replace("-", "") for v in value]
279+
214280
return value
215281

216282
def generate_query_conditions_from_column(self, column: Any) -> Any:
@@ -588,6 +654,10 @@ def _resolve_operator(value: Any) -> Tuple[Any, GenericFilterOps]:
588654
589655
Returns:
590656
A tuple of the filter value and the operator.
657+
658+
Raises:
659+
ValueError: when we try to use the `oneof` operator with the wrong
660+
value.
591661
"""
592662
operator = GenericFilterOps.EQUALS # Default operator
593663
if isinstance(value, str):
@@ -598,6 +668,15 @@ def _resolve_operator(value: Any) -> Tuple[Any, GenericFilterOps]:
598668
):
599669
value = split_value[1]
600670
operator = GenericFilterOps(split_value[0])
671+
672+
if operator == operator.ONEOF:
673+
try:
674+
value = json.loads(value)
675+
if not isinstance(value, list):
676+
raise ValueError
677+
except ValueError:
678+
raise ValueError(ONEOF_ERROR)
679+
601680
return value, operator
602681

603682
def generate_name_or_id_query_conditions(
@@ -648,8 +727,8 @@ def generate_name_or_id_query_conditions(
648727

649728
return or_(*conditions)
650729

730+
@staticmethod
651731
def generate_custom_query_conditions_for_column(
652-
self,
653732
value: Any,
654733
table: Type[SQLModel],
655734
column: str,
@@ -833,16 +912,17 @@ def define_filter(
833912

834913
# Create str filters
835914
if self.is_str_field(column):
836-
return StrFilter(
837-
operation=GenericFilterOps(operator),
915+
return self._define_str_filter(
916+
operator=GenericFilterOps(operator),
838917
column=column,
839918
value=value,
840919
)
841920

842921
# Handle unsupported datatypes
843922
logger.warning(
844-
f"The Datatype {self._model_class.model_fields[column].annotation} might "
845-
"not be supported for filtering. Defaulting to a string filter."
923+
f"The Datatype {self._model_class.model_fields[column].annotation} "
924+
"might not be supported for filtering. Defaulting to a string "
925+
"filter."
846926
)
847927
return StrFilter(
848928
operation=GenericFilterOps(operator),
@@ -1032,8 +1112,9 @@ def _define_uuid_filter(
10321112
"Invalid value passed as UUID query parameter."
10331113
) from e
10341114

1035-
# Cast the value to string for further comparisons.
1036-
value = str(value)
1115+
# For equality checks, ensure that the value is a valid UUID.
1116+
if operator == GenericFilterOps.ONEOF and not isinstance(value, list):
1117+
raise ValueError(ONEOF_ERROR)
10371118

10381119
# Generate the filter.
10391120
uuid_filter = UUIDFilter(
@@ -1043,6 +1124,38 @@ def _define_uuid_filter(
10431124
)
10441125
return uuid_filter
10451126

1127+
@staticmethod
1128+
def _define_str_filter(
1129+
column: str, value: Any, operator: GenericFilterOps
1130+
) -> StrFilter:
1131+
"""Define a str filter for a given column.
1132+
1133+
Args:
1134+
column: The column to filter on.
1135+
value: The UUID value by which to filter.
1136+
operator: The operator to use for filtering.
1137+
1138+
Returns:
1139+
A Filter object.
1140+
1141+
Raises:
1142+
ValueError: If the value is not a proper value.
1143+
"""
1144+
# For equality checks, ensure that the value is a valid UUID.
1145+
if operator == GenericFilterOps.ONEOF and not isinstance(value, list):
1146+
raise ValueError(
1147+
"If you are using `oneof:` as a filtering op, the value needs "
1148+
"to be a json formatted list string."
1149+
)
1150+
1151+
# Generate the filter.
1152+
str_filter = StrFilter(
1153+
operation=GenericFilterOps(operator),
1154+
column=column,
1155+
value=value,
1156+
)
1157+
return str_filter
1158+
10461159
@staticmethod
10471160
def _define_bool_filter(
10481161
column: str, value: Any, operator: GenericFilterOps

src/zenml/models/v2/core/artifact_version.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -474,6 +474,7 @@ class ArtifactVersionFilter(WorkspaceScopedTaggableFilter):
474474
"user",
475475
"model",
476476
"pipeline_run",
477+
"run_metadata",
477478
]
478479
artifact_id: Optional[Union[UUID, str]] = Field(
479480
default=None,
@@ -545,6 +546,10 @@ class ArtifactVersionFilter(WorkspaceScopedTaggableFilter):
545546
description="Name/ID of a pipeline run that is associated with this "
546547
"artifact version.",
547548
)
549+
run_metadata: Optional[Dict[str, str]] = Field(
550+
default=None,
551+
description="The run_metadata to filter the artifact versions by.",
552+
)
548553

549554
model_config = ConfigDict(protected_namespaces=())
550555

@@ -564,6 +569,7 @@ def get_custom_filters(self) -> List[Union["ColumnElement[bool]"]]:
564569
ModelSchema,
565570
ModelVersionArtifactSchema,
566571
PipelineRunSchema,
572+
RunMetadataSchema,
567573
StepRunInputArtifactSchema,
568574
StepRunOutputArtifactSchema,
569575
StepRunSchema,
@@ -645,6 +651,23 @@ def get_custom_filters(self) -> List[Union["ColumnElement[bool]"]]:
645651
)
646652
custom_filters.append(pipeline_run_filter)
647653

654+
if self.run_metadata is not None:
655+
from zenml.enums import MetadataResourceTypes
656+
657+
for key, value in self.run_metadata.items():
658+
additional_filter = and_(
659+
RunMetadataSchema.resource_id == ArtifactVersionSchema.id,
660+
RunMetadataSchema.resource_type
661+
== MetadataResourceTypes.ARTIFACT_VERSION,
662+
RunMetadataSchema.key == key,
663+
self.generate_custom_query_conditions_for_column(
664+
value=value,
665+
table=RunMetadataSchema,
666+
column="value",
667+
),
668+
)
669+
custom_filters.append(additional_filter)
670+
648671
return custom_filters
649672

650673

0 commit comments

Comments
 (0)