Skip to content

Commit

Permalink
fix dat version (#709)
Browse files Browse the repository at this point in the history
  • Loading branch information
ksagiyam authored Sep 27, 2023
1 parent 8e4ad93 commit da14715
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 0 deletions.
12 changes: 12 additions & 0 deletions pyop2/parloop.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from pyop2.local_kernel import LocalKernel, CStringLocalKernel, LoopyLocalKernel
from pyop2.types import (Access, Global, AbstractDat, Dat, DatView, MixedDat, Mat, Set,
MixedSet, ExtrudedSet, Subset, Map, ComposedMap, MixedMap)
from pyop2.types.data_carrier import DataCarrier
from pyop2.utils import cached_property


Expand Down Expand Up @@ -209,6 +210,7 @@ def compute(self):
@mpi.collective
def __call__(self):
"""Execute the kernel over all members of the iteration space."""
self.increment_dat_version()
self.zero_global_increments()
orig_lgmaps = self.replace_lgmaps()
self.global_to_local_begin()
Expand All @@ -223,6 +225,16 @@ def __call__(self):
self.finalize_global_increments()
self.local_to_global_end()

def increment_dat_version(self):
"""Increment dat versions of :class:`DataCarrier`s in the arguments."""
for lk_arg, gk_arg, pl_arg in self.zipped_arguments:
assert isinstance(pl_arg.data, DataCarrier)
if lk_arg.access is not Access.READ:
if pl_arg.data in self.reduced_globals:
self.reduced_globals[pl_arg.data].data.increment_dat_version()
else:
pl_arg.data.increment_dat_version()

def zero_global_increments(self):
"""Zero any global increments every time the loop is executed."""
for g in self.reduced_globals.keys():
Expand Down
4 changes: 4 additions & 0 deletions pyop2/types/dat.py
Original file line number Diff line number Diff line change
Expand Up @@ -828,6 +828,10 @@ def what(x):
def dat_version(self):
return sum(d.dat_version for d in self._dats)

def increment_dat_version(self):
for d in self:
d.increment_dat_version()

def __call__(self, access, path=None):
from pyop2.parloop import MixedDatLegacyArg
return MixedDatLegacyArg(self, path, access)
Expand Down
30 changes: 30 additions & 0 deletions test/unit/test_dats.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,17 @@ def test_dat_version(self, s, d1):
assert d1.dat_version == 4
assert d2.dat_version == 2

# ParLoop
d3 = op2.Dat(s ** 1, data=None, dtype=np.uint32)
assert d3.dat_version == 0
k = op2.Kernel("""
static void write(unsigned int* v) {
*v = 1;
}
""", "write")
op2.par_loop(k, s, d3(op2.WRITE))
assert d3.dat_version == 1

def test_mixed_dat_version(self, s, d1, mdat):
"""Check object versioning for MixedDat"""
d2 = op2.Dat(s)
Expand Down Expand Up @@ -216,6 +227,25 @@ def test_mixed_dat_version(self, s, d1, mdat):
assert mdat.dat_version == 8
assert mdat2.dat_version == 5

# ParLoop
d3 = op2.Dat(s ** 1, data=None, dtype=np.uint32)
d4 = op2.Dat(s ** 1, data=None, dtype=np.uint32)
d3d4 = op2.MixedDat([d3, d4])
assert d3.dat_version == 0
assert d4.dat_version == 0
assert d3d4.dat_version == 0
k = op2.Kernel("""
static void write(unsigned int* v) {
v[0] = 1;
v[1] = 2;
}
""", "write")
m = op2.Map(s, op2.Set(nelems), 1, values=[0, 1, 2, 3, 4])
op2.par_loop(k, s, d3d4(op2.WRITE, op2.MixedMap([m, m])))
assert d3.dat_version == 1
assert d4.dat_version == 1
assert d3d4.dat_version == 2

def test_accessing_data_with_halos_increments_dat_version(self, d1):
assert d1.dat_version == 0
d1.data_ro_with_halos
Expand Down

0 comments on commit da14715

Please sign in to comment.