3939}
4040_unary_funcs = {
4141 "exp" : sp .exp ,
42- "log10" : lambda x : - sp .oo if x .is_zero is True else sp .log (x , 10 ),
43- "log2" : lambda x : - sp .oo if x .is_zero is True else sp .log (x , 2 ),
42+ "log10" : lambda x , evaluate = True : - sp .oo
43+ if x .is_zero is True
44+ else sp .log (x , 10 , evaluate = evaluate ),
45+ "log2" : lambda x , evaluate = True : - sp .oo
46+ if x .is_zero is True
47+ else sp .log (x , 2 , evaluate = evaluate ),
4448 "ln" : sp .log ,
4549 "sqrt" : sp .sqrt ,
4650 "abs" : sp .Abs ,
@@ -75,8 +79,14 @@ class MathVisitorSympy(PetabMathExprParserVisitor):
7579
7680 For a general introduction to ANTLR4 visitors, see:
7781 https://github.com/antlr/antlr4/blob/7d4cea92bc3f7d709f09c3f1ac77c5bbc71a6749/doc/python-target.md
82+
83+ :param evaluate: Whether to evaluate the expression.
7884 """
7985
86+ def __init__ (self , evaluate = True ):
87+ super ().__init__ ()
88+ self .evaluate = evaluate
89+
8090 def visitPetabExpression (
8191 self , ctx : PetabMathExprParser .PetabExpressionContext
8292 ) -> sp .Expr | sp .Basic :
@@ -101,9 +111,17 @@ def visitMultExpr(
101111 operand1 = bool2num (self .visit (ctx .getChild (0 )))
102112 operand2 = bool2num (self .visit (ctx .getChild (2 )))
103113 if ctx .ASTERISK ():
104- return operand1 * operand2
114+ return sp . Mul ( operand1 , operand2 , evaluate = self . evaluate )
105115 if ctx .SLASH ():
106- return operand1 / operand2
116+ return (
117+ operand1 / operand2
118+ if self .evaluate
119+ else sp .Mul (
120+ operand1 ,
121+ sp .Pow (operand2 , - 1 , evaluate = False ),
122+ evaluate = False ,
123+ )
124+ )
107125
108126 raise AssertionError (f"Unexpected expression: { ctx .getText ()} " )
109127
@@ -112,9 +130,9 @@ def visitAddExpr(self, ctx: PetabMathExprParser.AddExprContext) -> sp.Expr:
112130 op1 = bool2num (self .visit (ctx .getChild (0 )))
113131 op2 = bool2num (self .visit (ctx .getChild (2 )))
114132 if ctx .PLUS ():
115- return op1 + op2
133+ return sp . Add ( op1 , op2 , evaluate = self . evaluate )
116134 if ctx .MINUS ():
117- return op1 - op2
135+ return sp . Add ( op1 , - op2 , evaluate = self . evaluate )
118136
119137 raise AssertionError (
120138 f"Unexpected operator: { ctx .getChild (1 ).getText ()} "
@@ -146,28 +164,32 @@ def visitFunctionCall(
146164 f"Unexpected number of arguments: { len (args )} "
147165 f"in { ctx .getText ()} "
148166 )
149- return _trig_funcs [func_name ](* args )
167+ return _trig_funcs [func_name ](* args , evaluate = self . evaluate )
150168 if func_name in _unary_funcs :
151169 if len (args ) != 1 :
152170 raise AssertionError (
153171 f"Unexpected number of arguments: { len (args )} "
154172 f"in { ctx .getText ()} "
155173 )
156- return _unary_funcs [func_name ](* args )
174+ return _unary_funcs [func_name ](* args , evaluate = self . evaluate )
157175 if func_name in _binary_funcs :
158176 if len (args ) != 2 :
159177 raise AssertionError (
160178 f"Unexpected number of arguments: { len (args )} "
161179 f"in { ctx .getText ()} "
162180 )
163- return _binary_funcs [func_name ](* args )
181+ return _binary_funcs [func_name ](* args , evaluate = self . evaluate )
164182 if func_name == "log" :
165183 if len (args ) not in [1 , 2 ]:
166184 raise AssertionError (
167185 f"Unexpected number of arguments: { len (args )} "
168186 f"in { ctx .getText ()} "
169187 )
170- return - sp .oo if args [0 ].is_zero is True else sp .log (* args )
188+ return (
189+ - sp .oo
190+ if args [0 ].is_zero is True
191+ else sp .log (* args , evaluate = self .evaluate )
192+ )
171193
172194 if func_name == "piecewise" :
173195 if (len (args ) - 1 ) % 2 != 0 :
@@ -184,7 +206,7 @@ def visitFunctionCall(
184206 args [::2 ], args [1 ::2 ], strict = True
185207 )
186208 )
187- return sp .Piecewise (* sp_args )
209+ return sp .Piecewise (* sp_args , evaluate = self . evaluate )
188210
189211 raise ValueError (f"Unknown function: { ctx .getText ()} " )
190212
@@ -203,7 +225,7 @@ def visitPowerExpr(
203225 )
204226 operand1 = bool2num (self .visit (ctx .getChild (0 )))
205227 operand2 = bool2num (self .visit (ctx .getChild (2 )))
206- return sp .Pow (operand1 , operand2 )
228+ return sp .Pow (operand1 , operand2 , evaluate = self . evaluate )
207229
208230 def visitUnaryExpr (
209231 self , ctx : PetabMathExprParser .UnaryExprContext
@@ -240,7 +262,7 @@ def visitComparisonExpr(
240262 if op in ops :
241263 lhs = bool2num (lhs )
242264 rhs = bool2num (rhs )
243- return ops [op ](lhs , rhs )
265+ return ops [op ](lhs , rhs , evaluate = self . evaluate )
244266
245267 raise AssertionError (f"Unexpected operator: { op } " )
246268
@@ -301,4 +323,6 @@ def num2bool(x: sp.Basic | sp.Expr) -> sp.Basic | sp.Expr:
301323 return sp .false
302324 if x .is_zero is False :
303325 return sp .true
326+ if isinstance (x , Boolean ):
327+ return x
304328 return sp .Piecewise ((True , x != 0.0 ), (False , True ))
0 commit comments