Skip to content

Commit

Permalink
Merge pull request #170 from l-johnston/issue_169
Browse files Browse the repository at this point in the history
Fix unyt_array.__getitem__ to support numpy masked arrays
  • Loading branch information
jzuhone authored Aug 19, 2021
2 parents 9bebe05 + f427623 commit e86c050
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 8 deletions.
15 changes: 7 additions & 8 deletions unyt/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -1679,14 +1679,13 @@ def ua(self):

def __getitem__(self, item):
ret = super(unyt_array, self).__getitem__(item)
if ret.shape == ():
return unyt_quantity(
ret, self.units, bypass_validation=True, name=self.name
)
else:
if hasattr(self, "units"):
ret.units = self.units
return ret
if getattr(ret, "shape", None) == ():
ret = unyt_quantity(ret, bypass_validation=True, name=self.name)
try:
setattr(ret, "units", self.units)
except AttributeError:
pass
return ret

#
# Start operation methods
Expand Down
24 changes: 24 additions & 0 deletions unyt/tests/test_unyt_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -2425,6 +2425,30 @@ def test_ksi():
assert_allclose_units(unyt_quantity(1, "lbf/inch**2"), unyt_quantity(0.001, "ksi"))


def test_masked_array():
data = unyt_array([1, 2, 3], "s")
mask = [False, False, True]
marr = np.ma.MaskedArray(data, mask)
assert_array_equal(marr.data, data)
assert all(marr.mask == mask)
assert marr.sum() == unyt_quantity(3, "s")
assert np.ma.notmasked_contiguous(marr) == [slice(0, 2, None)]
assert marr.argmax() == 1
assert marr.max() == unyt_quantity(2, "s")
data = unyt_array([1, 2, np.inf], "s")
marr = np.ma.MaskedArray(data)
marr_masked = np.ma.masked_invalid(marr)
assert all(marr_masked.mask == [False, False, True])
marr_masked.set_fill_value(unyt_quantity(3, "s"))
assert_array_equal(marr_masked.filled(), unyt_array([1, 2, 3], "s"))
marr_fixed = np.ma.fix_invalid(marr)
assert_array_equal(marr_fixed.data, unyt_array([1, 2, 1e20], "s"))
assert_array_equal(np.ma.filled(marr, unyt_quantity(3, "s")), data)
assert_array_equal(np.ma.compressed(marr_masked), unyt_array([1, 2], "s"))
# executing the repr should not raise an exception
marr.__repr__()


def test_complexvalued():
freq = unyt_array([1j, 1j * 10], "Hz")
arr = 1 / (Unit("F") * Unit("Ω") * freq)
Expand Down

0 comments on commit e86c050

Please sign in to comment.