diff --git a/pyop2/types/dat.py b/pyop2/types/dat.py index 5ee339bcc..fb877c1a8 100644 --- a/pyop2/types/dat.py +++ b/pyop2/types/dat.py @@ -681,6 +681,7 @@ def __init__(self, dat, index): if not (0 <= i < d): raise ex.IndexValueError("Can't create DatView with index %s for Dat with shape %s" % (index, dat.dim)) self.index = index + self._idx = (slice(None), *index) self._parent = dat # Point at underlying data super(DatView, self).__init__(dat.dataset, @@ -720,41 +721,37 @@ def halo_valid(self): def halo_valid(self, value): self._parent.halo_valid = value + @property + def dat_version(self): + return self._parent.dat_version + + @property + def _data(self): + return self._parent._data[self._idx] + @property def data(self): - full = self._parent.data - idx = (slice(None), *self.index) - return full[idx] + return self._parent.data[self._idx] @property def data_ro(self): - full = self._parent.data_ro - idx = (slice(None), *self.index) - return full[idx] + return self._parent.data_ro[self._idx] @property def data_wo(self): - full = self._parent.data_wo - idx = (slice(None), *self.index) - return full[idx] + return self._parent.data_wo[self._idx] @property def data_with_halos(self): - full = self._parent.data_with_halos - idx = (slice(None), *self.index) - return full[idx] + return self._parent.data_with_halos[self._idx] @property def data_ro_with_halos(self): - full = self._parent.data_ro_with_halos - idx = (slice(None), *self.index) - return full[idx] + return self._parent.data_ro_with_halos[self._idx] @property def data_wo_with_halos(self): - full = self._parent.data_wo_with_halos - idx = (slice(None), *self.index) - return full[idx] + return self._parent.data_wo_with_halos[self._idx] class Dat(AbstractDat, VecAccessMixin): diff --git a/test/unit/test_dats.py b/test/unit/test_dats.py index d43b5a1e4..2b8cf2efb 100644 --- a/test/unit/test_dats.py +++ b/test/unit/test_dats.py @@ -55,6 +55,16 @@ def mdat(d1): return op2.MixedDat([d1, d1]) +@pytest.fixture(scope='module') +def s2(s): + return op2.DataSet(s, 2) + + +@pytest.fixture +def vdat(s2): + return op2.Dat(s2, np.zeros(2 * nelems), dtype=np.float64) + + class TestDat: """ @@ -254,6 +264,60 @@ def test_accessing_data_with_halos_increments_dat_version(self, d1): assert d1.dat_version == 1 +class TestDatView(): + + def test_dat_view_assign(self, vdat): + vdat.data[:, 0] = 3 + vdat.data[:, 1] = 4 + comp = op2.DatView(vdat, 1) + comp.data[:] = 7 + assert not vdat.halo_valid + assert not comp.halo_valid + + expected = np.zeros_like(vdat.data) + expected[:, 0] = 3 + expected[:, 1] = 7 + assert all(comp.data == expected[:, 1]) + assert all(vdat.data[:, 0] == expected[:, 0]) + assert all(vdat.data[:, 1] == expected[:, 1]) + + def test_dat_view_zero(self, vdat): + vdat.data[:, 0] = 3 + vdat.data[:, 1] = 4 + comp = op2.DatView(vdat, 1) + comp.zero() + assert vdat.halo_valid + assert comp.halo_valid + + expected = np.zeros_like(vdat.data) + expected[:, 0] = 3 + expected[:, 1] = 0 + assert all(comp.data == expected[:, 1]) + assert all(vdat.data[:, 0] == expected[:, 0]) + assert all(vdat.data[:, 1] == expected[:, 1]) + + def test_dat_view_halo_valid(self, vdat): + """Check halo validity for DatView""" + comp = op2.DatView(vdat, 1) + assert vdat.halo_valid + assert comp.halo_valid + assert vdat.dat_version == 0 + assert comp.dat_version == 0 + + comp.data_ro_with_halos + assert vdat.halo_valid + assert comp.halo_valid + assert vdat.dat_version == 0 + assert comp.dat_version == 0 + + # accessing comp.data_with_halos should mark the parent halo as dirty + comp.data_with_halos + assert not vdat.halo_valid + assert not comp.halo_valid + assert vdat.dat_version == 1 + assert comp.dat_version == 1 + + if __name__ == '__main__': import os pytest.main(os.path.abspath(__file__))