Skip to content

Commit 9cd9879

Browse files
authored
Remove usage of deprecated API dace.Memlet.simple (#1425)
Replace deprecated constructor API dace.Memlet.simple() with dace.Memlet()
1 parent 70f0f88 commit 9cd9879

File tree

3 files changed

+18
-18
lines changed

3 files changed

+18
-18
lines changed

src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py

+7-7
Original file line numberDiff line numberDiff line change
@@ -338,7 +338,7 @@ def visit_StencilClosure(
338338
out_name, debuginfo=closure_sdfg.debuginfo
339339
)
340340
value = ValueExpr(access, dtype)
341-
memlet = dace.Memlet.simple(out_name, "0")
341+
memlet = dace.Memlet(data=out_name, subset="0")
342342
closure_init_state.add_edge(out_tasklet, "__result", access, None, memlet)
343343
program_arg_syms[name] = value
344344
else:
@@ -427,10 +427,10 @@ def visit_StencilClosure(
427427
edge.src_conn,
428428
transient_access,
429429
None,
430-
dace.Memlet.simple(memlet.data, output_subset, debuginfo=nsdfg.debuginfo),
430+
dace.Memlet(data=memlet.data, subset=output_subset, debuginfo=nsdfg.debuginfo),
431431
)
432-
inner_memlet = dace.Memlet.simple(
433-
memlet.data, output_subset, other_subset_str=memlet.subset
432+
inner_memlet = dace.Memlet(
433+
data=memlet.data, subset=output_subset, other_subset=memlet.subset
434434
)
435435
closure_state.add_edge(transient_access, None, map_exit, edge.dst_conn, inner_memlet)
436436
closure_state.remove_edge(edge)
@@ -523,7 +523,7 @@ def _visit_scan_stencil_closure(
523523
"__result",
524524
start_state.add_access(scan_carry_name, debuginfo=scan_sdfg.debuginfo),
525525
None,
526-
dace.Memlet.simple(scan_carry_name, "0"),
526+
dace.Memlet(data=scan_carry_name, subset="0"),
527527
)
528528

529529
# add storage to scan SDFG for inputs
@@ -603,13 +603,13 @@ def _visit_scan_stencil_closure(
603603
connector,
604604
compute_state.add_access(name, debuginfo=lambda_context.body.debuginfo),
605605
None,
606-
dace.Memlet.simple(name, scan_loop_var),
606+
dace.Memlet(data=name, subset=scan_loop_var),
607607
)
608608

609609
update_state.add_nedge(
610610
update_state.add_access(output_name, debuginfo=lambda_context.body.debuginfo),
611611
update_state.add_access(scan_carry_name, debuginfo=lambda_context.body.debuginfo),
612-
dace.Memlet.simple(output_names[0], scan_loop_var, other_subset_str="0"),
612+
dace.Memlet(data=output_name, subset=scan_loop_var, other_subset="0"),
613613
)
614614

615615
return scan_sdfg, map_ranges, scan_dim_index

src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py

+10-10
Original file line numberDiff line numberDiff line change
@@ -260,7 +260,7 @@ def builtin_neighbors(
260260
iterator.indices[shifted_dim],
261261
me,
262262
shift_tasklet,
263-
memlet=dace.Memlet.simple(iterator.indices[shifted_dim].data, "0", debuginfo=di),
263+
memlet=dace.Memlet(data=iterator.indices[shifted_dim].data, subset="0", debuginfo=di),
264264
dst_conn="__idx",
265265
)
266266
state.add_edge(shift_tasklet, "__result", data_access_tasklet, field_index, dace.Memlet())
@@ -280,7 +280,7 @@ def builtin_neighbors(
280280
data_access_tasklet,
281281
mx,
282282
result_access,
283-
memlet=dace.Memlet.simple(result_name, neighbor_index, debuginfo=di),
283+
memlet=dace.Memlet(data=result_name, subset=neighbor_index, debuginfo=di),
284284
src_conn="__result",
285285
)
286286

@@ -315,7 +315,7 @@ def builtin_can_deref(
315315
"_out",
316316
result_node,
317317
None,
318-
dace.Memlet.simple(result_name, "0", debuginfo=di),
318+
dace.Memlet(data=result_name, subset="0", debuginfo=di),
319319
)
320320
return [ValueExpr(result_node, dace.dtypes.bool)]
321321

@@ -385,7 +385,7 @@ def builtin_list_get(
385385
transformer.context.state.add_nedge(
386386
args[1].value,
387387
result_node,
388-
dace.Memlet.simple(args[1].value.data, index_value),
388+
dace.Memlet(data=args[1].value.data, subset=index_value),
389389
)
390390
return [ValueExpr(result_node, args[1].dtype)]
391391

@@ -634,7 +634,7 @@ def visit_Lambda(
634634
lambda_state.add_nedge(
635635
expr.value,
636636
result_access,
637-
dace.Memlet.simple(result_access.data, "0"),
637+
dace.Memlet(data=result_access.data, subset="0"),
638638
)
639639
result = ValueExpr(value=result_access, dtype=expr.dtype)
640640
else:
@@ -801,7 +801,7 @@ def _visit_deref(self, node: itir.FunCall) -> list[ValueExpr]:
801801
iterator.indices[dim] for dim in sorted_dims if dim in iterator.indices
802802
]
803803
deref_memlets = [dace.Memlet.from_array(iterator.field.data, field_array)] + [
804-
dace.Memlet.simple(node.data, "0") for node in deref_nodes[1:]
804+
dace.Memlet(data=node.data, subset="0") for node in deref_nodes[1:]
805805
]
806806

807807
# we create a mapped tasklet for array slicing
@@ -927,7 +927,7 @@ def visit_OffsetLiteral(self, node: itir.OffsetLiteral) -> list[ValueExpr]:
927927
"get_offset", {}, {"__out"}, f"__out = {offset}", debuginfo=di
928928
)
929929
self.context.state.add_edge(
930-
tasklet_node, "__out", offset_node, None, dace.Memlet.simple(offset_var, "0")
930+
tasklet_node, "__out", offset_node, None, dace.Memlet(data=offset_var, subset="0")
931931
)
932932
return [ValueExpr(offset_node, self.context.body.arrays[offset_var].dtype)]
933933

@@ -1036,7 +1036,7 @@ def _visit_reduce(self, node: itir.FunCall):
10361036
dace.Memlet.from_array(reduce_input_node.data, reduce_input_desc),
10371037
)
10381038
self.context.state.add_nedge(
1039-
reduce_node, result_access, dace.Memlet.simple(result_name, "0")
1039+
reduce_node, result_access, dace.Memlet(data=result_name, subset="0")
10401040
)
10411041

10421042
# we apply map fusion only to the nested-SDFG which is generated for the reduction operator
@@ -1108,7 +1108,7 @@ def add_expr_tasklet(
11081108
)
11091109
self.context.state.add_edge(arg.value, None, expr_tasklet, internal, memlet)
11101110

1111-
memlet = dace.Memlet.simple(result_access.data, "0", debuginfo=di)
1111+
memlet = dace.Memlet(data=result_access.data, subset="0", debuginfo=di)
11121112
self.context.state.add_edge(expr_tasklet, "__result", result_access, None, memlet)
11131113

11141114
return [ValueExpr(result_access, result_type)]
@@ -1140,7 +1140,7 @@ def closure_to_tasklet_sdfg(
11401140
)
11411141
access = state.add_access(name, debuginfo=body.debuginfo)
11421142
idx_accesses[dim] = access
1143-
state.add_edge(tasklet, "value", access, None, dace.Memlet.simple(name, "0"))
1143+
state.add_edge(tasklet, "value", access, None, dace.Memlet(data=name, subset="0"))
11441144
for name, ty in inputs:
11451145
if isinstance(ty, ts.FieldType):
11461146
ndim = len(ty.dims)

src/gt4py/next/program_processors/runners/dace_iterator/utility.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ def create_memlet_full(source_identifier: str, source_array: dace.data.Array):
6969

7070
def create_memlet_at(source_identifier: str, index: tuple[str, ...]):
7171
subset = ", ".join(index)
72-
return dace.Memlet.simple(source_identifier, subset)
72+
return dace.Memlet(data=source_identifier, subset=subset)
7373

7474

7575
def get_sorted_dims(dims: Sequence[Dimension]) -> Sequence[tuple[int, Dimension]]:

0 commit comments

Comments
 (0)