Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Python Frontend (maybe bug, maybe WAI): Resolution of free symbols may break in a couple of ways #1791

Open
pratyai opened this issue Nov 23, 2024 · 0 comments
Labels
bug Something isn't working frontend

Comments

@pratyai
Copy link
Collaborator

pratyai commented Nov 23, 2024

In the following program, foo() uses symbols N and M, but always as the expression N-M. Now, if I try bar.to_sdfg(), which calls foo(a), the size of a gives one constraint for N-M, but the other N-M in the function body provides two free symbols. In the end, it resolves as one free symbols in the end (which happens to be M). I.e., I cannot call just foo(a), but I must pass M (e.g., foo(a, M=M)).

This may be considered as an intended behaviour (e.g., sympy cannot solve arbitrary complex expression). However, I cannot call foo(a, N=N, M=M), because N is not a free symbol (even if N-M matches the size constraint). So, I cannot look at the definition of foo() and say that "since there are two free symbols, I'll just pass all of them". I have to look at the generated SDFG of foo() to find out which one of the original two free symbols survived.

This is a bit awkward, but even that could be considered acceptable (e.g., one need to try a few times, but in the end it works). But then, as I show in bar2(), foo(a, M=M) does not work either! Because somehow there is a __SOLVE_M left in the SDFG. I suspect that this part is probably a real bug that cannot be ignored, even after the earlier justifications.

Finally, of course it's possible to replace N-M with a new symbol N_minus_M, and then I can even call foo(a), because the size constraint already fully resolves everything.

import dace

N = dace.symbol('N')
M = dace.symbol('M')
N_minus_M = dace.symbol('N_minus_M')


@dace.program
def foo(a: dace.float64[N - M]):
    for i, in dace.map[0:N-M]:
        a[i] = 1


@dace.program
def foo_alt(a: dace.float64[N_minus_M]):
    for i, in dace.map[0:N_minus_M]:
        a[i] = 1


@dace.program
def bar(a: dace.float64[N - M]):
    foo(a)


@dace.program
def bar_2(a: dace.float64[N - M]):
    foo(a, M=M)


@dace.program
def bar_alt(a: dace.float64[N - M]):
    foo_alt(a)


def test_foo_bar():
    g = bar_alt.to_sdfg(simplify=False)
    # OK
    g.validate()
    g.compile()

    g = bar.to_sdfg(simplify=False)
    # raise DaceSyntaxError(
    #                     self, node, 'Argument number mismatch in'
    #                     ' call to "%s" (expected %d,'
    #                     ' got %d). Missing arguments: %s' % (funcname, len(required_args), len(args), missing))
    # E               dace.frontend.python.common.DaceSyntaxError: Argument number mismatch in call to "radiation_aerosol_optics_foo" (expected 2, got 1). Missing arguments: {'M'}
    g.validate()
    g.compile()

    g = bar_2.to_sdfg(simplify=False)
    # scalar_args.update({k: dt.Scalar(self.symbols[k]) for k in free_symbols if not k.startswith('__dace')})
    # E       KeyError: '__SOLVE_M'
    g.validate()
    g.compile()
@pratyai pratyai added bug Something isn't working frontend labels Nov 23, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working frontend
Projects
None yet
Development

No branches or pull requests

1 participant