Skip to content

Commit

Permalink
[components] Templating for asset_attributes
Browse files Browse the repository at this point in the history
  • Loading branch information
OwenKephart committed Dec 23, 2024
1 parent b96403d commit 7eb1d13
Show file tree
Hide file tree
Showing 7 changed files with 114 additions and 63 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from pydantic import TypeAdapter
from typing_extensions import Self

from dagster_components.core.component_rendering import TemplatedValueResolver, preprocess_value
from dagster_components.core.component_rendering import TemplatedValueResolver


class ComponentDeclNode: ...
Expand Down Expand Up @@ -254,8 +254,8 @@ def _raw_params(self) -> Optional[Mapping[str, Any]]:

def load_params(self, params_schema: Type[T]) -> T:
with pushd(str(self.path)):
preprocessed_params = preprocess_value(
self.templated_value_resolver, self._raw_params(), params_schema
preprocessed_params = self.templated_value_resolver.resolve_params(
self._raw_params(), params_schema
)
return TypeAdapter(params_schema).validate_python(preprocessed_params)

Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import functools
import json
import os
from typing import AbstractSet, Any, Mapping, Optional, Sequence, Type, TypeVar, Union
from typing import AbstractSet, Any, Callable, Mapping, Optional, Sequence, Type, TypeVar, Union

import dagster._check as check
from dagster._record import record
Expand Down Expand Up @@ -40,6 +41,9 @@ def _env(key: str) -> Optional[str]:
return os.environ.get(key)


ShouldRenderFn = Callable[[Sequence[Union[str, int]]], bool]


@record
class TemplatedValueResolver:
context: Mapping[str, Any]
Expand All @@ -51,11 +55,49 @@ def default() -> "TemplatedValueResolver":
def with_context(self, **additional_context) -> "TemplatedValueResolver":
return TemplatedValueResolver(context={**self.context, **additional_context})

def resolve(self, val: Any) -> Any:
def _resolve_value(self, val: Any) -> Any:
return NativeTemplate(val).render(**self.context) if isinstance(val, str) else val

def _resolve(
self,
val: Any,
valpath: Optional[Sequence[Union[str, int]]],
should_render: Callable[[Sequence[Union[str, int]]], bool],
) -> Any:
if valpath is not None and not should_render(valpath):
return val
elif isinstance(val, dict):
return {
k: self._resolve(v, [*valpath, k] if valpath is not None else None, should_render)
for k, v in val.items()
}
elif isinstance(val, list):
return [
self._resolve(v, [*valpath, i] if valpath is not None else None, should_render)
for i, v in enumerate(val)
]
else:
return self._resolve_value(val)

def _should_render(
def resolve(self, val: Any) -> Any:
"""Given a raw value, preprocesses it by rendering any templated values."""
return self._resolve(val, None, lambda _: True)

def resolve_params(self, val: T, target_type: Type) -> T:
"""Given a raw value, preprocesses it by rendering any templated values that are not marked as deferred in the target_type's json schema."""
json_schema = (
target_type.model_json_schema() if issubclass(target_type, BaseModel) else None
)
if json_schema is None:
should_render = lambda _: True
else:
should_render = functools.partial(
has_rendering_scope, json_schema=json_schema, subschema=json_schema
)
return self._resolve(val, [], should_render=should_render)


def has_rendering_scope(
valpath: Sequence[Union[str, int]], json_schema: Mapping[str, Any], subschema: Mapping[str, Any]
) -> bool:
# List[ComplexType] (e.g.) will contain a reference to the complex type schema in the
Expand All @@ -70,7 +112,7 @@ def _should_render(

# Optional[ComplexType] (e.g.) will contain multiple schemas in the "anyOf" field
if "anyOf" in subschema:
return all(_should_render(valpath, json_schema, inner) for inner in subschema["anyOf"])
return all(has_rendering_scope(valpath, json_schema, inner) for inner in subschema["anyOf"])

el = valpath[0]
if isinstance(el, str):
Expand All @@ -89,30 +131,4 @@ def _should_render(
return subschema.get("additionalProperties", True)

_, *rest = valpath
return _should_render(rest, json_schema, inner)


def _render_values(
value_resolver: TemplatedValueResolver,
val: Any,
valpath: Sequence[Union[str, int]],
json_schema: Optional[Mapping[str, Any]],
) -> Any:
if json_schema and not _should_render(valpath, json_schema, json_schema):
return val
elif isinstance(val, dict):
return {
k: _render_values(value_resolver, v, [*valpath, k], json_schema) for k, v in val.items()
}
elif isinstance(val, list):
return [
_render_values(value_resolver, v, [*valpath, i], json_schema) for i, v in enumerate(val)
]
else:
return value_resolver.resolve(val)


def preprocess_value(renderer: TemplatedValueResolver, val: T, target_type: Type) -> T:
"""Given a raw value, preprocesses it by rendering any templated values that are not marked as deferred in the target_type's json schema."""
json_schema = target_type.model_json_schema() if issubclass(target_type, BaseModel) else None
return _render_values(renderer, val, [], json_schema)
return has_rendering_scope(rest, json_schema, inner)
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from abc import ABC, abstractmethod
from typing import Annotated, Any, Dict, Literal, Mapping, Optional, Sequence, Union
from abc import ABC
from typing import AbstractSet, Annotated, Any, Dict, Literal, Mapping, Optional, Sequence, Union

from dagster._core.definitions.asset_key import AssetKey
from dagster._core.definitions.asset_selection import AssetSelection
from dagster._core.definitions.asset_spec import AssetSpec, map_asset_specs
from dagster._core.definitions.assets import AssetsDefinition
Expand All @@ -11,6 +12,8 @@
from dagster._record import replace
from pydantic import BaseModel, Field

from dagster_components.core.component_rendering import RenderingScope, TemplatedValueResolver


class OpSpecBaseModel(BaseModel):
name: Optional[str] = None
Expand All @@ -33,26 +36,31 @@ class AssetSpecProcessor(ABC, BaseModel):
tags: Optional[Mapping[str, str]] = None
automation_condition: Optional[AutomationConditionModel] = None

def _attributes(self) -> Mapping[str, Any]:
return {
**self.model_dump(exclude={"target", "operation"}, exclude_unset=True),
**{
"automation_condition": self.automation_condition.to_automation_condition()
if self.automation_condition
else None
},
}
def _apply_to_spec(self, spec: AssetSpec, attributes: Mapping[str, Any]) -> AssetSpec: ...

def apply_to_spec(
self,
spec: AssetSpec,
value_resolver: TemplatedValueResolver,
target_keys: AbstractSet[AssetKey],
) -> AssetSpec:
if spec.key not in target_keys:
return spec

@abstractmethod
def _apply_to_spec(self, spec: AssetSpec) -> AssetSpec: ...
# add the original spec to the context and resolve values
attributes = value_resolver.with_context(asset=spec).resolve(
self.model_dump(exclude={"target", "operation"}, exclude_unset=True)
)
return self._apply_to_spec(spec, attributes)

def apply(self, defs: Definitions) -> Definitions:
def apply(self, defs: Definitions, value_resolver: TemplatedValueResolver) -> Definitions:
target_selection = AssetSelection.from_string(self.target, include_sources=True)
target_keys = target_selection.resolve(defs.get_asset_graph())

mappable = [d for d in defs.assets or [] if isinstance(d, (AssetsDefinition, AssetSpec))]
mapped_assets = map_asset_specs(
lambda spec: self._apply_to_spec(spec) if spec.key in target_keys else spec, mappable
lambda spec: self.apply_to_spec(spec, value_resolver, target_keys),
mappable,
)

assets = [
Expand All @@ -66,8 +74,7 @@ class MergeAttributes(AssetSpecProcessor):
# default operation is "merge"
operation: Literal["merge"] = "merge"

def _apply_to_spec(self, spec: AssetSpec) -> AssetSpec:
attributes = self._attributes()
def _apply_to_spec(self, spec: AssetSpec, attributes: Mapping[str, Any]) -> AssetSpec:
mergeable_attributes = {"metadata", "tags"}
merge_attributes = {k: v for k, v in attributes.items() if k in mergeable_attributes}
replace_attributes = {k: v for k, v in attributes.items() if k not in mergeable_attributes}
Expand All @@ -78,13 +85,13 @@ class ReplaceAttributes(AssetSpecProcessor):
# operation must be set explicitly
operation: Literal["replace"]

def _apply_to_spec(self, spec: AssetSpec) -> AssetSpec:
return spec.replace_attributes(**self._attributes())
def _apply_to_spec(self, spec: AssetSpec, attributes: Mapping[str, Any]) -> AssetSpec:
return spec.replace_attributes(**attributes)


AssetAttributes = Sequence[
Annotated[
Union[MergeAttributes, ReplaceAttributes],
Field(union_mode="left_to_right"),
RenderingScope(Field(union_mode="left_to_right"), required_scope={"asset"}),
]
]
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def _fn(context: AssetExecutionContext):

defs = Definitions(assets=[_fn])
for transform in self.asset_processors:
defs = transform.apply(defs)
defs = transform.apply(defs, context.templated_value_resolver)
return defs

@classmethod
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def _fn(context: AssetExecutionContext, sling: SlingResource):

defs = Definitions(assets=[_fn], resources={"sling": self.resource})
for transform in self.asset_processors:
defs = transform.apply(defs)
defs = transform.apply(defs, context.templated_value_resolver)
return defs

@classmethod
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,7 @@
from dagster_components.core.component_rendering import (
RenderingScope,
TemplatedValueResolver,
_should_render,
preprocess_value,
has_rendering_scope,
)
from pydantic import BaseModel, Field, TypeAdapter

Expand Down Expand Up @@ -44,7 +43,9 @@ class Outer(BaseModel):
],
)
def test_should_render(path, expected: bool) -> None:
assert _should_render(path, Outer.model_json_schema(), Outer.model_json_schema()) == expected
assert (
has_rendering_scope(path, Outer.model_json_schema(), Outer.model_json_schema()) == expected
)


def test_render() -> None:
Expand All @@ -61,7 +62,7 @@ def test_render() -> None:
}

renderer = TemplatedValueResolver(context={"foo_val": "foo", "bar_val": "bar"})
rendered_data = preprocess_value(renderer, data, Outer)
rendered_data = renderer.resolve_params(data, Outer)

assert rendered_data == {
"a": "foo",
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
import pytest
from dagster import AssetKey, AssetSpec, Definitions
from dagster_components.core.dsl_schema import AssetAttributes, MergeAttributes, ReplaceAttributes
from dagster_components.core.dsl_schema import (
AssetAttributes,
MergeAttributes,
ReplaceAttributes,
TemplatedValueResolver,
)
from pydantic import BaseModel, TypeAdapter


Expand All @@ -20,7 +25,7 @@ class M(BaseModel):
def test_replace_attributes() -> None:
op = ReplaceAttributes(operation="replace", target="group:g2", tags={"newtag": "newval"})

newdefs = op.apply(defs)
newdefs = op.apply(defs, TemplatedValueResolver.default())
asset_graph = newdefs.get_asset_graph()
assert asset_graph.get(AssetKey("a")).tags == {}
assert asset_graph.get(AssetKey("b")).tags == {"newtag": "newval"}
Expand All @@ -30,13 +35,35 @@ def test_replace_attributes() -> None:
def test_merge_attributes() -> None:
op = MergeAttributes(operation="merge", target="group:g2", tags={"newtag": "newval"})

newdefs = op.apply(defs)
newdefs = op.apply(defs, TemplatedValueResolver.default())
asset_graph = newdefs.get_asset_graph()
assert asset_graph.get(AssetKey("a")).tags == {}
assert asset_graph.get(AssetKey("b")).tags == {"newtag": "newval"}
assert asset_graph.get(AssetKey("c")).tags == {"tag": "val", "newtag": "newval"}


def test_render_attributes_asset_context() -> None:
op = MergeAttributes(tags={"group_name_tag": "group__{{ asset.group_name }}"})

newdefs = op.apply(defs, TemplatedValueResolver.default().with_context(foo="theval"))
asset_graph = newdefs.get_asset_graph()
assert asset_graph.get(AssetKey("a")).tags == {"group_name_tag": "group__g1"}
assert asset_graph.get(AssetKey("b")).tags == {"group_name_tag": "group__g2"}
assert asset_graph.get(AssetKey("c")).tags == {"tag": "val", "group_name_tag": "group__g2"}


def test_render_attributes_custom_context() -> None:
op = ReplaceAttributes(
operation="replace", target="group:g2", tags={"a": "{{ foo }}", "b": "prefix_{{ foo }}"}
)

newdefs = op.apply(defs, TemplatedValueResolver.default().with_context(foo="theval"))
asset_graph = newdefs.get_asset_graph()
assert asset_graph.get(AssetKey("a")).tags == {}
assert asset_graph.get(AssetKey("b")).tags == {"a": "theval", "b": "prefix_theval"}
assert asset_graph.get(AssetKey("c")).tags == {"a": "theval", "b": "prefix_theval"}


@pytest.mark.parametrize(
"python,expected",
[
Expand Down

0 comments on commit 7eb1d13

Please sign in to comment.