Skip to content

Commit

Permalink
Merge pull request #1093 from spcl/users/lukas/loc-storage-fix
Browse files Browse the repository at this point in the history
Additional local storage tests
  • Loading branch information
tbennun authored Aug 23, 2022
2 parents 330861a + 2b6269f commit 5636931
Showing 1 changed file with 218 additions and 0 deletions.
218 changes: 218 additions & 0 deletions tests/transformations/local_storage_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,223 @@
import dace
import numpy as np
from dace.transformation.dataflow import MapTiling, OutLocalStorage
from dace.transformation.dataflow.local_storage import InLocalStorage

import dace.transformation.helpers as xfh

N = dace.symbol('N')


@dace.program
def copy_sdfg(A: dace.float32[N, N], B: dace.float32[N, N]):
for i, j in dace.map[0:N, 0:N]:
with dace.tasklet:
a << A[i, j]
b >> B[i, j]
b = a


def find_map_entries(sdfg):
outer_map_entry = None
inner_map_entry = None
for node in sdfg.start_state.nodes():
if not isinstance(node, dace.nodes.MapEntry):
continue

if xfh.get_parent_map(sdfg.start_state, node) is None:
assert outer_map_entry is None
outer_map_entry = node
else:
assert inner_map_entry is None
inner_map_entry = node
assert not outer_map_entry is None
assert not inner_map_entry is None

return outer_map_entry, inner_map_entry


def test_in_local_storage_explicit():
sdfg = copy_sdfg.to_sdfg()
sdfg.simplify()

sdfg.apply_transformations([MapTiling], options=[{"tile_sizes": [8]}])

outer_map_entry, inner_map_entry = find_map_entries(sdfg)

InLocalStorage.apply_to(sdfg=sdfg,
node_a=outer_map_entry,
node_b=inner_map_entry,
options={
"array": "A",
"create_array": True,
"prefix": "loc_"
},
save=True)

# Finding relevant node
local_storage_node = None
for node in sdfg.start_state.nodes():
if not isinstance(node, dace.nodes.AccessNode):
continue

if node.data == "loc_A":
assert local_storage_node is None
local_storage_node = node
break

assert not local_storage_node is None

# Check transient array created
trans_array = local_storage_node.data
assert trans_array in sdfg.arrays

# Check properties
desc = sdfg.arrays[local_storage_node.data]
assert desc.shape == (8, 8)
assert desc.transient == True

# Check array was set correctly
serialized = sdfg.transformation_hist[0].to_json()
assert serialized["array"] == "A"


def test_in_local_storage_implicit():
sdfg = copy_sdfg.to_sdfg()
sdfg.simplify()

sdfg.apply_transformations([MapTiling], options=[{"tile_sizes": [8]}])

outer_map_entry, inner_map_entry = find_map_entries(sdfg)

InLocalStorage.apply_to(sdfg=sdfg,
node_a=outer_map_entry,
node_b=inner_map_entry,
options={
"create_array": True,
"prefix": "loc_"
},
save=True)

# Finding relevant node
local_storage_node = None
for node in sdfg.start_state.nodes():
if not isinstance(node, dace.nodes.AccessNode):
continue

if node.data == "loc_A":
assert local_storage_node is None
local_storage_node = node
break

assert not local_storage_node is None

# Check transient array created
trans_array = local_storage_node.data
assert trans_array in sdfg.arrays

# Check properties
desc = sdfg.arrays[local_storage_node.data]
assert desc.shape == (8, 8)
assert desc.transient == True

# Check array was set correctly
serialized = sdfg.transformation_hist[0].to_json()
assert serialized["array"] == None


def test_out_local_storage_explicit():
sdfg = copy_sdfg.to_sdfg()
sdfg.simplify()

sdfg.apply_transformations([MapTiling], options=[{"tile_sizes": [8]}])

outer_map_entry, inner_map_entry = find_map_entries(sdfg)
outer_map_exit = sdfg.start_state.exit_node(outer_map_entry)
inner_map_exit = sdfg.start_state.exit_node(inner_map_entry)

OutLocalStorage.apply_to(sdfg=sdfg,
node_a=inner_map_exit,
node_b=outer_map_exit,
options={
"array": "B",
"create_array": True,
"prefix": "loc_"
},
save=True)

# Finding relevant node
local_storage_node = None
for node in sdfg.start_state.nodes():
if not isinstance(node, dace.nodes.AccessNode):
continue

if node.data == "loc_B":
assert local_storage_node is None
local_storage_node = node
break

assert not local_storage_node is None

# Check transient array created
trans_array = local_storage_node.data
assert trans_array in sdfg.arrays

# Check properties
desc = sdfg.arrays[local_storage_node.data]
assert desc.shape == (8, 8)
assert desc.transient == True

# Check array was set correctly
serialized = sdfg.transformation_hist[0].to_json()
assert serialized["array"] == "B"


def test_out_local_storage_implicit():
sdfg = copy_sdfg.to_sdfg()
sdfg.simplify()

sdfg.apply_transformations([MapTiling], options=[{"tile_sizes": [8]}])

outer_map_entry, inner_map_entry = find_map_entries(sdfg)
outer_map_exit = sdfg.start_state.exit_node(outer_map_entry)
inner_map_exit = sdfg.start_state.exit_node(inner_map_entry)

OutLocalStorage.apply_to(sdfg=sdfg,
node_a=inner_map_exit,
node_b=outer_map_exit,
options={
"create_array": True,
"prefix": "loc_"
},
save=True)

# Finding relevant node
local_storage_node = None
for node in sdfg.start_state.nodes():
if not isinstance(node, dace.nodes.AccessNode):
continue

if node.data == "loc_B":
assert local_storage_node is None
local_storage_node = node
break

assert not local_storage_node is None

# Check transient array created
trans_array = local_storage_node.data
assert trans_array in sdfg.arrays

# Check properties
desc = sdfg.arrays[local_storage_node.data]
assert desc.shape == (8, 8)
assert desc.transient == True

# Check array was set correctly
serialized = sdfg.transformation_hist[0].to_json()
assert serialized["array"] == None


@dace.program
def arange():
out = np.ndarray([N], np.int32)
Expand All @@ -18,6 +231,7 @@ def arange():


class LocalStorageTests(unittest.TestCase):

def test_even(self):
sdfg = arange.to_sdfg()
sdfg.apply_transformations([MapTiling, OutLocalStorage], options=[{'tile_sizes': [8]}, {}])
Expand All @@ -37,3 +251,7 @@ def test_uneven(self):

if __name__ == '__main__':
unittest.main()
test_in_local_storage_explicit()
test_in_local_storage_implicit()
test_out_local_storage_explicit()
test_out_local_storage_implicit()

0 comments on commit 5636931

Please sign in to comment.