@@ -100,21 +100,22 @@ def visit_StencilClosure(self, node: ir.StencilClosure):
100
100
ValidateRecordedShiftsAnnex ().visit (node )
101
101
return self .generic_visit (node )
102
102
103
- def visit_FunCall (self , node : ir .FunCall ):
103
+ def visit_FunCall (self , node : ir .FunCall , ** kwargs ):
104
104
old_node = node
105
- node = self .generic_visit (node )
105
+ node = self .generic_visit (node ,
106
+ ignore_recorded_shifts_missing = (kwargs .get ("ignore_recorded_shifts_missing" , False ) or (hasattr (node .annex , "recorded_shifts" ) and len (node .annex .recorded_shifts ) == 0 )))
106
107
#ValidateRecordedShiftsAnnex().visit(node)
107
108
if isinstance (node .fun , ir .Lambda ):
108
109
eligible_params = [False ] * len (node .fun .params )
109
110
110
111
# force inline lift args derefed at at most a single position
111
112
new_args = []
112
113
bound_scalars = {}
113
- # TODO: what is node.fun is not a lambda? e.g. directly deref?
114
+ # TODO: what if node.fun is not a lambda? e.g. directly deref?
114
115
for i , (param , arg ) in enumerate (zip (node .fun .params , node .args )):
115
- if common_pattern_matcher .is_applied_lift (arg ) and not hasattr (param .annex , "recorded_shifts" ):
116
+ if not kwargs . get ( "ignore_recorded_shifts_missing" , False ) and common_pattern_matcher .is_applied_lift (arg ) and not hasattr (param .annex , "recorded_shifts" ):
116
117
breakpoint ()
117
- if common_pattern_matcher .is_applied_lift (arg ) and param .annex .recorded_shifts in [set (), {()}]:
118
+ if not kwargs . get ( "ignore_recorded_shifts_missing" , False ) and common_pattern_matcher .is_applied_lift (arg ) and param .annex .recorded_shifts in [set (), {()}]:
118
119
eligible_params [i ] = True
119
120
global unique_id
120
121
bound_arg_name = f"__wtf{ unique_id } "
@@ -123,6 +124,7 @@ def visit_FunCall(self, node: ir.FunCall):
123
124
capture_lift .annex .recorded_shifts = param .annex .recorded_shifts
124
125
new_args .append (capture_lift )
125
126
bound_scalars [bound_arg_name ] = InlineLifts (flags = InlineLifts .Flag .INLINE_TRIVIAL_DEREF_LIFT ).visit (im .deref (arg ), recurse = False )
127
+ ValidateRecordedShiftsAnnex ().visit (bound_scalars [bound_arg_name ])
126
128
else :
127
129
new_args .append (arg )
128
130
@@ -198,7 +200,7 @@ def apply_common_transforms(
198
200
] = None ,
199
201
symbolic_domain_sizes : Optional [dict [str , str ]] = None ,
200
202
):
201
- lift_mode = LiftMode .FORCE_TEMPORARIES
203
+ # lift_mode = LiftMode.FORCE_TEMPORARIES
202
204
203
205
if lift_mode is None :
204
206
lift_mode = LiftMode .FORCE_INLINE
@@ -231,7 +233,7 @@ def apply_common_transforms(
231
233
inlined = InlineLambdas .apply (
232
234
inlined ,
233
235
opcount_preserving = True ,
234
- force_inline_lift_args = True ,
236
+ force_inline_lift_args = True , # todo: this is still needed as we can not extract a lift from a conditional
235
237
)
236
238
if inlined == ir :
237
239
break
0 commit comments