-
Notifications
You must be signed in to change notification settings - Fork 51
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat[next]: Infer as_fieldop type without domain #1853
base: main
Are you sure you want to change the base?
Conversation
…ot yet working for tuples
@@ -492,10 +492,7 @@ def visit_SetAt(self, node: itir.SetAt, *, ctx) -> None: | |||
# probably just change the behaviour of the lowering. Until then we do this more | |||
# complicated comparison. | |||
if isinstance(target_type, ts.FieldType) and isinstance(expr_type, ts.FieldType): | |||
assert ( | |||
set(expr_type.dims).issubset(set(target_type.dims)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the subset requirement is not necessary anymore, e.g. in the case
itir.SetAt(
domain=unstructured_domain,
expr=im.as_fieldop(
im.lambda_("it")(im.reduce("plus", 0.0)(im.deref("it"))),
unstructured_domain,
)(im.ref("inp")),
target=im.ref("out"),
)
with
im.sym("inp", float_vertex_v2e_k_field)
and im.sym("out", float_vertex_k_field)
the expr_type is
[Dimension(value='Vertex', kind=<DimensionKind.HORIZONTAL: 'horizontal'>), Dimension(value='KDim', kind=<DimensionKind.VERTICAL: 'vertical'>), Dimension(value='V2E', kind=<DimensionKind.LOCAL: 'local'>)]
and the target_type is
[Dimension(value='Vertex', kind=<DimensionKind.HORIZONTAL: 'horizontal'>), Dimension(value='KDim', kind=<DimensionKind.VERTICAL: 'vertical'>)]
which is correct in my opinion.
cf. test_fencil_with_nb_field_input
)(im.ref("inp1", float_i_field), im.ref("inp2", float_i_field)), | ||
ts.TupleType(types=[float_i_field, float_i_field]), | ||
)(im.ref("inp1", float_i_field), im.ref("inp2", float_j_field)), | ||
ts.TupleType(types=[float_ij_field, float_ij_field]), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is ts.TupleType(types=[float_ij_field, float_ij_field]
correct here?
Co-authored-by: Till Ehrengruber <[email protected]>
Co-authored-by: Till Ehrengruber <[email protected]>
Co-authored-by: Till Ehrengruber <[email protected]>
Previously, a broadcast generated an `as_fieldop`. In [PR#1853](#1853), the domain restriction returns a smaller domain for the `as_fieldop`. In order to avoid this, we introduce the broadcast builtin here and transform it into an `im.as_fieldop("deref", ...)` call after the domain inference. --------- Co-authored-by: Till Ehrengruber <[email protected]>
Zero-dimensional fields were supported in the dace backend as program arguments, but not as temporary fields. This PR provides support for zero-dimensional temporary fields, which are computed by `as_fieldop` expressions with empty domain. This feature is enabled by PR #1853. Note that the current representation in the SDFG as scalars does not allow to return a zero-dimensional fields. However, the feature of zero-dimensional fields as program output is not required.
I have fixed the lowering of empty domains in dace backend, so the tests should now pass after you rebase. |
Thanks @edopao! Now the respective tests are passing. |
This extends the GTIR type inference to infer the type of as_fieldop calls without a domain.
TODOs:
test_fieldop_from_scan
:__field_operator_simple_scan_operator(out, __out_0_range) { out @ c⟨ KDimᵥ: [__out_0_range[0], __out_0_range[1][ ⟩ ← as_fieldop(scan(λ(carry) → carry + 1.0, True, 1.0), cartesian_domain())();
has an empty domain. A special case is implemented for that, which ensures, that the vertical dimensions are kept.
isinstance(subexpr.type, ts.TypeSpec)
becausesubexpr.type = None
when running it or debugging without breakpoints. There is however no error when setting a breakpoint in type_synthesizer.py:applied_as_fieldop.tuple_get
It passes when removing Vertex from the position_dims and the defined_dims, i.e.,