Skip to content

Python Frontend (maybe bug, maybe WAI): Resolution of free symbols may break in a couple of ways #1791

@pratyai

Description

@pratyai

In the following program, foo() uses symbols N and M, but always as the expression N-M. Now, if I try bar.to_sdfg(), which calls foo(a), the size of a gives one constraint for N-M, but the other N-M in the function body provides two free symbols. In the end, it resolves as one free symbols in the end (which happens to be M). I.e., I cannot call just foo(a), but I must pass M (e.g., foo(a, M=M)).

This may be considered as an intended behaviour (e.g., sympy cannot solve arbitrary complex expression). However, I cannot call foo(a, N=N, M=M), because N is not a free symbol (even if N-M matches the size constraint). So, I cannot look at the definition of foo() and say that "since there are two free symbols, I'll just pass all of them". I have to look at the generated SDFG of foo() to find out which one of the original two free symbols survived.

This is a bit awkward, but even that could be considered acceptable (e.g., one need to try a few times, but in the end it works). But then, as I show in bar2(), foo(a, M=M) does not work either! Because somehow there is a __SOLVE_M left in the SDFG. I suspect that this part is probably a real bug that cannot be ignored, even after the earlier justifications.

Finally, of course it's possible to replace N-M with a new symbol N_minus_M, and then I can even call foo(a), because the size constraint already fully resolves everything.

import dace

N = dace.symbol('N')
M = dace.symbol('M')
N_minus_M = dace.symbol('N_minus_M')


@dace.program
def foo(a: dace.float64[N - M]):
    for i, in dace.map[0:N-M]:
        a[i] = 1


@dace.program
def foo_alt(a: dace.float64[N_minus_M]):
    for i, in dace.map[0:N_minus_M]:
        a[i] = 1


@dace.program
def bar(a: dace.float64[N - M]):
    foo(a)


@dace.program
def bar_2(a: dace.float64[N - M]):
    foo(a, M=M)


@dace.program
def bar_alt(a: dace.float64[N - M]):
    foo_alt(a)


def test_foo_bar():
    g = bar_alt.to_sdfg(simplify=False)
    # OK
    g.validate()
    g.compile()

    g = bar.to_sdfg(simplify=False)
    # raise DaceSyntaxError(
    #                     self, node, 'Argument number mismatch in'
    #                     ' call to "%s" (expected %d,'
    #                     ' got %d). Missing arguments: %s' % (funcname, len(required_args), len(args), missing))
    # E               dace.frontend.python.common.DaceSyntaxError: Argument number mismatch in call to "radiation_aerosol_optics_foo" (expected 2, got 1). Missing arguments: {'M'}
    g.validate()
    g.compile()

    g = bar_2.to_sdfg(simplify=False)
    # scalar_args.update({k: dt.Scalar(self.symbols[k]) for k in free_symbols if not k.startswith('__dace')})
    # E       KeyError: '__SOLVE_M'
    g.validate()
    g.compile()

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't workingfrontend

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions