Skip to content

Commit

Permalink
Fix pattern matching and add unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
StefanUlbrich committed Mar 8, 2022
1 parent 094da3c commit bb4a795
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 16 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "design-by-contract"
version = "0.2"
version = "0.2.1"
description = "Handy decorator to define contracts with dependency injection in Python 3.10 and above"
authors = ["Stefan Ulbrich"]
license = "MIT"
Expand Down
20 changes: 12 additions & 8 deletions src/design_by_contract.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ class UnresolvedSymbol:
Overrides the equality operator to behave like an
assignment.
"""

name: str
value: Optional[Any] = None

Expand All @@ -32,11 +33,15 @@ def __eq__(self, other: Any) -> "UnresolvedSymbol":
if self.value is None:
raise ContractViolationError(f"Symbols `{self.name}` and `{other.name}` undefined")
other.value = self.value
case UnresolvedSymbol(value) if value != self.value:
case UnresolvedSymbol(_, value) if self.value is None:
self.value = value
case UnresolvedSymbol(name, value) if value != self.value:
raise ContractViolationError(
f"Symbols `{self.name}` and `{other.name}` do not match: `{self.value}` != `{other.value}`"
f"Symbols `{self.name}` and `{name}` do not match: `{self.value}` != `{value}`"
)
case value if value != value:
case self.value:
return True
case value if self.value is not None:
raise ContractViolationError(
f"Symbols `{self.name}` and `{other}` do not match: `{self.value}` != `{other}`"
)
Expand All @@ -45,13 +50,12 @@ def __eq__(self, other: Any) -> "UnresolvedSymbol":
return self

def __bool__(self) -> bool:
return (self.value is not None)
return self.value is not None


P, R = ParamSpec("P"), TypeVar("R")


#: Test
@decorator
def contract(
func: Callable[P, R],
Expand All @@ -73,7 +77,6 @@ def contract(
If False, the contracts are not evaluated, by default True
"""


if not evaluate:
return func(*args, **kw)

Expand Down Expand Up @@ -117,11 +120,11 @@ def evaluate_annotations(annotations: dict[str, Any]) -> None:

logger.debug("contract for `%s`, unresolved: `%s`, %s", arg_name, unresolved, symbols)

if not meta(*[(symbols | injectables)[i] for i in meta_args]):
if not meta(*[(symbols | injectables)[i] for i in meta_args]):
raise ContractViolationError(f"Contract violated for argument: `{arg_name}`")

if any([i.value is None for i in symbols.values()]):
raise ContractLogicError(f"Not all symbols were resolved `%s`", symbols)
raise ContractLogicError(f"Not all symbols were resolved `{symbols}`", )

injectables |= {k: v.value for k, v in symbols.items()}

Expand All @@ -146,6 +149,7 @@ def evaluate_annotations(annotations: dict[str, Any]) -> None:


if __name__ == "__main__":
# pylint: disable=invalid-name, missing-function-docstring
# Example
import numpy as np

Expand Down
59 changes: 52 additions & 7 deletions tests/test_dbc.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
import numpy as np
import pandas as pd
import pytest
from design_by_contract import contract, ContractViolationError, ContractLogicError

from design_by_contract import contract, ContractViolationError, ContractLogicError, UnresolvedSymbol

# pylint: skip-file
class TestNumpy:
def test_matmult_correct(self):
@contract
Expand Down Expand Up @@ -115,15 +115,16 @@ def spam(
assert str(exc_info.value) == ("Contract violated for argument: `a`")

def test_vstack(self):

@contract
def spam(
a: Annotated[np.ndarray, lambda x, m, o: (m, o) == x.shape],
b: Annotated[np.ndarray, lambda x, n, o: (n, o) == x.shape],
) -> Annotated[np.ndarray, lambda x, m,n,o: x.shape == (m+n, o)]:
print(np.vstack((a,b)).shape)
return np.vstack((a,b))
spam(np.zeros((3, 2)), np.zeros(( 4, 2)))
) -> Annotated[np.ndarray, lambda x, m, n, o: x.shape == (m + n, o)]:
print(np.vstack((a, b)).shape)
return np.vstack((a, b))

spam(np.zeros((3, 2)), np.zeros((4, 2)))


class TestGeneral:
def test_docstring(self):
Expand All @@ -141,6 +142,50 @@ def spam(a: np.ndarray, b: Annotated[np.ndarray, lambda b, m: b.shape == (m, 3)]

assert "(a: numpy.ndarray, b: typing.Annotated[numpy.ndarray," in str(signature(spam))

def test_reserved(self):
@contract(reserved="y")
def spam(
a: Annotated[np.ndarray, lambda y, m, n: (m, n) == y.shape],
b: Annotated[np.ndarray, lambda y, n, o: (n, o) == y.shape],
) -> Annotated[np.ndarray, lambda y, m, o: y.shape == (m, o)]:

return a @ b

def test_match(self):
a, b = UnresolvedSymbol("a"), UnresolvedSymbol("b")
a == 2
b == a
assert a.value == b.value

def test_match_fail(self):
a, b = UnresolvedSymbol("a"), UnresolvedSymbol("b")
a == 2
b == 1
with pytest.raises(ContractViolationError) as exc_info:
a == b

def test_match_symmetry(self):
a, b = UnresolvedSymbol("a"), UnresolvedSymbol("b")
a == 2
assert a.value == 2

b = UnresolvedSymbol("a")
2 == b
assert b.value == 2

def test_match_fail2(self):
a = UnresolvedSymbol("a")
a == 2

with pytest.raises(ContractViolationError) as exc_info:
a == 3

with pytest.raises(ContractViolationError) as exc_info:
3 == a

a == 2
2 == a


class TestPandas:
def test_pandas_correct(self):
Expand Down

0 comments on commit bb4a795

Please sign in to comment.