From f046fd08142023dfd83e3ff09897a0752ce88dea Mon Sep 17 00:00:00 2001 From: Jan-Lukas Wynen Date: Wed, 22 Jan 2025 15:24:14 +0100 Subject: [PATCH 1/2] Start DataGraph.override_input --- src/sciline/data_graph.py | 23 ++++ tests/pipeline_test.py | 253 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 276 insertions(+) diff --git a/src/sciline/data_graph.py b/src/sciline/data_graph.py index 5eab25c2..e7530c2a 100644 --- a/src/sciline/data_graph.py +++ b/src/sciline/data_graph.py @@ -177,6 +177,29 @@ def visualize_data_graph(self, **kwargs: Any) -> graphviz.Digraph: # type: igno dot.edge(str(edge[0]), str(edge[1]), label=label) return dot + def override_input(self, original: Key, replacement: Key) -> DataGraph: + dg = self.copy() + dg._override_input_in_place(original, replacement) + return dg + + def _override_input_in_place(self, original: Key, replacement: Key) -> None: + formal_args = getattr(replacement, '__args__', ()) + assert getattr(original, '__args__', ()) == formal_args + for formal in formal_args: + # If `formal` is a type var, plug in all possible concrete types: + for actual in getattr(formal, '__constraints__', ()): + self._override_input_in_place(original[actual], replacement[actual]) + + _move_out_edges(self, original, replacement) + + +def _move_out_edges(graph: DataGraph, original: Key, replacement: Key) -> None: + """Change all out edges from ``original`` to start from ``replacement``.""" + g = graph.underlying_graph + to_replace = [edge for edge in g.edges(data=True) if edge[0] == original] + g.remove_edges_from(to_replace) + g.add_edges_from((replacement, *edge[1:]) for edge in to_replace) + _no_value = object() diff --git a/tests/pipeline_test.py b/tests/pipeline_test.py index c3a63183..705e89fb 100644 --- a/tests/pipeline_test.py +++ b/tests/pipeline_test.py @@ -1440,3 +1440,256 @@ def make_str(x: int) -> str: getattr(pl, get_method)(int) for key in pl.output_keys(): assert key_name(key) in info.value.args[0] + + +def test_override_input_disconnects_from_existing_parent() -> None: + def foo(x: int) -> str: + return str(x) + + pipeline = sl.Pipeline([foo], params={int: 3}) + new = pipeline.override_input(int, float) + + with pytest.raises(sl.UnsatisfiedRequirement, match="No provider.*float"): + new.compute(str) + + +def test_override_input_disconnects_from_existing_parent_generic() -> None: + T1 = NewType('T1', int) + T2 = NewType('T2', int) + T = TypeVar('T', T1, T2) + + class A(sl.Scope[T, int], int): ... + + class B(sl.Scope[T, int], int): ... + + class C(sl.Scope[T, int], int): ... + + def foo(a: A[T]) -> B[T]: + return B[T](a + 2) + + pipeline = sl.Pipeline([foo], params={A[T1]: A[T1](3), A[T2]: A[T2](4)}) + new = pipeline.override_input(A[T], C[T]) + + with pytest.raises(sl.UnsatisfiedRequirement, match="No provider.*C"): + new.compute(B[T1]) + + +def test_override_input_connects_to_new_parent() -> None: + def foo(x: int) -> str: + return str(x) + + original = sl.Pipeline([foo], params={float: 6.5}) + new = original.override_input(int, float) + + assert new.compute(str) == "6.5" + + +def test_override_input_connects_to_new_parent_generic() -> None: + T1 = NewType('T1', int) + T2 = NewType('T2', int) + T = TypeVar('T', T1, T2) + + class A(sl.Scope[T, int], int): ... + + class B(sl.Scope[T, int], int): ... + + class C(sl.Scope[T, int], int): ... + + def foo(a: A[T]) -> B[T]: + return B[T](a + 2) + + pipeline = sl.Pipeline([foo], params={C[T1]: C[T1](3), C[T2]: C[T2](4)}) + new = pipeline.override_input(A[T], C[T]) + + assert new.compute(B[T1]) == B[T1](5) + assert new.compute(B[T2]) == B[T2](6) + + +def test_override_input_connects_to_new_parent_from_existing() -> None: + def foo(x: int) -> str: + return str(x) + + original = sl.Pipeline([foo], params={int: 7, float: -0.3}) + new = original.override_input(int, float) + + assert new.compute(str) == "-0.3" + + +def test_override_input_connects_to_new_parent_from_existing_generic() -> None: + T1 = NewType('T1', int) + T2 = NewType('T2', int) + T = TypeVar('T', T1, T2) + + class A(sl.Scope[T, int], int): ... + + class B(sl.Scope[T, int], int): ... + + class C(sl.Scope[T, int], int): ... + + def foo(a: A[T]) -> B[T]: + return B[T](a + 2) + + pipeline = sl.Pipeline( + [foo], + params={A[T1]: A[T1](3), A[T2]: A[T2](4), C[T1]: C[T1](-3), C[T2]: C[T2](-4)}, + ) + new = pipeline.override_input(A[T], C[T]) + + assert new.compute(B[T1]) == B[T1](-1) + assert new.compute(B[T2]) == B[T2](-2) + + +def test_override_input_multiple_consumers() -> None: + A = NewType('A', int) + B = NewType('B', int) + C = NewType('C', int) + Replacement = NewType('Replacement', int) + + def plus_1(a: A) -> B: + return B(a + 1) + + def times_5(b: B) -> C: + return C(b * 5) + + def replicate(val: C, n: A) -> list[C]: + return [val] * n + + original = sl.Pipeline([plus_1, times_5, replicate], params={A: 2}) + new = original.override_input(A, Replacement) + new[Replacement] = 3 + + assert original.compute(list[C]) == [C(15), C(15)] + assert new.compute(list[C]) == [C(20), C(20), C(20)] + + +def test_override_input_multiple_consumers_generic() -> None: + T1 = NewType('T1', int) + T2 = NewType('T2', int) + T = TypeVar('T', T1, T2) + + class A(sl.Scope[T, int], int): ... + + class B(sl.Scope[T, int], int): ... + + class C(sl.Scope[T, int], int): ... + + class Replacement(sl.Scope[T, int], int): ... + + def plus_1(a: A[T]) -> B[T]: + return B[T](a + 1) + + def times_5(b: B[T]) -> C[T]: + return C[T](b * 5) + + def replicate(val: C[T], n: A[T]) -> list[C[T]]: + return [val] * n + + original = sl.Pipeline([plus_1, times_5, replicate], params={A[T1]: 2, A[T2]: 1}) + new = original.override_input(A[T], Replacement[T]) + new[Replacement[T1]] = 3 + new[Replacement[T2]] = 4 + + assert original.compute(list[C[T1]]) == [C[T1](15)] * 2 + assert original.compute(list[C[T2]]) == [C[T2](10)] * 1 + assert new.compute(list[C[T1]]) == [C[T1](20)] * 3 + assert new.compute(list[C[T2]]) == [C[T2](25)] * 4 + + +def test_override_input_multiple_consumers_generic_concrete_value() -> None: + T1 = NewType('T1', int) + T2 = NewType('T2', int) + T = TypeVar('T', T1, T2) + + class A(sl.Scope[T, int], int): ... + + class B(sl.Scope[T, int], int): ... + + class C(sl.Scope[T, int], int): ... + + class Replacement(sl.Scope[T, int], int): ... + + def plus_1(a: A[T]) -> B[T]: + return B[T](a + 1) + + def times_5(b: B[T]) -> C[T]: + return C[T](b * 5) + + def replicate(val: C[T], n: A[T]) -> list[C[T]]: + return [val] * n + + original = sl.Pipeline([plus_1, times_5, replicate], params={A[T1]: 2, A[T2]: 1}) + new = original.override_input(A[T1], Replacement[T1]) + new[Replacement[T1]] = 3 + new[Replacement[T2]] = 4 + + assert original.compute(list[C[T1]]) == [C[T1](15)] * 2 + assert original.compute(list[C[T2]]) == [C[T2](10)] * 1 + assert new.compute(list[C[T1]]) == [C[T1](20)] * 3 # replaced + assert new.compute(list[C[T2]]) == [C[T2](10)] * 1 # unchanged + + +def test_override_input_multiple_consumers_generic_concrete_value_cross_bound() -> None: + T1 = NewType('T1', int) + T2 = NewType('T2', int) + T = TypeVar('T', T1, T2) + + class A(sl.Scope[T, int], int): ... + + class B(sl.Scope[T, int], int): ... + + class C(sl.Scope[T, int], int): ... + + class Replacement(sl.Scope[T, int], int): ... + + def plus_1(a: A[T]) -> B[T]: + return B[T](a + 1) + + def times_5(b: B[T]) -> C[T]: + return C[T](b * 5) + + def replicate(val: C[T], n: A[T]) -> list[C[T]]: + return [val] * n + + original = sl.Pipeline([plus_1, times_5, replicate], params={A[T1]: 2, A[T2]: 1}) + new = original.override_input(A[T2], Replacement[T1]) # T2 -> T1 + new[Replacement[T1]] = 3 + new[Replacement[T2]] = 4 + + assert original.compute(list[C[T1]]) == [C[T1](15)] * 2 + assert original.compute(list[C[T2]]) == [C[T2](10)] * 1 + assert new.compute(list[C[T1]]) == [C[T1](15)] * 2 # unchanged + assert new.compute(list[C[T2]]) == [C[T2](20)] * 3 # got Replacement[T1] + + +def test_override_input_does_not_remove_nodes() -> None: + def foo(x: int) -> str: + return str(x) + + original = sl.Pipeline([foo], params={int: 7, float: -0.3}) + new = original.override_input(int, float) + assert new.compute(int) == 7 # int is still there + + +def test_override_input_does_not_change_outputs() -> None: + def foo(x: int) -> str: + return str(x) + + original = sl.Pipeline([foo], params={int: 7}) + new = original.override_input(str, float) + assert new.compute(str) == "7" # output of foo is unchanged + + +def test_override_input_does_nothing_if_input_does_not_exist() -> None: + def foo(x: int) -> str: + return str(x) + + original = sl.Pipeline([foo], params={int: 7}) + new = original.override_input(float, int) # there is no float in the graph + assert new.compute(str) == "7" + + +# TODO generic -> regular is ambiguous +# TODO regular -> generic works, duplicates regular +# TODO multiple type vars +# TODO nested type vars +# TODO what if later provider added? From 777218b2b8c0ea9232b743124693f8e4d803bbbe Mon Sep 17 00:00:00 2001 From: Jan-Lukas Wynen Date: Thu, 23 Jan 2025 09:18:37 +0100 Subject: [PATCH 2/2] Add another test --- tests/pipeline_test.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/tests/pipeline_test.py b/tests/pipeline_test.py index 705e89fb..a45819a4 100644 --- a/tests/pipeline_test.py +++ b/tests/pipeline_test.py @@ -1688,6 +1688,15 @@ def foo(x: int) -> str: assert new.compute(str) == "7" +def test_override_input_preserves_uses_of_override_type() -> None: + def foo(x: float) -> str: + return str(x) + + original = sl.Pipeline([foo], params={int: 7, float: 4.4}) + new = original.override_input(int, float) # keeps float arg of foo + assert new.compute(str) == "4.4" + + # TODO generic -> regular is ambiguous # TODO regular -> generic works, duplicates regular # TODO multiple type vars