Skip to content

Commit

Permalink
always cast atoms.arrays to numpy (#162)
Browse files Browse the repository at this point in the history
* always cast `atoms.arrays` to numpy

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* remove comment

* update test

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
PythonFZ and pre-commit-ci[bot] authored Jan 15, 2025
1 parent b052fd6 commit ed78db8
Show file tree
Hide file tree
Showing 5 changed files with 30 additions and 4 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "znh5md"
version = "0.4.2"
version = "0.4.3"
description = "ASE Interface for the H5MD format."
authors = ["zincwarecode <[email protected]>"]
license = "Apache-2.0"
Expand Down
2 changes: 1 addition & 1 deletion tests/format/test_info_array_calc.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

# Define assertion functions for different data types
def assert_equal(actual, expected):
assert actual == expected
npt.assert_equal(actual, expected)


def assert_allclose(actual, expected):
Expand Down
18 changes: 18 additions & 0 deletions tests/test_arrays_str.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import numpy as np
import numpy.testing as npt
import rdkit2ase

import znh5md


def test_pdb_repeat(tmp_path):
water = rdkit2ase.smiles2conformers(smiles="O", numConfs=1)[0]
water.arrays["atomtypes"] = np.array(["H", "O", "H"])
assert isinstance(water.arrays["atomtypes"], np.ndarray)

io = znh5md.IO(tmp_path / "test.h5")
io.append(water)

atoms = io[0]
assert isinstance(atoms.arrays["atomtypes"], np.ndarray)
npt.assert_array_equal(atoms.arrays["atomtypes"], water.arrays["atomtypes"])
2 changes: 1 addition & 1 deletion tests/test_znh5md.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@


def test_version():
assert znh5md.__version__ == "0.4.2"
assert znh5md.__version__ == "0.4.3"


def test_creator(tmp_path):
Expand Down
10 changes: 9 additions & 1 deletion znh5md/interface/read.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import json
import typing as t
import warnings

import ase
import numpy as np
Expand Down Expand Up @@ -114,7 +115,14 @@ def handle_origin_data(self, name: str, data: list, origin: ORIGIN_TYPE) -> None
elif origin == "info":
self.info[name] = data
elif origin == "arrays":
self.arrays[name] = data
try:
self.arrays[name] = np.array(data)
except ValueError:
warnings.warn(
f"Could not convert data to array for '{name}'. "
"Storing as list instead."
)
self.arrays[name] = data
elif origin == "atoms":
raise ValueError(f"Origin 'atoms' is not allowed for {name}")
else:
Expand Down

0 comments on commit ed78db8

Please sign in to comment.