Skip to content

Commit

Permalink
[components] Resolver -> Renderer
Browse files Browse the repository at this point in the history
  • Loading branch information
OwenKephart committed Dec 31, 2024
1 parent 174254c commit 59cb5ae
Show file tree
Hide file tree
Showing 11 changed files with 188 additions and 44 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from pydantic import TypeAdapter
from typing_extensions import Self

from dagster_components.core.component_rendering import TemplatedValueResolver
from dagster_components.core.component_rendering import TemplatedValueRenderer


class ComponentDeclNode: ...
Expand Down Expand Up @@ -212,7 +212,7 @@ class ComponentLoadContext:
resources: Mapping[str, object]
registry: ComponentTypeRegistry
decl_node: Optional[ComponentDeclNode]
templated_value_resolver: TemplatedValueResolver
templated_value_renderer: TemplatedValueRenderer

@staticmethod
def for_test(
Expand All @@ -225,7 +225,7 @@ def for_test(
resources=resources or {},
registry=registry or ComponentTypeRegistry.empty(),
decl_node=decl_node,
templated_value_resolver=TemplatedValueResolver.default(),
templated_value_renderer=TemplatedValueRenderer.default(),
)

@property
Expand All @@ -240,7 +240,7 @@ def path(self) -> Path:
def with_rendering_scope(self, rendering_scope: Mapping[str, Any]) -> "ComponentLoadContext":
return dataclasses.replace(
self,
templated_value_resolver=self.templated_value_resolver.with_context(**rendering_scope),
templated_value_renderer=self.templated_value_renderer.with_context(**rendering_scope),
)

def for_decl_node(self, decl_node: ComponentDeclNode) -> "ComponentLoadContext":
Expand All @@ -255,7 +255,7 @@ def _raw_params(self) -> Optional[Mapping[str, Any]]:

def load_params(self, params_schema: Type[T]) -> T:
with pushd(str(self.path)):
preprocessed_params = self.templated_value_resolver.render_params(
preprocessed_params = self.templated_value_renderer.render_params(
self._raw_params(), params_schema
)
return TypeAdapter(params_schema).validate_python(preprocessed_params)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
Component,
ComponentLoadContext,
ComponentTypeRegistry,
TemplatedValueResolver,
TemplatedValueRenderer,
get_component_type_name,
is_registered_component_type,
)
Expand Down Expand Up @@ -103,7 +103,7 @@ def build_defs_from_component_path(
resources=resources,
registry=registry,
decl_node=decl_node,
templated_value_resolver=TemplatedValueResolver.default(),
templated_value_renderer=TemplatedValueRenderer.default(),
)
components = load_components_from_context(context)
return defs_from_components(resources=resources, context=context, components=components)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,18 +81,18 @@ class RenderedModel(BaseModel):
model_config = ConfigDict(json_schema_extra={JSON_SCHEMA_EXTRA_KEY: True})

def _render_property(
self, key: str, raw_value: Any, value_resolver: "TemplatedValueResolver"
self, key: str, raw_value: Any, value_renderer: "TemplatedValueRenderer"
) -> Any:
return value_resolver.render_obj(raw_value)
return value_renderer.render_obj(raw_value)

def render_properties(self, value_resolver: "TemplatedValueResolver") -> Mapping[str, Any]:
def render_properties(self, value_renderer: "TemplatedValueRenderer") -> Mapping[str, Any]:
"""Returns a dictionary of rendered properties for this class."""
raw_properties = self.model_dump(exclude_unset=True)

# validate that the rendered properties match the output type
rendered_properties = {}
for k, v in raw_properties.items():
rendered = self._render_property(k, v, value_resolver)
rendered = self._render_property(k, v, value_renderer)
annotation = self.__annotations__[k]
expected_type = _get_expected_type(annotation)
if expected_type is not None:
Expand All @@ -106,17 +106,17 @@ def render_properties(self, value_resolver: "TemplatedValueResolver") -> Mapping


@record
class TemplatedValueResolver:
class TemplatedValueRenderer:
context: Mapping[str, Any]

@staticmethod
def default() -> "TemplatedValueResolver":
return TemplatedValueResolver(
def default() -> "TemplatedValueRenderer":
return TemplatedValueRenderer(
context={"env": _env, "automation_condition": automation_condition_scope()}
)

def with_context(self, **additional_context) -> "TemplatedValueResolver":
return TemplatedValueResolver(context={**self.context, **additional_context})
def with_context(self, **additional_context) -> "TemplatedValueRenderer":
return TemplatedValueRenderer(context={**self.context, **additional_context})

def _render_value(self, val: Any) -> Any:
"""Renders a single value, if it is a templated string."""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from dagster_components.core.component_rendering import (
RenderedModel,
RenderingMetadata,
TemplatedValueResolver,
TemplatedValueRenderer,
)


Expand All @@ -42,8 +42,8 @@ class AssetAttributesModel(RenderedModel):
Optional[str], RenderingMetadata(output_type=Optional[AutomationCondition])
] = None

def _render_property(self, key, raw_value, value_resolver):
rendered = super()._render_property(key, raw_value, value_resolver)
def _render_property(self, key, raw_value, value_renderer):
rendered = super()._render_property(key, raw_value, value_renderer)
if key == "key":
# coerce the string asset key into an AssetKey object
return AssetKey.from_user_string(rendered) if rendered else None
Expand All @@ -62,24 +62,24 @@ def _apply_to_spec(self, spec: AssetSpec, attributes: Mapping[str, Any]) -> Asse
def apply_to_spec(
self,
spec: AssetSpec,
value_resolver: TemplatedValueResolver,
value_renderer: TemplatedValueRenderer,
target_keys: AbstractSet[AssetKey],
) -> AssetSpec:
if spec.key not in target_keys:
return spec

# add the original spec to the context and resolve values
return self._apply_to_spec(
spec, self.attributes.render_properties(value_resolver.with_context(asset=spec))
spec, self.attributes.render_properties(value_renderer.with_context(asset=spec))
)

def apply(self, defs: Definitions, value_resolver: TemplatedValueResolver) -> Definitions:
def apply(self, defs: Definitions, value_renderer: TemplatedValueRenderer) -> Definitions:
target_selection = AssetSelection.from_string(self.target, include_sources=True)
target_keys = target_selection.resolve(defs.get_asset_graph())

mappable = [d for d in defs.assets or [] if isinstance(d, (AssetsDefinition, AssetSpec))]
mapped_assets = map_asset_specs(
lambda spec: self.apply_to_spec(spec, value_resolver, target_keys),
lambda spec: self.apply_to_spec(spec, value_renderer, target_keys),
mappable,
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from dagster_components import Component, ComponentLoadContext
from dagster_components.core.component import (
ComponentGenerateRequest,
TemplatedValueResolver,
TemplatedValueRenderer,
component_type,
)
from dagster_components.core.dsl_schema import (
Expand Down Expand Up @@ -43,16 +43,16 @@ def __init__(
self,
*,
params: Optional[AssetAttributesModel],
value_resolver: TemplatedValueResolver,
value_renderer: TemplatedValueRenderer,
):
self.params = params or AssetAttributesModel()
self.value_resolver = value_resolver
self.value_renderer = value_renderer

def _get_rendered_attribute(
self, attribute: str, dbt_resource_props: Mapping[str, Any], default_method
) -> Any:
resolver = self.value_resolver.with_context(node=dbt_resource_props)
rendered_attribute = self.params.render_properties(resolver).get(attribute)
renderer = self.value_renderer.with_context(node=dbt_resource_props)
rendered_attribute = self.params.render_properties(renderer).get(attribute)
return (
rendered_attribute
if rendered_attribute is not None
Expand Down Expand Up @@ -102,7 +102,7 @@ def load(cls, context: ComponentLoadContext) -> Self:
op_spec=loaded_params.op,
dbt_translator=DbtProjectComponentTranslator(
params=loaded_params.translator,
value_resolver=context.templated_value_resolver,
value_renderer=context.templated_value_renderer,
),
asset_processors=loaded_params.asset_attributes or [],
)
Expand All @@ -123,7 +123,7 @@ def _fn(context: AssetExecutionContext):

defs = Definitions(assets=[_fn])
for transform in self.asset_processors:
defs = transform.apply(defs, context.templated_value_resolver)
defs = transform.apply(defs, context.templated_value_renderer)
return defs

@classmethod
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def load(cls, context: ComponentLoadContext) -> "PipesSubprocessScriptCollection
if not script_path.exists():
raise FileNotFoundError(f"Script {script_path} does not exist")
path_specs[script_path] = [
AssetSpec(**asset.render_properties(context.templated_value_resolver))
AssetSpec(**asset.render_properties(context.templated_value_renderer))
for asset in script.assets
]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def _fn(context: AssetExecutionContext, sling: SlingResource):

defs = Definitions(assets=[_fn], resources={"sling": self.resource})
for transform in self.asset_processors:
defs = transform.apply(defs, context.templated_value_resolver)
defs = transform.apply(defs, context.templated_value_renderer)
return defs

@classmethod
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
from pathlib import Path
from typing import Any, Iterator, Mapping, Optional, Sequence, Union

from dagster._core.definitions.asset_key import AssetKey
from dagster._core.definitions.assets import AssetsDefinition
from dagster._core.definitions.auto_materialize_policy import AutoMaterializePolicy
from dagster._core.definitions.definitions_class import Definitions
from dagster._core.definitions.events import AssetMaterialization
from dagster._core.definitions.result import MaterializeResult
from dagster_embedded_elt.sling import DagsterSlingTranslator, SlingResource, sling_assets
from dagster_embedded_elt.sling.resources import AssetExecutionContext
from pydantic import BaseModel
from typing_extensions import Self

from dagster_components import Component, ComponentLoadContext
from dagster_components.core.component import (
ComponentGenerateRequest,
TemplatedValueRenderer,
component_type,
)
from dagster_components.core.dsl_schema import (
AssetAttributes,
AssetAttributesModel,
AssetSpecProcessor,
OpSpecBaseModel,
)
from dagster_components.generate import generate_component_yaml


class SlingReplicationParams(BaseModel):
path: str
op: Optional[OpSpecBaseModel] = None
translator: Optional[AssetAttributesModel]


class SlingReplicationCollectionParams(BaseModel):
sling: Optional[SlingResource] = None
replications: Sequence[SlingReplicationParams]
asset_attributes: Optional[AssetAttributes] = None


class SlingReplicationTranslator(DagsterSlingTranslator):
def __init__(
self,
*,
params: Optional[AssetAttributesModel],
value_renderer: TemplatedValueRenderer,
):
self.params = params or AssetAttributesModel()
self.value_renderer = value_renderer

def _get_rendered_attribute(
self, attribute: str, stream_definition: Mapping[str, Any], default_method
) -> Any:
renderer = self.value_renderer.with_context(stream_definition=stream_definition)
rendered_attribute = self.params.render_properties(renderer).get(attribute)
return (
rendered_attribute
if rendered_attribute is not None
else default_method(stream_definition)
)

def get_asset_key(self, stream_definition: Mapping[str, Any]) -> AssetKey:
return self._get_rendered_attribute("key", stream_definition, super().get_asset_key)

def get_group_name(self, stream_definition: Mapping[str, Any]) -> Optional[str]:
return self._get_rendered_attribute("group_name", stream_definition, super().get_group_name)

def get_tags(self, stream_definition: Mapping[str, Any]) -> Mapping[str, str]:
return self._get_rendered_attribute("tags", stream_definition, super().get_tags)

def get_metadata(self, stream_definition: Mapping[str, Any]) -> Mapping[str, Any]:
return self._get_rendered_attribute("metadata", stream_definition, super().get_metadata)

def get_auto_materialize_policy(
self, stream_definition: Mapping[str, Any]
) -> Optional[AutoMaterializePolicy]:
return self._get_rendered_attribute(
"auto_materialize_policy", stream_definition, super().get_auto_materialize_policy
)


@component_type(name="sling_replication_collection")
class SlingReplicationCollectionComponent(Component):
params_schema = SlingReplicationCollectionParams

def __init__(
self,
dirpath: Path,
resource: SlingResource,
sling_replications: Sequence[SlingReplicationParams],
asset_attributes: Sequence[AssetSpecProcessor],
):
self.dirpath = dirpath
self.resource = resource
self.sling_replications = sling_replications
self.asset_attributes = asset_attributes

@classmethod
def load(cls, context: ComponentLoadContext) -> Self:
loaded_params = context.load_params(cls.params_schema)
return cls(
dirpath=context.path,
resource=loaded_params.sling or SlingResource(),
sling_replications=loaded_params.replications,
asset_attributes=loaded_params.asset_attributes or [],
)

def build_replication_asset(
self, context: ComponentLoadContext, replication: SlingReplicationParams
) -> AssetsDefinition:
@sling_assets(
name=replication.op.name if replication.op else Path(replication.path).stem,
op_tags=replication.op.tags if replication.op else {},
replication_config=self.dirpath / replication.path,
dagster_sling_translator=SlingReplicationTranslator(
params=replication.translator,
value_renderer=context.templated_value_renderer,
),
)
def _asset(context: AssetExecutionContext):
yield from self.execute(context=context, sling=self.resource)

return _asset

def execute(
self, context: AssetExecutionContext, sling: SlingResource
) -> Iterator[Union[AssetMaterialization, MaterializeResult]]:
yield from sling.replicate(context=context)

def build_defs(self, context: ComponentLoadContext) -> Definitions:
defs = Definitions(
assets=[
self.build_replication_asset(context, replication)
for replication in self.sling_replications
],
)
for transform in self.asset_attributes:
defs = transform.apply(defs, context.templated_value_renderer)
return defs

@classmethod
def generate_files(cls, request: ComponentGenerateRequest, params: Any) -> None:
generate_component_yaml(request, params)
Loading

0 comments on commit 59cb5ae

Please sign in to comment.