Skip to content

Relax Representation Invariant Checking Logic for Recursive Classes #1162

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Mar 22, 2025
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ and adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

### 🐛 Bug fixes

- `check_contracts` no longer makes methods immediately enforce Representation Invariant checks when setting attributes of instances with the same type (one `Node` modifies another `Node` instance) and only checks RIs for these instances after the method returns.

### 🔧 Internal changes

## [2.10.1] - 2025-02-19
Expand Down
40 changes: 39 additions & 1 deletion python_ta/contracts/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,8 +210,10 @@ def new_setattr(self: klass, name: str, value: Any) -> None:
original_attr_value = super(klass, self).__getattribute__(name)
super(klass, self).__setattr__(name, value)
frame_locals = inspect.currentframe().f_back.f_locals
if self is not frame_locals.get("self"):
caller_self = frame_locals.get("self")
if not isinstance(caller_self, type(self)):
# Only validating if the attribute is not being set in a instance/class method
# AND caller_self is an instance of self's type
if klass_mod is not None:
try:
_check_invariants(self, klass, klass_mod.__dict__)
Expand All @@ -221,6 +223,14 @@ def new_setattr(self: klass, name: str, value: Any) -> None:
else:
super(klass, self).__delattr__(name)
raise AssertionError(str(e)) from None
elif caller_self is not self:
# Keep track of mutations to instances that are of the same type as caller_self (and are also not `self`)
# to enforce RIs on them only after the caller function returns.
caller_klass = type(caller_self)
if hasattr(caller_klass, "__mutated_instances__"):
mutated_instances = getattr(caller_klass, "__mutated_instances__")
if self not in mutated_instances:
mutated_instances.append(self)

for attr, value in klass.__dict__.items():
if inspect.isroutine(value):
Expand Down Expand Up @@ -413,6 +423,15 @@ def _get_argument_suggestions(arg: Any, annotation: type) -> str:
def _instance_method_wrapper(wrapped: Callable, klass: type) -> Callable:
@wrapt.decorator
def wrapper(wrapped, instance, args, kwargs):
# Create an accumulator to store the instances mutated across this function call.
# Store and restore existing mutated instance lists in case the instance method
# executes another instance method.
instance_klass = type(instance)
mutated_instances_to_restore = None
if hasattr(instance_klass, "__mutated_instances__"):
mutated_instances_to_restore = getattr(instance_klass, "__mutated_instances__")
setattr(instance_klass, "__mutated_instances__", [])

try:
r = _check_function_contracts(wrapped, instance, args, kwargs)
if _instance_init_in_callstack(instance):
Expand All @@ -421,10 +440,29 @@ def wrapper(wrapped, instance, args, kwargs):
klass_mod = _get_module(klass)
if klass_mod is not None and ENABLE_CONTRACT_CHECKING:
_check_invariants(instance, klass, klass_mod.__dict__)

# Additionally check RI violations on PyTA-decorated instances that were mutated
# across the function call.
mutated_instances = getattr(instance_klass, "__mutated_instances__", [])
for mutated_instance in mutated_instances:
# Mutated instances may be of parent class types so the invariants to check should also be
# for the parent class and not the child class.
mutated_instance_klass = type(mutated_instance)
mutated_instance_klass_mod = _get_module(mutated_instance_klass)
_check_invariants(
mutated_instance,
mutated_instance_klass,
mutated_instance_klass_mod.__dict__,
)
except PyTAContractError as e:
raise AssertionError(str(e)) from None
else:
return r
finally:
if mutated_instances_to_restore is None:
delattr(instance_klass, "__mutated_instances__")
else:
setattr(instance_klass, "__mutated_instances__", mutated_instances_to_restore)

return wrapper(wrapped)

Expand Down
105 changes: 103 additions & 2 deletions tests/test_contracts/test_class_contracts.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import math
from dataclasses import dataclass
from typing import List, Set, Tuple
from typing import List, Optional, Set, Tuple

import pytest
from nested_preconditions_example import Student
Expand All @@ -28,11 +28,13 @@ class Person:
age: int
name: str
fav_foods: List[str]
_other: Optional[Person]

def __init__(self, name, age, fav_food):
def __init__(self, name, age, fav_food, other: Optional[Person] = None):
self.name = name
self.age = age
self.fav_foods = fav_food
self._other = other

def change_name(self, name: str) -> str:
self.name = name
Expand All @@ -50,6 +52,20 @@ def decrease_and_increase_age(self, age: int) -> int:
self.age = age
return age

def decrease_and_increase_others_age(self, other: Person, age: int) -> int:
"""Temporary violates RI for another instance of the same class."""
other.age = -10
other.age = age
return age

def decrease_others_age(self, other: Person) -> None:
"""Violate the RI of an argument"""
other.age = -10

def decrease_attr_others_age(self) -> None:
"""Violate the RI of an instance attribute"""
self._other.age = -10

def add_fav_food(self, food):
self.fav_foods.append(food)

Expand All @@ -62,6 +78,27 @@ def return_mouthful_greeting(self, greeting: str) -> str:
return f"{greeting} {self.name}!"


class Child(Person):
"""Represent a child.

Representation Invariants:
- self.age < 10
"""

def change_someones_name(self, other: Person, name: str) -> None:
"""Temporarily violate an RI of an instance of a parent class

Precondition:
- len(name) > 0
"""
other.name = "" # Violates the length RI of Person.name
other.name = name # Resolves the RI violation

def remove_someones_name(self, other: Person) -> None:
"""Violate an RI of an instance of a parent class"""
other.name = ""


def change_age(person, new_age):
person.age = new_age

Expand Down Expand Up @@ -110,6 +147,16 @@ def person():
return Person("David", 31, ["Sushi"])


@pytest.fixture
def person_2(person):
return Person("Liu", 31, ["Sushi"], person)


@pytest.fixture
def child():
return Child("JackJack", 1, ["Cookies"])


def test_change_age_invalid_over(person) -> None:
"""
Change the age to larger than 150. Expect an exception.
Expand Down Expand Up @@ -147,6 +194,60 @@ def test_change_age_invalid_in_method(person) -> None:
assert age == 10


def test_change_age_of_other_invalid_in_method(person, person_2) -> None:
"""
Call a method that changes age of another instance of the same class to something invalid but
back to something valid.
Expects normal behavior.
"""
age = person.decrease_and_increase_others_age(person_2, 10)
assert age == 10


def test_change_name_of_parent_invalid_in_method(person, child) -> None:
"""
Call a method that changes name of an instance of a parent class to something invalid but
back to something valid.
Expects normal behavior.
This will also check that the child type's RIs are not being enforced on the mutated parent instance.
"""
child.change_someones_name(person, "Davi")
assert person.name == "Davi"


def test_violate_ri_in_other_instance(person, person_2) -> None:
"""
Call a method that changes age of another instance of the same class to something invalid.
Expects the RI to be violated hence an AssertionError to be raised.
"""
with pytest.raises(AssertionError) as excinfo:
person.decrease_others_age(person_2)
msg = str(excinfo.value)
assert "self.age > 0" in msg


def test_violate_ri_in_attribute_instance(person_2) -> None:
"""
Call a method that changes age of an instance attribute of the same class to something invalid.
Expects the RI to be violated hence an AssertionError to be raised.
"""
with pytest.raises(AssertionError) as excinfo:
person_2.decrease_attr_others_age()
msg = str(excinfo.value)
assert "self.age > 0" in msg


def test_violate_ri_in_parent_instance(person, child) -> None:
"""
Call a method that changes name of an instance of a parent class to something invalid.
Expects the RI to be violated hence an AssertionError to be raised.
"""
with pytest.raises(AssertionError) as excinfo:
child.remove_someones_name(person)
msg = str(excinfo.value)
assert "len(self.name) > 0" in msg


def test_same_method_names(person) -> None:
"""
Call a method with the same name as an instance method and ensure reprsentation invariants are checked.
Expand Down