Skip to content

Commit

Permalink
Merge pull request #4414 from chrishavlin/gadget_header_array_reduction
Browse files Browse the repository at this point in the history
  • Loading branch information
neutrinoceros authored Jun 20, 2023
2 parents 4fcbc8c + 10b8a8b commit 84a035d
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 0 deletions.
1 change: 1 addition & 0 deletions nose_ignores.txt
Original file line number Diff line number Diff line change
Expand Up @@ -34,4 +34,5 @@
--ignore-file=test_stream_particles\.py
--ignore-file=test_stream_stretched\.py
--ignore-file=test_version\.py
--ignore-file=test_gadget_pytest\.py
--ignore-file=test_vr_orientation\.py
1 change: 1 addition & 0 deletions tests/tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,7 @@ other_tests:
- "--ignore-file=test_stream_particles\\.py"
- "--ignore-file=test_stream_stretched\\.py"
- "--ignore-file=test_version\\.py"
- "--ignore-file=test_gadget_pytest\\.py"
- "--ignore-file=test_vr_orientation\\.py"
- "--exclude-test=yt.frontends.gdf.tests.test_outputs.TestGDF"
- "--exclude-test=yt.frontends.adaptahop.tests.test_outputs"
Expand Down
9 changes: 9 additions & 0 deletions yt/frontends/gadget/data_structures.py
Original file line number Diff line number Diff line change
Expand Up @@ -629,6 +629,15 @@ def _get_hvals(self):
if "Parameters" in handle:
hvals.update((str(k), v) for k, v in handle["/Parameters"].attrs.items())
handle.close()

# ensure that 1-element arrays are reduced to scalars
updated_hvals = {}
for hvalname, value in hvals.items():
if isinstance(value, np.ndarray) and value.size == 1:
mylog.info("Reducing single-element array %s to scalar.", hvalname)
updated_hvals[hvalname] = value.item()
hvals.update(updated_hvals)

return hvals

def _get_uvals(self):
Expand Down
34 changes: 34 additions & 0 deletions yt/frontends/gadget/tests/test_gadget_pytest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import numpy as np

import yt
from yt.testing import requires_file, requires_module
from yt.utilities.on_demand_imports import _h5py as h5py


@requires_file("snapshot_033/snap_033.0.hdf5")
@requires_module("h5py")
def test_gadget_header_array_reduction(tmp_path):
# first get a real header
ds = yt.load("snapshot_033/snap_033.0.hdf5")
hvals = ds._get_hvals()
hvals_orig = hvals.copy()
# wrap some of the scalar values in nested arrays
hvals["Redshift"] = np.array([hvals["Redshift"]])
hvals["Omega0"] = np.array([[hvals["Omega0"]]])

# drop those header values into a fake header-only file
tmp_snpshot_dir = tmp_path / "snapshot_033"
tmp_snpshot_dir.mkdir()
tmp_header_only_file = str(tmp_snpshot_dir / "fake_gadget_header.hdf5")
with h5py.File(tmp_header_only_file, mode="w") as f:
headergrp = f.create_group("Header")
for field, val in hvals.items():
headergrp.attrs[field] = val

# trick the dataset into using the header file and make sure the
# arrays are reduced
ds._input_filename = tmp_header_only_file
hvals = ds._get_hvals()
for attr in ("Redshift", "Omega0"):
assert hvals[attr] == hvals_orig[attr]
assert isinstance(hvals[attr], np.ndarray) is False

0 comments on commit 84a035d

Please sign in to comment.