Skip to content
This repository has been archived by the owner on Jul 3, 2023. It is now read-only.

Prototype for typed_tag. Solve #277

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 86 additions & 3 deletions hamilton/function_modifiers/metadata.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,13 @@
from typing import Any, Callable, Dict
import sys
import typing
from typing import Any, Callable, Dict, Type

if sys.version_info < (3, 8):
from typing_extensions import TypedDict
else:
from typing import TypedDict

import typing_inspect

from hamilton import node
from hamilton.function_modifiers import base
Expand Down Expand Up @@ -45,13 +54,17 @@ class tag(base.NodeDecorator):
"module",
] # Anything that starts with any of these is banned, the framework reserves the right to manage it

def __init__(self, **tags: str):
def __init__(self, *, __validate_tag_types: bool = True, **tags: str):
"""Constructor for adding tag annotations to a function.

:param tags: the keys are always going to be strings, so the type annotation here means the values are strings.
Implicitly this is `Dict[str, str]` but the PEP guideline is to only annotate it with `str`.
:param __validate_tag_types: If true, we validate the types of the tags. This is called by the framework, and
should not be called by users. If you want to have more than just str valued tags, consider using typed tags
as specified below.
"""
self.tags = tags
self.__validate_tag_types = __validate_tag_types

def decorate_node(self, node_: node.Node) -> node.Node:
"""Decorates the nodes produced by this with the specified tags
Expand Down Expand Up @@ -105,8 +118,11 @@ def validate(self, fn: Callable):
"""
bad_tags = set()
for key, value in self.tags.items():
if (not tag._key_allowed(key)) or (not tag._value_allowed(value)):
if not tag._key_allowed(key):
bad_tags.add((key, value))
if not tag._value_allowed(value) and not self.__validate_tag_types:
bad_tags.add((key, value))

if bad_tags:
bad_tags_formatted = ",".join([f"{key}={value}" for key, value in bad_tags])
raise base.InvalidDecoratorException(
Expand All @@ -132,3 +148,70 @@ def decorate_node(self, node_: node.Node) -> node.Node:
new_tags = node_.tags.copy()
new_tags.update(self.tag_mapping.get(node_.name, {}))
return tag(**new_tags).decorate_node(node_)


# class TypedTagSet(TypedDict):
# """A typed tag set is a dictionary of tags that are typed. We do additional validation on this
# to ensure that the right types are created and that that ri"""


def _type_allowed(type: Type[Type], allow_lists: bool = True) -> bool:
"""Validates that a type is allowed. We only allow primitive types and lists of primitive types"""
if type in [int, float, str, bool]:
return True
if allow_lists:
if typing_inspect.is_generic_type(type):
if typing_inspect.get_origin(type) == list:
return _type_allowed(typing_inspect.get_args(type)[0], allow_lists=False)
return False


def _validate_spec(typed_dict_class: Type[TypedDict]):
invalid_types = []
for key, value in typing.get_type_hints(typed_dict_class).items():
if not _type_allowed(value, allow_lists=True):
invalid_types.append((key, value))
if invalid_types:
invalid_types_formatted = ",".join([f"{key}={value}" for key, value in invalid_types])
raise base.InvalidDecoratorException(
f"The following key/value pairs are invalid as types: {invalid_types_formatted} "
"Types can be any primitive type or a list of a primitive type."
)


def _type_matches(value: Any, type_: Type[Type]):
if type_ in [int, float, str, bool]:
return isinstance(value, type_)
if typing_inspect.is_generic_type(type_):
if typing_inspect.get_origin(type_) == list:
return isinstance(value, list) and all(
_type_matches(item, typing_inspect.get_args(type_)[0]) for item in value
)
return False


def _validate_values(typed_dict: dict, typed_dict_class: Type[TypedDict]):
invalid_pairs = []
for key, value in typed_dict.items():
if not _type_matches(value, typing.get_type_hints(typed_dict_class)[key]):
invalid_pairs.append((key, value))
if invalid_pairs:
invalid_pairs_formatted = ",".join([f"{key}={value}" for key, value in invalid_pairs])
raise base.InvalidDecoratorException(
f"The following key/value pairs are invalid as values: {invalid_pairs_formatted} "
"Values must match the specified type."
)


def validate_typed_dict(data: dict, typed_dict_class: TypedDict):
_validate_spec(typed_dict_class)
_validate_values(data, typed_dict_class)


class typed_tags:
def __init__(self, typed_tag_class: TypedDict):
self.tag_set_type = typed_tag_class

def __call__(self, **kwargs: Any):
validate_typed_dict(dict(**kwargs), self.tag_set_type)
return tag(**kwargs, __validate_tag_types=False) # types are already validated
89 changes: 89 additions & 0 deletions tests/function_modifiers/test_metadata.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,17 @@
import sys
from typing import Dict, List

if sys.version_info < (3, 8):
from typing_extensions import TypedDict
else:
from typing import TypedDict

import pandas as pd
import pytest

from hamilton import function_modifiers, node
from hamilton.function_modifiers.base import InvalidDecoratorException
from hamilton.function_modifiers.metadata import typed_tags


def test_tags():
Expand Down Expand Up @@ -108,3 +118,82 @@ def dummy_tagged_function() -> pd.DataFrame:
assert node_map["b"].tags["tag_b_gets"] == "tag_value_b_gets"
assert node_map["a"].tags["tag_key_everyone_gets"] == "tag_value_just_a_gets"
assert node_map["b"].tags["tag_key_everyone_gets"] == "tag_value_everyone_gets"


def test_typed_tags_success():
"""Tests the typed_tags decorator to ensure that it works in the basic case"""

class FooType(TypedDict):
foo: str
bar: int

foo = typed_tags(FooType)

def dummy_tagged_function() -> int:
"""dummy doc"""
return 1

node_ = foo(foo="foo", bar=1).decorate_node(node.Node.from_fn(dummy_tagged_function))
assert node_.tags["foo"] == "foo"
assert node_.tags["bar"] == 1


def test_typed_tags_wrong_type_failure():
"""Tests the typed_tags decorator to ensure that it breaks when the wrong types are passed"""

class FooType(TypedDict):
foo: str
bar: int

foo = typed_tags(FooType)

with pytest.raises(InvalidDecoratorException):

@foo(foo=1, bar="bar")
def dummy_tagged_function() -> int:
"""dummy doc"""
return 1


def test_typed_tags_illegal_types_failure():
"""Tests the typed_tags decorator to ensure that it breaks when illegal types are declared"""

class FooType(TypedDict):
foo: Dict[str, dict]
bar: List[List[int]]

foo = typed_tags(FooType)

with pytest.raises(InvalidDecoratorException):

@foo(foo=1, bar="bar")
def dummy_tagged_function() -> int:
"""dummy doc"""
return 1


def test_layered_tags_success():
"""Tests to ensure that layered tags are applied appropriately"""

class FooType(TypedDict):
foo: str
bar: int

class BarType(TypedDict):
bat: int
baz: List[int]

foo = typed_tags(FooType)
bar = typed_tags(BarType)

@foo(foo="foo", bar=1)
@bar(bat=2, baz=[1, 2, 3])
def dummy_tagged_function() -> int:
"""dummy doc"""
return 1

(node_,) = function_modifiers.base.resolve_nodes(dummy_tagged_function, {})
assert node_.tags["foo"] == "foo"
assert node_.tags["bar"] == 1
assert node_.tags["bat"] == 2
assert node_.tags["baz"] == [1, 2, 3]