diff --git a/benchmarks/benchmark_xyz.py b/benchmarks/benchmark_xyz.py new file mode 100644 index 0000000..7e2d1ce --- /dev/null +++ b/benchmarks/benchmark_xyz.py @@ -0,0 +1,47 @@ +import pathlib +import typing + +import flour +import numpy as np +import pytest + + +@pytest.fixture +def structures() -> list[flour.XyzData]: + num_atoms = 100 + rng = np.random.default_rng(11) + return [ + flour.XyzData( + comment="Test comment", + elements=["C"]*num_atoms, + positions=rng.random((num_atoms, 3)) * 100, + ) + for _ in range(500) + ] + + +@pytest.mark.benchmark(group="write_xyz") +def benchmark_write_xyz( + benchmark: typing.Any, + tmp_path: pathlib.Path, + structures: list[flour.XyzData], +) -> None: + benchmark( + flour.write_xyz, + path=tmp_path / "bench.xyz", + xyz_structures=structures + ) + + +@pytest.mark.benchmark(group="read_xyz") +def benchmark_read_xyz( + benchmark: typing.Any, + tmp_path: pathlib.Path, + structures: list[flour.XyzData], +) -> None: + path = tmp_path / "bench.xyz" + flour.write_xyz( + path=path, + xyz_structures=structures + ) + benchmark(flour.read_xyz, path) diff --git a/flour.pyi b/flour.pyi index cef22e4..7d5e2ec 100644 --- a/flour.pyi +++ b/flour.pyi @@ -29,3 +29,15 @@ def write_cube( voxels: npt.NDArray[np.float64], ) -> None: pass + + +class XyzData: + elements: list[str] + comment: str + positions: npt.NDArray[np.float64] + +def read_xyz(path: pathlib.Path | str) -> list[XyzData]: + pass + +def write_xyz(path: pathlib.Path | str, xyz_structures: list[XyzData]) -> None: + pass diff --git a/src/lib.rs b/src/lib.rs index d2f1a26..2d24b24 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -4,7 +4,7 @@ use itertools::izip; use numpy::convert::IntoPyArray; use numpy::ndarray::Axis; use numpy::{PyArray1, PyArray2, PyArray3, PyReadonlyArray1, PyReadonlyArray2, PyReadonlyArray3}; -use pyo3::{exceptions::PyRuntimeError, prelude::*}; +use pyo3::{exceptions::PyRuntimeError, prelude::*, Python}; #[pyclass] struct VoxelGrid { @@ -192,10 +192,95 @@ fn write_cube( Ok(()) } +#[pyclass] +struct XyzData { + #[pyo3(get)] + elements: Vec, + #[pyo3(get)] + comment: String, + #[pyo3(get)] + positions: Py>, +} + +#[pymethods] +impl XyzData { + #[new] + fn new(comment: String, elements: Vec, positions: Py>) -> Self { + XyzData { + elements, + comment, + positions, + } + } +} + +/// Read a `.xyz` file. +#[pyfunction] +fn read_xyz(py: Python, path: PathBuf) -> PyResult> { + let mut multi_xyz_data: Vec = vec![]; + let contents = fs::read_to_string(path)?; + let mut lines = contents.lines(); + while let (Some(first_line), Some(second_line)) = (lines.next(), lines.next()) { + let mut first_line_words = first_line.split_ascii_whitespace(); + let Some(num_atoms) = first_line_words.next() else {break}; + let num_atoms = num_atoms.parse::()?; + + let mut elements = Vec::with_capacity(num_atoms); + let mut positions = Vec::with_capacity(num_atoms * 3); + for _ in 0..num_atoms { + let atom_line = lines.next().ok_or_else(|| { + PyRuntimeError::new_err("xyz file is missing atom definition line") + })?; + let mut words = atom_line.split_ascii_whitespace(); + let element = words + .next() + .ok_or_else(|| PyRuntimeError::new_err("xyz file is missing element symbol"))?; + elements.push(element.to_string()); + positions.extend(words.map(|word: &str| word.parse::().unwrap())); + } + + multi_xyz_data.push(XyzData { + elements, + comment: second_line.to_string(), + positions: positions + .into_pyarray(py) + .reshape([num_atoms, 3])? + .to_owned(), + }); + } + Ok(multi_xyz_data) +} + +#[pyfunction] +fn write_xyz(path: PathBuf, xyz_structures: Vec>) -> PyResult<()> { + let mut content = String::new(); + + for xyz_data in xyz_structures { + let positions = Python::with_gil(|py| xyz_data.positions.as_ref(py).to_owned_array()); + + content.push_str(&xyz_data.elements.len().to_string()); + content.push('\n'); + content.push_str(&xyz_data.comment); + content.push('\n'); + + izip!(&xyz_data.elements, positions.axis_iter(Axis(0))).for_each(|(element, position)| { + content.push_str(&format!( + "{} {: >11.6} {: >11.6} {: >11.6} \n", + element, position[0], position[1], position[2], + )) + }); + } + fs::write(path, content)?; + Ok(()) +} + /// A Python module implemented in Rust. #[pymodule] fn flour(_py: Python, m: &PyModule) -> PyResult<()> { m.add_function(wrap_pyfunction!(read_cube, m)?)?; m.add_function(wrap_pyfunction!(write_cube, m)?)?; + m.add_function(wrap_pyfunction!(write_xyz, m)?)?; + m.add_function(wrap_pyfunction!(read_xyz, m)?)?; + m.add_class::()?; Ok(()) } diff --git a/tests/test_xyz.py b/tests/test_xyz.py new file mode 100644 index 0000000..395307d --- /dev/null +++ b/tests/test_xyz.py @@ -0,0 +1,45 @@ +import pathlib + +import flour +import numpy as np + + +def test_xyz( + tmp_path: pathlib.Path, +) -> None: + + xyz_path = tmp_path / "molecule.xyz" + comment = "Test comment!" + elements = ['Pd', 'C', 'O', 'H'] + positions = np.array( + [ + [1.0, 2.0, 3.0], + [4.0, 5.0, 6.0], + [7.0, 8.0, 9.0], + [10.0, 11.0, 12.0], + ] + ) + xyz_data_0 = flour.XyzData( + comment=comment, + elements=elements, + positions=positions + ) + xyz_data_1 = flour.XyzData( + comment=comment+'!', + elements=elements[::-1], + positions=positions*-1 + ) + + flour.write_xyz( + path=xyz_path, + xyz_structures=[xyz_data_0, xyz_data_1], + ) + + xyz_structures = flour.read_xyz(xyz_path) + assert comment == xyz_structures[0].comment + assert np.all(np.equal(elements, xyz_structures[0].elements)) + assert np.all(np.isclose(positions, xyz_structures[0].positions)) + + assert xyz_data_1.comment == xyz_structures[1].comment + assert np.all(np.equal(xyz_data_1.elements, xyz_structures[1].elements)) + assert np.all(np.isclose(xyz_data_1.positions, xyz_structures[1].positions))