Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a method for connecting providers to different inputs #192

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions src/sciline/data_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,29 @@ def visualize_data_graph(self, **kwargs: Any) -> graphviz.Digraph: # type: igno
dot.edge(str(edge[0]), str(edge[1]), label=label)
return dot

def override_input(self, original: Key, replacement: Key) -> DataGraph:
dg = self.copy()
dg._override_input_in_place(original, replacement)
return dg

def _override_input_in_place(self, original: Key, replacement: Key) -> None:
formal_args = getattr(replacement, '__args__', ())
assert getattr(original, '__args__', ()) == formal_args
for formal in formal_args:
# If `formal` is a type var, plug in all possible concrete types:
for actual in getattr(formal, '__constraints__', ()):
self._override_input_in_place(original[actual], replacement[actual])

_move_out_edges(self, original, replacement)


def _move_out_edges(graph: DataGraph, original: Key, replacement: Key) -> None:
"""Change all out edges from ``original`` to start from ``replacement``."""
g = graph.underlying_graph
to_replace = [edge for edge in g.edges(data=True) if edge[0] == original]
g.remove_edges_from(to_replace)
g.add_edges_from((replacement, *edge[1:]) for edge in to_replace)


_no_value = object()

Expand Down
262 changes: 262 additions & 0 deletions tests/pipeline_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1440,3 +1440,265 @@ def make_str(x: int) -> str:
getattr(pl, get_method)(int)
for key in pl.output_keys():
assert key_name(key) in info.value.args[0]


def test_override_input_disconnects_from_existing_parent() -> None:
def foo(x: int) -> str:
return str(x)

pipeline = sl.Pipeline([foo], params={int: 3})
new = pipeline.override_input(int, float)

with pytest.raises(sl.UnsatisfiedRequirement, match="No provider.*float"):
new.compute(str)


def test_override_input_disconnects_from_existing_parent_generic() -> None:
T1 = NewType('T1', int)
T2 = NewType('T2', int)
T = TypeVar('T', T1, T2)

class A(sl.Scope[T, int], int): ...

class B(sl.Scope[T, int], int): ...

class C(sl.Scope[T, int], int): ...

def foo(a: A[T]) -> B[T]:
return B[T](a + 2)

pipeline = sl.Pipeline([foo], params={A[T1]: A[T1](3), A[T2]: A[T2](4)})
new = pipeline.override_input(A[T], C[T])

with pytest.raises(sl.UnsatisfiedRequirement, match="No provider.*C"):
new.compute(B[T1])


def test_override_input_connects_to_new_parent() -> None:
def foo(x: int) -> str:
return str(x)

original = sl.Pipeline([foo], params={float: 6.5})
new = original.override_input(int, float)

assert new.compute(str) == "6.5"


def test_override_input_connects_to_new_parent_generic() -> None:
T1 = NewType('T1', int)
T2 = NewType('T2', int)
T = TypeVar('T', T1, T2)

class A(sl.Scope[T, int], int): ...

class B(sl.Scope[T, int], int): ...

class C(sl.Scope[T, int], int): ...

def foo(a: A[T]) -> B[T]:
return B[T](a + 2)

pipeline = sl.Pipeline([foo], params={C[T1]: C[T1](3), C[T2]: C[T2](4)})
new = pipeline.override_input(A[T], C[T])

assert new.compute(B[T1]) == B[T1](5)
assert new.compute(B[T2]) == B[T2](6)


def test_override_input_connects_to_new_parent_from_existing() -> None:
def foo(x: int) -> str:
return str(x)

original = sl.Pipeline([foo], params={int: 7, float: -0.3})
new = original.override_input(int, float)

assert new.compute(str) == "-0.3"


def test_override_input_connects_to_new_parent_from_existing_generic() -> None:
T1 = NewType('T1', int)
T2 = NewType('T2', int)
T = TypeVar('T', T1, T2)

class A(sl.Scope[T, int], int): ...

class B(sl.Scope[T, int], int): ...

class C(sl.Scope[T, int], int): ...

def foo(a: A[T]) -> B[T]:
return B[T](a + 2)

pipeline = sl.Pipeline(
[foo],
params={A[T1]: A[T1](3), A[T2]: A[T2](4), C[T1]: C[T1](-3), C[T2]: C[T2](-4)},
)
new = pipeline.override_input(A[T], C[T])

assert new.compute(B[T1]) == B[T1](-1)
assert new.compute(B[T2]) == B[T2](-2)


def test_override_input_multiple_consumers() -> None:
A = NewType('A', int)
B = NewType('B', int)
C = NewType('C', int)
Replacement = NewType('Replacement', int)

def plus_1(a: A) -> B:
return B(a + 1)

def times_5(b: B) -> C:
return C(b * 5)

def replicate(val: C, n: A) -> list[C]:
return [val] * n

original = sl.Pipeline([plus_1, times_5, replicate], params={A: 2})
new = original.override_input(A, Replacement)
new[Replacement] = 3

assert original.compute(list[C]) == [C(15), C(15)]
assert new.compute(list[C]) == [C(20), C(20), C(20)]


def test_override_input_multiple_consumers_generic() -> None:
T1 = NewType('T1', int)
T2 = NewType('T2', int)
T = TypeVar('T', T1, T2)

class A(sl.Scope[T, int], int): ...

class B(sl.Scope[T, int], int): ...

class C(sl.Scope[T, int], int): ...

class Replacement(sl.Scope[T, int], int): ...

def plus_1(a: A[T]) -> B[T]:
return B[T](a + 1)

def times_5(b: B[T]) -> C[T]:
return C[T](b * 5)

def replicate(val: C[T], n: A[T]) -> list[C[T]]:
return [val] * n

original = sl.Pipeline([plus_1, times_5, replicate], params={A[T1]: 2, A[T2]: 1})
new = original.override_input(A[T], Replacement[T])
new[Replacement[T1]] = 3
new[Replacement[T2]] = 4

assert original.compute(list[C[T1]]) == [C[T1](15)] * 2
assert original.compute(list[C[T2]]) == [C[T2](10)] * 1
assert new.compute(list[C[T1]]) == [C[T1](20)] * 3
assert new.compute(list[C[T2]]) == [C[T2](25)] * 4


def test_override_input_multiple_consumers_generic_concrete_value() -> None:
T1 = NewType('T1', int)
T2 = NewType('T2', int)
T = TypeVar('T', T1, T2)

class A(sl.Scope[T, int], int): ...

class B(sl.Scope[T, int], int): ...

class C(sl.Scope[T, int], int): ...

class Replacement(sl.Scope[T, int], int): ...

def plus_1(a: A[T]) -> B[T]:
return B[T](a + 1)

def times_5(b: B[T]) -> C[T]:
return C[T](b * 5)

def replicate(val: C[T], n: A[T]) -> list[C[T]]:
return [val] * n

original = sl.Pipeline([plus_1, times_5, replicate], params={A[T1]: 2, A[T2]: 1})
new = original.override_input(A[T1], Replacement[T1])
new[Replacement[T1]] = 3
new[Replacement[T2]] = 4

assert original.compute(list[C[T1]]) == [C[T1](15)] * 2
assert original.compute(list[C[T2]]) == [C[T2](10)] * 1
assert new.compute(list[C[T1]]) == [C[T1](20)] * 3 # replaced
assert new.compute(list[C[T2]]) == [C[T2](10)] * 1 # unchanged


def test_override_input_multiple_consumers_generic_concrete_value_cross_bound() -> None:
T1 = NewType('T1', int)
T2 = NewType('T2', int)
T = TypeVar('T', T1, T2)

class A(sl.Scope[T, int], int): ...

class B(sl.Scope[T, int], int): ...

class C(sl.Scope[T, int], int): ...

class Replacement(sl.Scope[T, int], int): ...

def plus_1(a: A[T]) -> B[T]:
return B[T](a + 1)

def times_5(b: B[T]) -> C[T]:
return C[T](b * 5)

def replicate(val: C[T], n: A[T]) -> list[C[T]]:
return [val] * n

original = sl.Pipeline([plus_1, times_5, replicate], params={A[T1]: 2, A[T2]: 1})
new = original.override_input(A[T2], Replacement[T1]) # T2 -> T1
new[Replacement[T1]] = 3
new[Replacement[T2]] = 4

assert original.compute(list[C[T1]]) == [C[T1](15)] * 2
assert original.compute(list[C[T2]]) == [C[T2](10)] * 1
assert new.compute(list[C[T1]]) == [C[T1](15)] * 2 # unchanged
assert new.compute(list[C[T2]]) == [C[T2](20)] * 3 # got Replacement[T1]


def test_override_input_does_not_remove_nodes() -> None:
def foo(x: int) -> str:
return str(x)

original = sl.Pipeline([foo], params={int: 7, float: -0.3})
new = original.override_input(int, float)
assert new.compute(int) == 7 # int is still there


def test_override_input_does_not_change_outputs() -> None:
def foo(x: int) -> str:
return str(x)

original = sl.Pipeline([foo], params={int: 7})
new = original.override_input(str, float)
assert new.compute(str) == "7" # output of foo is unchanged


def test_override_input_does_nothing_if_input_does_not_exist() -> None:
def foo(x: int) -> str:
return str(x)

original = sl.Pipeline([foo], params={int: 7})
new = original.override_input(float, int) # there is no float in the graph
assert new.compute(str) == "7"


def test_override_input_preserves_uses_of_override_type() -> None:
def foo(x: float) -> str:
return str(x)

original = sl.Pipeline([foo], params={int: 7, float: 4.4})
new = original.override_input(int, float) # keeps float arg of foo
assert new.compute(str) == "4.4"


# TODO generic -> regular is ambiguous
# TODO regular -> generic works, duplicates regular
# TODO multiple type vars
# TODO nested type vars
# TODO what if later provider added?
Loading