Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: detect mutation to values of reactive vars #595

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
212 changes: 202 additions & 10 deletions solara/toestand.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
import contextlib
import dataclasses
import inspect
import logging
import sys
import threading
from types import FrameType
import warnings
import copy
from abc import ABC, abstractmethod
from collections import defaultdict
from operator import getitem
Expand All @@ -19,6 +22,7 @@
TypeVar,
Union,
cast,
List,
overload,
)

Expand All @@ -35,6 +39,7 @@
logger = logging.getLogger("solara.toestand")

_DEBUG = False
_CHECK_MUTATIONS = True


class ThreadLocal(threading.local):
Expand Down Expand Up @@ -265,11 +270,67 @@ def initial_value(self) -> S:
pass


@dataclasses.dataclass
class ShouldNotMutate:
value: Any
original: Any
traceback: Optional[inspect.Traceback]


should_not_mutate: List[ShouldNotMutate] = []


def _track_initial_mutation(value):
frame = _find_outside_solara_frame()
if frame is not None:
frame_info = inspect.getframeinfo(frame)
else:
frame_info = None

if value is not None:
v = ShouldNotMutate(value, copy.deepcopy(value), frame_info)
should_not_mutate.append(v)


def check_mutations():
for v in should_not_mutate:
if v.value != v.original:
tb = v.traceback
if tb:
if tb.code_context:
code = tb.code_context[0].strip()
else:
code = "No code context available"
msg = (
f"Reactive variable initialized at {tb.filename}:{tb.lineno} was initialized with a value of {v.original!r}, but was mutated to {v.value!r}.\n"
f"{code}"
)
else:
msg = f"Reactive variable was initialized with a value of {v.original!r}, but was mutated to {v.value!r} (unable to report the location in the source code)."
raise ValueError(msg)


def _find_outside_solara_frame() -> Optional[FrameType]:
# the module where the call stack origined from
current_frame = inspect.currentframe()
module_frame = None

while current_frame is not None:
module = inspect.getmodule(current_frame)
if module is None or not module.__name__.startswith("solara"):
module_frame = current_frame
break
current_frame = current_frame.f_back

return module_frame


class KernelStoreValue(KernelStore[S]):
default_value: S

def __init__(self, default_value: S, key=None):
self.default_value = default_value
_track_initial_mutation(default_value)
cls = type(default_value)
if key is None:
with KernelStoreValue.scope_lock:
Expand All @@ -279,7 +340,7 @@ def __init__(self, default_value: S, key=None):
super().__init__(key=key)

def initial_value(self) -> S:
return self.default_value
return copy.deepcopy(self.default_value)


class KernelStoreFactory(KernelStore[S]):
Expand All @@ -303,18 +364,49 @@ def initial_value(self) -> S:
return self.factory()


@dataclasses.dataclass
class StoreValue(Generic[S]):
private: S # the internal private value, should never be mutated, when _CHECK_MUTATIONS is False, we expose this in .get()
public: Optional[S] # when _CHECK_MUTATIONS is True, this is the value that is exposed in .get(), it is a deep copy of private
get_traceback: Optional[inspect.Traceback]
set_value: Optional[S] # the value that was set using .set(..), we deepcopy this to set private when _CHECK_MUTATIONS is True
set_traceback: Optional[inspect.Traceback]


def equals2(a, b):
"""Compare two values for equality.

Avoid false negative, e.g. when comparing dataframes, we want to compare
the data, not the object identity.

TODO: how do we reconcile this with the original equals function?

"""
if equals(a, b):
return True
import pickle

try:
if pickle.dumps(a) == pickle.dumps(b):
return True
except Exception:
pass
return False


class Reactive(ValueBase[S]):
_storage: ValueBase[S]
_storage: ValueBase[StoreValue[S]]

def __init__(self, default_value: Union[S, ValueBase[S]], key=None):
def __init__(self, default_value: Union[S, ValueBase[StoreValue[S]]], key=None, equals=equals2):
super().__init__()
if not isinstance(default_value, ValueBase):
self._storage = KernelStoreValue(default_value, key=key)
self._storage = KernelStoreValue(StoreValue[S](private=default_value, public=None, get_traceback=None, set_value=None, set_traceback=None), key=key)
else:
self._storage = default_value
self.__post__init__()
self._name = None
self._owner = None
self.equals = equals

def __set_name__(self, owner, name):
self._name = name
Expand Down Expand Up @@ -346,25 +438,125 @@ def update(self, *args, **kwargs):
def set(self, value: S):
if value is self:
raise ValueError("Can't set a reactive to itself")
self._storage.set(value)
if _CHECK_MUTATIONS:
private = copy.deepcopy(value)
frame = _find_outside_solara_frame()
if frame is not None:
frame_info = inspect.getframeinfo(frame)
else:
frame_info = None
store_value = StoreValue[S](private=private, public=None, get_traceback=None, set_value=value, set_traceback=frame_info)
else:
store_value = StoreValue[S](private=value, public=None, get_traceback=None, set_value=None, set_traceback=None)
self._storage.set(store_value)

def check_mutations(self):
if not _CHECK_MUTATIONS:
return
store_value = self._storage.peek()
if store_value.public is not None and not self.equals(store_value.public, store_value.private):
tb = store_value.get_traceback
# TODO: make the error message as elaborate as below
msg = (
f"Reactive variable was read when it had the value of {store_value.private!r}, but was later mutated to {store_value.public!r}.\n"
"Mutation should not be done on the value of a reactive variable, as in production mode we would be unable to track changes.\n"
)
if tb:
if tb.code_context:
code = tb.code_context[0]
else:
code = "<No code context available>"
msg += f"The last value was read in the following code:\n" f"{tb.filename}:{tb.lineno}\n" f"{code}"
raise ValueError(msg)
elif store_value.set_value is not None and not self.equals(store_value.set_value, store_value.private):
tb = store_value.set_traceback
msg = f"""Reactive variable was set with a value of {store_value.private!r}, but was later mutated mutated to {store_value.set_value!r}.

Mutation should not be done on the value of a reactive variable, as in production mode we would be unable to track changes.

Bad:
mylist = reactive([]]
some_values = [1, 2, 3]
mylist.value = some_values # you give solara a reference to your list
some_values.append(4) # but later mutate it (solara cannot detect this change, so a render will not be triggered)
# if later on a re-render happens for a different reason, you will read of the mutated list.

Good (if you want the reactive variable to be updated):
mylist = reactive([]]
some_values = [1, 2, 3]
mylist.value = some_values
mylist.value = some_values + [4]

Good (if you want to keep mutating your own list):
mylist = reactive([]]
some_values = [1, 2, 3]
mylist.value = some_values.copy() # this gives solara a copy of the list
some_values.append(4) # you are free to mutate your own list, solara will not see this

"""
if tb:
if tb.code_context:
code = tb.code_context[0]
else:
code = "<No code context available>"
msg += "The last time the value was set was at:\n" f"{tb.filename}:{tb.lineno}\n" f"{code}"
raise ValueError(msg)

def get(self, add_watch=None) -> S:
self.check_mutations()
# peek also avoid parents also adding themselves to the reactive_used set
value = self.peek()
if add_watch is not None:
warnings.warn("add_watch is deprecated, use .peek()", DeprecationWarning)
if thread_local.reactive_used is not None:
thread_local.reactive_used.add(self)
# peek to avoid parents also adding themselves to the reactive_used set
return self._storage.peek()
return value

def peek(self) -> S:
"""Return the value without automatically subscribing to listeners."""
return self._storage.peek()
store_value = self._storage.peek()
self._ensure_public_exists()
if _CHECK_MUTATIONS:
assert store_value.public is not None
return store_value.public
else:
return store_value.private

def _ensure_public_exists(self):
store_value = self._storage.peek()
if _CHECK_MUTATIONS and store_value.public is None:
with self.lock:
if store_value.public is None:
frame = _find_outside_solara_frame()
if frame is not None:
frame_info = inspect.getframeinfo(frame)
else:
frame_info = None
store_value.public = copy.deepcopy(store_value.private)
store_value.get_traceback = frame_info

def subscribe(self, listener: Callable[[S], None], scope: Optional[ContextManager] = None):
return self._storage.subscribe(listener, scope=scope)
def listener_wrapper(value: StoreValue[S]):
if _CHECK_MUTATIONS:
self._ensure_public_exists()
assert value.public is not None
listener(value.public)
else:
listener(value.private)

return self._storage.subscribe(listener_wrapper, scope=scope)

def subscribe_change(self, listener: Callable[[S, S], None], scope: Optional[ContextManager] = None):
return self._storage.subscribe_change(listener, scope=scope)
def listener_wrapper(new: StoreValue[S], previous: StoreValue[S]):
if _CHECK_MUTATIONS:
self._ensure_public_exists()
assert new.public is not None
assert previous.public is not None
listener(new.public, previous.public)
else:
listener(new.private, previous.private)

return self._storage.subscribe_change(listener_wrapper, scope=scope)

def computed(self, f: Callable[[S], T]) -> "Computed[T]":
def func():
Expand Down
47 changes: 47 additions & 0 deletions tests/unit/toestand_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
from pathlib import Path
from typing import Callable, Dict, List, Optional, Set, TypeVar

import pytest

import ipyvuetify as v
import react_ipywidgets as react
from typing_extensions import TypedDict
Expand All @@ -14,6 +16,7 @@
import solara.toestand as toestand
from solara.server import kernel, kernel_context
from solara.toestand import Reactive, Ref, State, use_sync_external_store
import pandas as pd

from .common import click

Expand Down Expand Up @@ -1204,3 +1207,47 @@ def test_computed_reload(no_kernel_context):
assert text.widget.v_model == "4.0"
finally:
app.close()


def test_mutate_initial_value():
initial_values = [1, 2, 3]
reactive = Reactive(initial_values)
assert reactive.value == initial_values
initial_values.append(4)
assert reactive.value != initial_values
with pytest.raises(ValueError):
toestand.check_mutations()


def test_mutate_value_public_value():
values = [1, 2, 3]
reactive = Reactive(values)
reactive.value.append(4)
with pytest.raises(ValueError, match="Reactive variable was read when it had the value of \[1, 2, 3\].*"):
reactive.check_mutations()


def test_mutate_value_set_value():
values = [1, 2, 3]
reactive = Reactive(values)
new_values = [1, 2, 3, 4]
reactive.value = new_values
new_values.append(5)
# print(reactive.value)
with pytest.raises(ValueError, match="Reactive variable was set.*"):
reactive.check_mutations()


def test_mutate_value_set_value_dataframe():
# dataframes do not support simple equality checks
# and we cannot have false equals(a, b) == False result when they
# actually are equal
df = pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]})
df_orig = df.copy()
reactive_df = Reactive[Optional[pd.DataFrame]](None)
assert reactive_df.equals(df, df_orig)
reactive_df.value = df
df["a"][0] = 100
assert not reactive_df.equals(df, df_orig)
with pytest.raises(ValueError, match="Reactive variable was set.*"):
reactive_df.check_mutations()
Loading