-
Notifications
You must be signed in to change notification settings - Fork 137
Description
In the following program, foo()
uses symbols N
and M
, but always as the expression N-M
. Now, if I try bar.to_sdfg()
, which calls foo(a)
, the size of a
gives one constraint for N-M
, but the other N-M
in the function body provides two free symbols. In the end, it resolves as one free symbols in the end (which happens to be M
). I.e., I cannot call just foo(a)
, but I must pass M
(e.g., foo(a, M=M)
).
This may be considered as an intended behaviour (e.g., sympy cannot solve arbitrary complex expression). However, I cannot call foo(a, N=N, M=M)
, because N
is not a free symbol (even if N-M
matches the size constraint). So, I cannot look at the definition of foo()
and say that "since there are two free symbols, I'll just pass all of them". I have to look at the generated SDFG of foo()
to find out which one of the original two free symbols survived.
This is a bit awkward, but even that could be considered acceptable (e.g., one need to try a few times, but in the end it works). But then, as I show in bar2()
, foo(a, M=M)
does not work either! Because somehow there is a __SOLVE_M
left in the SDFG. I suspect that this part is probably a real bug that cannot be ignored, even after the earlier justifications.
Finally, of course it's possible to replace N-M
with a new symbol N_minus_M
, and then I can even call foo(a)
, because the size constraint already fully resolves everything.
import dace
N = dace.symbol('N')
M = dace.symbol('M')
N_minus_M = dace.symbol('N_minus_M')
@dace.program
def foo(a: dace.float64[N - M]):
for i, in dace.map[0:N-M]:
a[i] = 1
@dace.program
def foo_alt(a: dace.float64[N_minus_M]):
for i, in dace.map[0:N_minus_M]:
a[i] = 1
@dace.program
def bar(a: dace.float64[N - M]):
foo(a)
@dace.program
def bar_2(a: dace.float64[N - M]):
foo(a, M=M)
@dace.program
def bar_alt(a: dace.float64[N - M]):
foo_alt(a)
def test_foo_bar():
g = bar_alt.to_sdfg(simplify=False)
# OK
g.validate()
g.compile()
g = bar.to_sdfg(simplify=False)
# raise DaceSyntaxError(
# self, node, 'Argument number mismatch in'
# ' call to "%s" (expected %d,'
# ' got %d). Missing arguments: %s' % (funcname, len(required_args), len(args), missing))
# E dace.frontend.python.common.DaceSyntaxError: Argument number mismatch in call to "radiation_aerosol_optics_foo" (expected 2, got 1). Missing arguments: {'M'}
g.validate()
g.compile()
g = bar_2.to_sdfg(simplify=False)
# scalar_args.update({k: dt.Scalar(self.symbols[k]) for k in free_symbols if not k.startswith('__dace')})
# E KeyError: '__SOLVE_M'
g.validate()
g.compile()