Skip to content

Commit e800abb

Browse files
committed
Add petab-compatible sympy string-printer
Add a sympy Printer to stringify sympy expressions in a petab-compatible way. For example, we need to avoid `str(sympy.sympify("x^2"))` -> `'x**2'`. Closes PEtab-dev#362.
1 parent 81af370 commit e800abb

File tree

6 files changed

+148
-45
lines changed

6 files changed

+148
-45
lines changed

petab/v1/math/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
"""Functions for parsing and evaluating mathematical expressions."""
22

3+
from .printer import PetabStrPrinter, petab_math_str # noqa: F401
34
from .sympify import sympify_petab # noqa: F401

petab/v1/math/printer.py

+91
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
"""A PEtab-compatible sympy string-printer."""
2+
3+
from itertools import chain, islice
4+
5+
import sympy as sp
6+
from sympy.printing.str import StrPrinter
7+
8+
9+
class PetabStrPrinter(StrPrinter):
10+
"""A PEtab-compatible sympy string-printer."""
11+
12+
#: Mapping of sympy functions to PEtab functions
13+
_func_map = {
14+
"asin": "arcsin",
15+
"acos": "arccos",
16+
"atan": "arctan",
17+
"acot": "arccot",
18+
"asec": "arcsec",
19+
"acsc": "arccsc",
20+
"asinh": "arcsinh",
21+
"acosh": "arccosh",
22+
"atanh": "arctanh",
23+
"acoth": "arccoth",
24+
"asech": "arcsech",
25+
"acsch": "arccsch",
26+
"Abs": "abs",
27+
}
28+
29+
def _print_BooleanTrue(self, expr):
30+
return "true"
31+
32+
def _print_BooleanFalse(self, expr):
33+
return "false"
34+
35+
def _print_Pow(self, expr: sp.Pow):
36+
"""Custom printing for the power operator"""
37+
base, exp = expr.as_base_exp()
38+
return f"{self._print(base)} ^ {self._print(exp)}"
39+
40+
def _print_Infinity(self, expr):
41+
"""Custom printing for infinity"""
42+
return "inf"
43+
44+
def _print_NegativeInfinity(self, expr):
45+
"""Custom printing for negative infinity"""
46+
return "-inf"
47+
48+
def _print_Function(self, expr):
49+
"""Custom printing for specific functions"""
50+
51+
if expr.func.__name__ == "Piecewise":
52+
return self._print_Piecewise(expr)
53+
54+
if func := self._func_map.get(expr.func.__name__):
55+
return f"{func}({', '.join(map(self._print, expr.args))})"
56+
57+
return super()._print_Function(expr)
58+
59+
def _print_Piecewise(self, expr):
60+
"""Custom printing for Piecewise function"""
61+
# merge the tuples and drop the final `True` condition
62+
str_args = map(
63+
self._print,
64+
islice(chain.from_iterable(expr.args), 2 * len(expr.args) - 1),
65+
)
66+
return f"piecewise({', '.join(str_args)})"
67+
68+
def _print_Min(self, expr):
69+
"""Custom printing for Min function"""
70+
return f"min({', '.join(map(self._print, expr.args))})"
71+
72+
def _print_Max(self, expr):
73+
"""Custom printing for Max function"""
74+
return f"max({', '.join(map(self._print, expr.args))})"
75+
76+
77+
def petab_math_str(expr: sp.Basic | sp.Expr | None) -> str:
78+
"""Convert a sympy expression to a PEtab-compatible math expression string.
79+
80+
:example:
81+
>>> expr = sp.sympify("x**2 + sin(y)")
82+
>>> petab_math_str(expr)
83+
'x ^ 2 + sin(y)'
84+
>>> expr = sp.sympify("Piecewise((1, x > 0), (0, True))")
85+
>>> petab_math_str(expr)
86+
'piecewise(1, x > 0, 0)'
87+
"""
88+
if expr is None:
89+
return ""
90+
91+
return PetabStrPrinter().doprint(expr)

petab/v1/math/sympify.py

+21-16
Original file line numberDiff line numberDiff line change
@@ -5,38 +5,37 @@
55
from antlr4 import CommonTokenStream, InputStream
66
from antlr4.error.ErrorListener import ErrorListener
77

8+
from . import petab_math_str
89
from ._generated.PetabMathExprLexer import PetabMathExprLexer
910
from ._generated.PetabMathExprParser import PetabMathExprParser
1011
from .SympyVisitor import MathVisitorSympy, bool2num
1112

1213
__all__ = ["sympify_petab"]
1314

1415

15-
def sympify_petab(
16-
expr: str | int | float, evaluate: bool = True
17-
) -> sp.Expr | sp.Basic:
16+
def sympify_petab(expr: str | int | float | sp.Basic, evaluate: bool = True) -> sp.Expr | sp.Basic:
1817
"""Convert PEtab math expression to sympy expression.
1918
20-
.. note::
21-
22-
All symbols in the returned expression will have the `real=True`
23-
assumption.
2419
2520
Args:
2621
expr: PEtab math expression.
2722
evaluate: Whether to evaluate the expression.
2823
2924
Raises:
3025
ValueError: Upon lexer/parser errors or if the expression is
31-
otherwise invalid.
26+
otherwise invalid.
3227
3328
Returns:
3429
The sympy expression corresponding to `expr`.
3530
Boolean values are converted to numeric values.
3631
32+
.. note::
33+
34+
All symbols in the returned expression will have the ``real=True``
35+
assumption.
3736
3837
:example:
39-
>>> from petab.math import sympify_petab
38+
>>> from petab.v1.math import sympify_petab
4039
>>> sympify_petab("sin(0)")
4140
0
4241
>>> sympify_petab("sin(0)", evaluate=False)
@@ -61,9 +60,8 @@ def sympify_petab(
6160
>>> sympify_petab("2", evaluate=True)
6261
2.00000000000000
6362
"""
64-
if isinstance(expr, sp.Expr):
65-
# TODO: check if only PEtab-compatible symbols and functions are used
66-
return expr
63+
if isinstance(expr, sp.Basic):
64+
return sympify_petab(petab_math_str(expr))
6765

6866
if isinstance(expr, int) or isinstance(expr, np.integer):
6967
return sp.Integer(expr)
@@ -95,10 +93,17 @@ def sympify_petab(
9593
visitor = MathVisitorSympy(evaluate=evaluate)
9694
expr = visitor.visit(tree)
9795
expr = bool2num(expr)
98-
# check for `False`, we'll accept both `True` and `None`
99-
if expr.is_extended_real is False:
100-
raise ValueError(f"Expression {expr} is not real-valued.")
101-
96+
try:
97+
# check for `False`, we'll accept both `True` and `None`
98+
if expr.is_extended_real is False:
99+
raise ValueError(f"Expression {expr} is not real-valued.")
100+
except AttributeError as e:
101+
# work-around for `sp.sec(0, evaluate=False).is_extended_real` error
102+
if str(e) not in (
103+
"'One' object has no attribute '_eval_is_extended_real'",
104+
"'Float' object has no attribute '_eval_is_extended_real'",
105+
):
106+
raise
102107
return expr
103108

104109

petab/v2/core.py

+3-18
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
from typing_extensions import Self
2626

2727
from ..v1.lint import is_valid_identifier
28-
from ..v1.math import sympify_petab
28+
from ..v1.math import petab_math_str, sympify_petab
2929
from . import C, get_observable_df
3030

3131
__all__ = [
@@ -273,23 +273,8 @@ def to_df(self) -> pd.DataFrame:
273273
for record in records:
274274
obs = record[C.OBSERVABLE_FORMULA]
275275
noise = record[C.NOISE_FORMULA]
276-
record[C.OBSERVABLE_FORMULA] = (
277-
None
278-
if obs is None
279-
# TODO: we need a custom printer for sympy expressions
280-
# to avoid '**'
281-
# https://github.com/PEtab-dev/libpetab-python/issues/362
282-
else str(obs)
283-
if not obs.is_number
284-
else float(obs)
285-
)
286-
record[C.NOISE_FORMULA] = (
287-
None
288-
if noise is None
289-
else str(noise)
290-
if not noise.is_number
291-
else float(noise)
292-
)
276+
record[C.OBSERVABLE_FORMULA] = petab_math_str(obs)
277+
record[C.NOISE_FORMULA] = petab_math_str(noise)
293278
return pd.DataFrame(records).set_index([C.OBSERVABLE_ID])
294279

295280
@classmethod

tests/v1/math/test_math.py

+29-10
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,9 @@
66
import sympy as sp
77
import yaml
88
from sympy.abc import _clash
9-
from sympy.logic.boolalg import Boolean
9+
from sympy.logic.boolalg import Boolean, BooleanFalse, BooleanTrue
1010

11-
from petab.math import sympify_petab
11+
from petab.v1.math import petab_math_str, sympify_petab
1212

1313

1414
def test_sympify_numpy():
@@ -28,6 +28,19 @@ def test_evaluate():
2828
act = sympify_petab("piecewise(1, 1 > 2, 0)", evaluate=False)
2929
assert str(act) == "Piecewise((1.0, 1.0 > 2.0), (0.0, True))"
3030

31+
def test_assumptions():
32+
# in PEtab, all symbols are expected to be real-valued
33+
assert sympify_petab("x").is_real
34+
35+
# non-real symbols are changed to real
36+
assert sympify_petab(sp.Symbol("x", real=False)).is_real
37+
38+
39+
def test_printer():
40+
assert petab_math_str(None) == ""
41+
assert petab_math_str(BooleanTrue()) == "true"
42+
assert petab_math_str(BooleanFalse()) == "false"
43+
3144

3245
def read_cases():
3346
"""Read test cases from YAML file in the petab_test_suite package."""
@@ -60,29 +73,35 @@ def read_cases():
6073
@pytest.mark.parametrize("expr_str, expected", read_cases())
6174
def test_parse_cases(expr_str, expected):
6275
"""Test PEtab math expressions for the PEtab test suite."""
63-
result = sympify_petab(expr_str)
64-
if isinstance(result, Boolean):
65-
assert result == expected
76+
sym_expr = sympify_petab(expr_str)
77+
if isinstance(sym_expr, Boolean):
78+
assert sym_expr == expected
6679
else:
6780
try:
68-
result = float(result.evalf())
81+
result = float(sym_expr.evalf())
6982
assert np.isclose(result, expected), (
7083
f"{expr_str}: Expected {expected}, got {result}"
7184
)
7285
except TypeError:
73-
assert result == expected, (
86+
assert sym_expr == expected, (
7487
f"{expr_str}: Expected {expected}, got {result}"
7588
)
7689

90+
# test parsing, printing, and parsing again
91+
resympified = sympify_petab(petab_math_str(sym_expr))
92+
if sym_expr.is_number:
93+
assert np.isclose(float(resympified), float(sym_expr))
94+
else:
95+
assert resympified.equals(sym_expr), (sym_expr, resympified)
96+
7797

7898
def test_ids():
7999
"""Test symbols in expressions."""
80100
assert sympify_petab("bla * 2") == 2.0 * sp.Symbol("bla", real=True)
81101

82102
# test that sympy expressions that are invalid in PEtab raise an error
83-
# TODO: handle these cases after
84-
# https://github.com/PEtab-dev/libpetab-python/pull/364
85-
# sympify_petab(sp.Symbol("föö"))
103+
with pytest.raises(ValueError):
104+
sympify_petab(sp.Symbol("föö"))
86105

87106

88107
def test_syntax_error():

tests/v2/test_problem.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,9 @@ def test_modify_problem():
133133
}
134134
).set_index([OBSERVABLE_ID])
135135
assert_frame_equal(
136-
problem.observable_df[[OBSERVABLE_FORMULA, NOISE_FORMULA]],
136+
problem.observable_df[[OBSERVABLE_FORMULA, NOISE_FORMULA]].map(
137+
lambda x: float(x) if x != "" else None
138+
),
137139
exp_observable_df,
138140
check_dtype=False,
139141
)

0 commit comments

Comments
 (0)