Skip to content

Commit 9270262

Browse files
committed
Small fixes. All Icon4Py tests except for as_offsets pass with gtfn_cpu & temporaries
1 parent 38a3fe9 commit 9270262

File tree

5 files changed

+39
-11
lines changed

5 files changed

+39
-11
lines changed

src/gt4py/next/ffront/foast_to_itir.py

+11-1
Original file line numberDiff line numberDiff line change
@@ -342,7 +342,17 @@ def visit_BinOp(self, node: foast.BinOp, **kwargs) -> itir.FunCall:
342342
return self._map(node.op.value, node.left, node.right)
343343

344344
def visit_TernaryExpr(self, node: foast.TernaryExpr, **kwargs) -> itir.FunCall:
345-
return im.if_(im.deref(self.visit(node.condition, **kwargs)), self.visit(node.true_expr, **kwargs), self.visit(node.false_expr, **kwargs))
345+
op = "if_"
346+
args = (node.condition, node.true_expr, node.false_expr)
347+
lowered_args = [to_iterator_of_tuples(self.visit(arg, **kwargs), arg.type) for arg in args]
348+
if any(type_info.contains_local_field(arg.type) for arg in args):
349+
lowered_args = [promote_to_list(arg)(larg) for arg, larg in zip(args, lowered_args)]
350+
op = im.call("map_")(op)
351+
352+
return to_tuples_of_iterator(im.promote_to_lifted_stencil(im.call(op))(*lowered_args), node.type)
353+
354+
# TODO: iterator of tuples?
355+
#return im.if_(im.deref(self.visit(node.condition, **kwargs)), self.visit(node.true_expr, **kwargs), self.visit(node.false_expr, **kwargs))
346356

347357
def visit_Compare(self, node: foast.Compare, **kwargs) -> itir.FunCall:
348358
return self._map(node.op.value, node.left, node.right)

src/gt4py/next/iterator/transforms/inline_lifts.py

+7-2
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,11 @@ def visit_FunCall(
218218
):
219219
symtable = kwargs["symtable"]
220220

221+
ignore_recorded_shifts_missing = (kwargs.get("ignore_recorded_shifts_missing", False) or (
222+
hasattr(node.annex, "recorded_shifts") and len(
223+
node.annex.recorded_shifts) == 0))
224+
kwargs = {**kwargs, "ignore_recorded_shifts_missing": ignore_recorded_shifts_missing}
225+
221226
recorded_shifts_annex = getattr(node.annex, "recorded_shifts", None)
222227
old_node = node
223228
node = (
@@ -306,13 +311,13 @@ def visit_FunCall(
306311
and len(node.args) > 0
307312
and self.predicate(node, is_scan_pass_context)
308313
):
309-
if not hasattr(node.annex, "recorded_shifts"):
314+
if not ignore_recorded_shifts_missing and not hasattr(node.annex, "recorded_shifts"):
310315
breakpoint()
311316

312317
# if the lift is never derefed its params also don't have a recorded_shifts attr and the
313318
# following will fail. we don't care about such lifts anyway as they are later on and
314319
# disappear
315-
if len(node.annex.recorded_shifts) == 0:
320+
if ignore_recorded_shifts_missing or len(node.annex.recorded_shifts) == 0:
316321
return node
317322

318323
stencil = node.fun.args[0] # type: ignore[attr-defined] # node already asserted to be of type ir.FunCall

src/gt4py/next/iterator/transforms/pass_manager.py

+9-7
Original file line numberDiff line numberDiff line change
@@ -100,21 +100,22 @@ def visit_StencilClosure(self, node: ir.StencilClosure):
100100
ValidateRecordedShiftsAnnex().visit(node)
101101
return self.generic_visit(node)
102102

103-
def visit_FunCall(self, node: ir.FunCall):
103+
def visit_FunCall(self, node: ir.FunCall, **kwargs):
104104
old_node = node
105-
node = self.generic_visit(node)
105+
node = self.generic_visit(node,
106+
ignore_recorded_shifts_missing=(kwargs.get("ignore_recorded_shifts_missing", False) or (hasattr(node.annex, "recorded_shifts") and len(node.annex.recorded_shifts) == 0)))
106107
#ValidateRecordedShiftsAnnex().visit(node)
107108
if isinstance(node.fun, ir.Lambda):
108109
eligible_params = [False] * len(node.fun.params)
109110

110111
# force inline lift args derefed at at most a single position
111112
new_args = []
112113
bound_scalars = {}
113-
# TODO: what is node.fun is not a lambda? e.g. directly deref?
114+
# TODO: what if node.fun is not a lambda? e.g. directly deref?
114115
for i, (param, arg) in enumerate(zip(node.fun.params, node.args)):
115-
if common_pattern_matcher.is_applied_lift(arg) and not hasattr(param.annex, "recorded_shifts"):
116+
if not kwargs.get("ignore_recorded_shifts_missing", False) and common_pattern_matcher.is_applied_lift(arg) and not hasattr(param.annex, "recorded_shifts"):
116117
breakpoint()
117-
if common_pattern_matcher.is_applied_lift(arg) and param.annex.recorded_shifts in [set(), {()}]:
118+
if not kwargs.get("ignore_recorded_shifts_missing", False) and common_pattern_matcher.is_applied_lift(arg) and param.annex.recorded_shifts in [set(), {()}]:
118119
eligible_params[i] = True
119120
global unique_id
120121
bound_arg_name = f"__wtf{unique_id}"
@@ -123,6 +124,7 @@ def visit_FunCall(self, node: ir.FunCall):
123124
capture_lift.annex.recorded_shifts = param.annex.recorded_shifts
124125
new_args.append(capture_lift)
125126
bound_scalars[bound_arg_name] = InlineLifts(flags=InlineLifts.Flag.INLINE_TRIVIAL_DEREF_LIFT).visit(im.deref(arg), recurse=False)
127+
ValidateRecordedShiftsAnnex().visit(bound_scalars[bound_arg_name])
126128
else:
127129
new_args.append(arg)
128130

@@ -198,7 +200,7 @@ def apply_common_transforms(
198200
] = None,
199201
symbolic_domain_sizes: Optional[dict[str, str]] = None,
200202
):
201-
lift_mode = LiftMode.FORCE_TEMPORARIES
203+
#lift_mode = LiftMode.FORCE_TEMPORARIES
202204

203205
if lift_mode is None:
204206
lift_mode = LiftMode.FORCE_INLINE
@@ -231,7 +233,7 @@ def apply_common_transforms(
231233
inlined = InlineLambdas.apply(
232234
inlined,
233235
opcount_preserving=True,
234-
force_inline_lift_args=True,
236+
force_inline_lift_args=True, # todo: this is still needed as we can not extract a lift from a conditional
235237
)
236238
if inlined == ir:
237239
break

src/gt4py/next/program_processors/codegens/gtfn/codegen.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ def visit_Literal(self, node: gtfn_ir.Literal, **kwargs: Any) -> str:
108108
case _:
109109
result = node.value
110110
# TODO: isn't this wrong and int32 should also be casted to int32?
111-
if node.type in ["float64", "float32", "int32", "int64"]:
111+
if node.type in ["float64", "float32", "int32", "int64", "bool"]:
112112
result = f"({result})"
113113
elif node.type == "axis_literal":
114114
pass

src/gt4py/next/program_processors/runners/gtfn.py

+11
Original file line numberDiff line numberDiff line change
@@ -214,3 +214,14 @@ def compilation_hash(otf_closure: stages.ProgramCall) -> int:
214214
executor=gtfn_gpu_cached_executor,
215215
allocator=next_allocators.StandardGPUFieldBufferAllocator(),
216216
)
217+
218+
run_gtfn_with_temporaries_cached_executor = otf_compile_executor.CachedOTFCompileExecutor(
219+
name="run_gtfn_with_temporaries_cached",
220+
otf_workflow=workflow.CachedStep(
221+
step=run_gtfn_with_temporaries.executor.otf_workflow, hash_function=compilation_hash
222+
),
223+
)
224+
run_gtfn_with_temporaries_cached = otf_compile_executor.OTFBackend(
225+
executor=run_gtfn_with_temporaries_cached_executor,
226+
allocator=next_allocators.StandardCPUFieldBufferAllocator(),
227+
)

0 commit comments

Comments
 (0)