From bb4a795ac79f9538e61de218ba17147cfcfb2d35 Mon Sep 17 00:00:00 2001 From: Stefan Ulbrich Date: Wed, 9 Mar 2022 00:20:10 +0100 Subject: [PATCH] Fix pattern matching and add unit tests --- pyproject.toml | 2 +- src/design_by_contract.py | 20 +++++++------ tests/test_dbc.py | 59 ++++++++++++++++++++++++++++++++++----- 3 files changed, 65 insertions(+), 16 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 445a837..86f849d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "design-by-contract" -version = "0.2" +version = "0.2.1" description = "Handy decorator to define contracts with dependency injection in Python 3.10 and above" authors = ["Stefan Ulbrich"] license = "MIT" diff --git a/src/design_by_contract.py b/src/design_by_contract.py index 0a6b922..3b673b9 100644 --- a/src/design_by_contract.py +++ b/src/design_by_contract.py @@ -23,6 +23,7 @@ class UnresolvedSymbol: Overrides the equality operator to behave like an assignment. """ + name: str value: Optional[Any] = None @@ -32,11 +33,15 @@ def __eq__(self, other: Any) -> "UnresolvedSymbol": if self.value is None: raise ContractViolationError(f"Symbols `{self.name}` and `{other.name}` undefined") other.value = self.value - case UnresolvedSymbol(value) if value != self.value: + case UnresolvedSymbol(_, value) if self.value is None: + self.value = value + case UnresolvedSymbol(name, value) if value != self.value: raise ContractViolationError( - f"Symbols `{self.name}` and `{other.name}` do not match: `{self.value}` != `{other.value}`" + f"Symbols `{self.name}` and `{name}` do not match: `{self.value}` != `{value}`" ) - case value if value != value: + case self.value: + return True + case value if self.value is not None: raise ContractViolationError( f"Symbols `{self.name}` and `{other}` do not match: `{self.value}` != `{other}`" ) @@ -45,13 +50,12 @@ def __eq__(self, other: Any) -> "UnresolvedSymbol": return self def __bool__(self) -> bool: - return (self.value is not None) + return self.value is not None P, R = ParamSpec("P"), TypeVar("R") -#: Test @decorator def contract( func: Callable[P, R], @@ -73,7 +77,6 @@ def contract( If False, the contracts are not evaluated, by default True """ - if not evaluate: return func(*args, **kw) @@ -117,11 +120,11 @@ def evaluate_annotations(annotations: dict[str, Any]) -> None: logger.debug("contract for `%s`, unresolved: `%s`, %s", arg_name, unresolved, symbols) - if not meta(*[(symbols | injectables)[i] for i in meta_args]): + if not meta(*[(symbols | injectables)[i] for i in meta_args]): raise ContractViolationError(f"Contract violated for argument: `{arg_name}`") if any([i.value is None for i in symbols.values()]): - raise ContractLogicError(f"Not all symbols were resolved `%s`", symbols) + raise ContractLogicError(f"Not all symbols were resolved `{symbols}`", ) injectables |= {k: v.value for k, v in symbols.items()} @@ -146,6 +149,7 @@ def evaluate_annotations(annotations: dict[str, Any]) -> None: if __name__ == "__main__": + # pylint: disable=invalid-name, missing-function-docstring # Example import numpy as np diff --git a/tests/test_dbc.py b/tests/test_dbc.py index 6852cac..1d88f98 100644 --- a/tests/test_dbc.py +++ b/tests/test_dbc.py @@ -4,9 +4,9 @@ import numpy as np import pandas as pd import pytest -from design_by_contract import contract, ContractViolationError, ContractLogicError - +from design_by_contract import contract, ContractViolationError, ContractLogicError, UnresolvedSymbol +# pylint: skip-file class TestNumpy: def test_matmult_correct(self): @contract @@ -115,15 +115,16 @@ def spam( assert str(exc_info.value) == ("Contract violated for argument: `a`") def test_vstack(self): - @contract def spam( a: Annotated[np.ndarray, lambda x, m, o: (m, o) == x.shape], b: Annotated[np.ndarray, lambda x, n, o: (n, o) == x.shape], - ) -> Annotated[np.ndarray, lambda x, m,n,o: x.shape == (m+n, o)]: - print(np.vstack((a,b)).shape) - return np.vstack((a,b)) - spam(np.zeros((3, 2)), np.zeros(( 4, 2))) + ) -> Annotated[np.ndarray, lambda x, m, n, o: x.shape == (m + n, o)]: + print(np.vstack((a, b)).shape) + return np.vstack((a, b)) + + spam(np.zeros((3, 2)), np.zeros((4, 2))) + class TestGeneral: def test_docstring(self): @@ -141,6 +142,50 @@ def spam(a: np.ndarray, b: Annotated[np.ndarray, lambda b, m: b.shape == (m, 3)] assert "(a: numpy.ndarray, b: typing.Annotated[numpy.ndarray," in str(signature(spam)) + def test_reserved(self): + @contract(reserved="y") + def spam( + a: Annotated[np.ndarray, lambda y, m, n: (m, n) == y.shape], + b: Annotated[np.ndarray, lambda y, n, o: (n, o) == y.shape], + ) -> Annotated[np.ndarray, lambda y, m, o: y.shape == (m, o)]: + + return a @ b + + def test_match(self): + a, b = UnresolvedSymbol("a"), UnresolvedSymbol("b") + a == 2 + b == a + assert a.value == b.value + + def test_match_fail(self): + a, b = UnresolvedSymbol("a"), UnresolvedSymbol("b") + a == 2 + b == 1 + with pytest.raises(ContractViolationError) as exc_info: + a == b + + def test_match_symmetry(self): + a, b = UnresolvedSymbol("a"), UnresolvedSymbol("b") + a == 2 + assert a.value == 2 + + b = UnresolvedSymbol("a") + 2 == b + assert b.value == 2 + + def test_match_fail2(self): + a = UnresolvedSymbol("a") + a == 2 + + with pytest.raises(ContractViolationError) as exc_info: + a == 3 + + with pytest.raises(ContractViolationError) as exc_info: + 3 == a + + a == 2 + 2 == a + class TestPandas: def test_pandas_correct(self):