39
39
}
40
40
_unary_funcs = {
41
41
"exp" : sp .exp ,
42
- "log10" : lambda x : - sp .oo if x .is_zero is True else sp .log (x , 10 ),
43
- "log2" : lambda x : - sp .oo if x .is_zero is True else sp .log (x , 2 ),
42
+ "log10" : lambda x , evaluate = True : - sp .oo
43
+ if x .is_zero is True
44
+ else sp .log (x , 10 , evaluate = evaluate ),
45
+ "log2" : lambda x , evaluate = True : - sp .oo
46
+ if x .is_zero is True
47
+ else sp .log (x , 2 , evaluate = evaluate ),
44
48
"ln" : sp .log ,
45
49
"sqrt" : sp .sqrt ,
46
50
"abs" : sp .Abs ,
@@ -75,8 +79,14 @@ class MathVisitorSympy(PetabMathExprParserVisitor):
75
79
76
80
For a general introduction to ANTLR4 visitors, see:
77
81
https://github.com/antlr/antlr4/blob/7d4cea92bc3f7d709f09c3f1ac77c5bbc71a6749/doc/python-target.md
82
+
83
+ :param evaluate: Whether to evaluate the expression.
78
84
"""
79
85
86
+ def __init__ (self , evaluate = True ):
87
+ super ().__init__ ()
88
+ self .evaluate = evaluate
89
+
80
90
def visitPetabExpression (
81
91
self , ctx : PetabMathExprParser .PetabExpressionContext
82
92
) -> sp .Expr | sp .Basic :
@@ -101,9 +111,17 @@ def visitMultExpr(
101
111
operand1 = bool2num (self .visit (ctx .getChild (0 )))
102
112
operand2 = bool2num (self .visit (ctx .getChild (2 )))
103
113
if ctx .ASTERISK ():
104
- return operand1 * operand2
114
+ return sp . Mul ( operand1 , operand2 , evaluate = self . evaluate )
105
115
if ctx .SLASH ():
106
- return operand1 / operand2
116
+ return (
117
+ operand1 / operand2
118
+ if self .evaluate
119
+ else sp .Mul (
120
+ operand1 ,
121
+ sp .Pow (operand2 , - 1 , evaluate = False ),
122
+ evaluate = False ,
123
+ )
124
+ )
107
125
108
126
raise AssertionError (f"Unexpected expression: { ctx .getText ()} " )
109
127
@@ -112,9 +130,9 @@ def visitAddExpr(self, ctx: PetabMathExprParser.AddExprContext) -> sp.Expr:
112
130
op1 = bool2num (self .visit (ctx .getChild (0 )))
113
131
op2 = bool2num (self .visit (ctx .getChild (2 )))
114
132
if ctx .PLUS ():
115
- return op1 + op2
133
+ return sp . Add ( op1 , op2 , evaluate = self . evaluate )
116
134
if ctx .MINUS ():
117
- return op1 - op2
135
+ return sp . Add ( op1 , - op2 , evaluate = self . evaluate )
118
136
119
137
raise AssertionError (
120
138
f"Unexpected operator: { ctx .getChild (1 ).getText ()} "
@@ -146,28 +164,32 @@ def visitFunctionCall(
146
164
f"Unexpected number of arguments: { len (args )} "
147
165
f"in { ctx .getText ()} "
148
166
)
149
- return _trig_funcs [func_name ](* args )
167
+ return _trig_funcs [func_name ](* args , evaluate = self . evaluate )
150
168
if func_name in _unary_funcs :
151
169
if len (args ) != 1 :
152
170
raise AssertionError (
153
171
f"Unexpected number of arguments: { len (args )} "
154
172
f"in { ctx .getText ()} "
155
173
)
156
- return _unary_funcs [func_name ](* args )
174
+ return _unary_funcs [func_name ](* args , evaluate = self . evaluate )
157
175
if func_name in _binary_funcs :
158
176
if len (args ) != 2 :
159
177
raise AssertionError (
160
178
f"Unexpected number of arguments: { len (args )} "
161
179
f"in { ctx .getText ()} "
162
180
)
163
- return _binary_funcs [func_name ](* args )
181
+ return _binary_funcs [func_name ](* args , evaluate = self . evaluate )
164
182
if func_name == "log" :
165
183
if len (args ) not in [1 , 2 ]:
166
184
raise AssertionError (
167
185
f"Unexpected number of arguments: { len (args )} "
168
186
f"in { ctx .getText ()} "
169
187
)
170
- return - sp .oo if args [0 ].is_zero is True else sp .log (* args )
188
+ return (
189
+ - sp .oo
190
+ if args [0 ].is_zero is True
191
+ else sp .log (* args , evaluate = self .evaluate )
192
+ )
171
193
172
194
if func_name == "piecewise" :
173
195
if (len (args ) - 1 ) % 2 != 0 :
@@ -184,7 +206,7 @@ def visitFunctionCall(
184
206
args [::2 ], args [1 ::2 ], strict = True
185
207
)
186
208
)
187
- return sp .Piecewise (* sp_args )
209
+ return sp .Piecewise (* sp_args , evaluate = self . evaluate )
188
210
189
211
raise ValueError (f"Unknown function: { ctx .getText ()} " )
190
212
@@ -203,7 +225,7 @@ def visitPowerExpr(
203
225
)
204
226
operand1 = bool2num (self .visit (ctx .getChild (0 )))
205
227
operand2 = bool2num (self .visit (ctx .getChild (2 )))
206
- return sp .Pow (operand1 , operand2 )
228
+ return sp .Pow (operand1 , operand2 , evaluate = self . evaluate )
207
229
208
230
def visitUnaryExpr (
209
231
self , ctx : PetabMathExprParser .UnaryExprContext
@@ -240,7 +262,7 @@ def visitComparisonExpr(
240
262
if op in ops :
241
263
lhs = bool2num (lhs )
242
264
rhs = bool2num (rhs )
243
- return ops [op ](lhs , rhs )
265
+ return ops [op ](lhs , rhs , evaluate = self . evaluate )
244
266
245
267
raise AssertionError (f"Unexpected operator: { op } " )
246
268
@@ -301,4 +323,6 @@ def num2bool(x: sp.Basic | sp.Expr) -> sp.Basic | sp.Expr:
301
323
return sp .false
302
324
if x .is_zero is False :
303
325
return sp .true
326
+ if isinstance (x , Boolean ):
327
+ return x
304
328
return sp .Piecewise ((True , x != 0.0 ), (False , True ))
0 commit comments