Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DatView: Fix zero() #727

Merged
merged 1 commit into from
Sep 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 15 additions & 18 deletions pyop2/types/dat.py
Original file line number Diff line number Diff line change
Expand Up @@ -681,6 +681,7 @@ def __init__(self, dat, index):
if not (0 <= i < d):
raise ex.IndexValueError("Can't create DatView with index %s for Dat with shape %s" % (index, dat.dim))
self.index = index
self._idx = (slice(None), *index)
self._parent = dat
# Point at underlying data
super(DatView, self).__init__(dat.dataset,
Expand Down Expand Up @@ -720,41 +721,37 @@ def halo_valid(self):
def halo_valid(self, value):
self._parent.halo_valid = value

@property
def dat_version(self):
return self._parent.dat_version

@property
def _data(self):
return self._parent._data[self._idx]
dham marked this conversation as resolved.
Show resolved Hide resolved

@property
def data(self):
full = self._parent.data
idx = (slice(None), *self.index)
return full[idx]
return self._parent.data[self._idx]

@property
def data_ro(self):
full = self._parent.data_ro
idx = (slice(None), *self.index)
return full[idx]
return self._parent.data_ro[self._idx]

@property
def data_wo(self):
full = self._parent.data_wo
idx = (slice(None), *self.index)
return full[idx]
return self._parent.data_wo[self._idx]

@property
def data_with_halos(self):
full = self._parent.data_with_halos
idx = (slice(None), *self.index)
return full[idx]
return self._parent.data_with_halos[self._idx]

@property
def data_ro_with_halos(self):
full = self._parent.data_ro_with_halos
idx = (slice(None), *self.index)
return full[idx]
return self._parent.data_ro_with_halos[self._idx]

@property
def data_wo_with_halos(self):
full = self._parent.data_wo_with_halos
idx = (slice(None), *self.index)
return full[idx]
return self._parent.data_wo_with_halos[self._idx]


class Dat(AbstractDat, VecAccessMixin):
Expand Down
64 changes: 64 additions & 0 deletions test/unit/test_dats.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,16 @@ def mdat(d1):
return op2.MixedDat([d1, d1])


@pytest.fixture(scope='module')
def s2(s):
return op2.DataSet(s, 2)


@pytest.fixture
def vdat(s2):
return op2.Dat(s2, np.zeros(2 * nelems), dtype=np.float64)


class TestDat:

"""
Expand Down Expand Up @@ -254,6 +264,60 @@ def test_accessing_data_with_halos_increments_dat_version(self, d1):
assert d1.dat_version == 1


class TestDatView():

def test_dat_view_assign(self, vdat):
vdat.data[:, 0] = 3
vdat.data[:, 1] = 4
comp = op2.DatView(vdat, 1)
comp.data[:] = 7
assert not vdat.halo_valid
assert not comp.halo_valid

expected = np.zeros_like(vdat.data)
expected[:, 0] = 3
expected[:, 1] = 7
assert all(comp.data == expected[:, 1])
assert all(vdat.data[:, 0] == expected[:, 0])
assert all(vdat.data[:, 1] == expected[:, 1])

def test_dat_view_zero(self, vdat):
vdat.data[:, 0] = 3
vdat.data[:, 1] = 4
comp = op2.DatView(vdat, 1)
comp.zero()
assert vdat.halo_valid
assert comp.halo_valid

expected = np.zeros_like(vdat.data)
expected[:, 0] = 3
expected[:, 1] = 0
assert all(comp.data == expected[:, 1])
assert all(vdat.data[:, 0] == expected[:, 0])
assert all(vdat.data[:, 1] == expected[:, 1])

def test_dat_view_halo_valid(self, vdat):
"""Check halo validity for DatView"""
comp = op2.DatView(vdat, 1)
assert vdat.halo_valid
assert comp.halo_valid
assert vdat.dat_version == 0
assert comp.dat_version == 0

comp.data_ro_with_halos
assert vdat.halo_valid
assert comp.halo_valid
assert vdat.dat_version == 0
assert comp.dat_version == 0

# accessing comp.data_with_halos should mark the parent halo as dirty
comp.data_with_halos
assert not vdat.halo_valid
assert not comp.halo_valid
assert vdat.dat_version == 1
assert comp.dat_version == 1


if __name__ == '__main__':
import os
pytest.main(os.path.abspath(__file__))
Loading