Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add XYZ file support #13

Draft
wants to merge 19 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 17 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 46 additions & 0 deletions benchmarks/benchmark_xyz.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import pathlib
import typing

import flour
import numpy as np
import pytest


def create_xyz_data(i: int):
num_atoms = 100
return flour.XyzData(
comment="Test comment!",
elements=['C']*num_atoms,
positions=np.random.default_rng(i).random((num_atoms, 3)) * 100
)


def create_xyz_structures():
return [create_xyz_data(i) for i in range(500)]


@pytest.mark.benchmark(group="write_xyz")
def benchmark_write_xyz(
benchmark: typing.Any,
tmp_path: pathlib.Path,
) -> None:
xyz_structures = create_xyz_structures()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i would create a fixture

@fixture
def structures() -> list[flour.XyzData]:
    num_atoms = 100
    rng = np.random.default_rng(11)
    return [
        flour.XyzData(
            comment="Test comment",
            elements=["C"]*num_atoms,
            positions=rng.random((num_atoms, 3)) * 100,
        )
        for _ in range(500)
    ]

which i would use in the benchmarks

def benchmark_write_xyz(
    benchmark: typing.Any,
    tmp_path: pathlib.Path,
    structures: list[flour.XyzData],
) -> None:
    ...

benchmark(
flour.write_xyz,
path=tmp_path / "bench.xyz",
xyz_structures=xyz_structures
)


@pytest.mark.benchmark(group="read_xyz")
def benchmark_read_xyz(
benchmark: typing.Any,
tmp_path: pathlib.Path,
) -> None:
path = tmp_path / "bench.xyz"
xyz_structures = create_xyz_structures()
flour.write_xyz(
path=path,
xyz_structures=xyz_structures
)
benchmark(flour.read_xyz, path)
12 changes: 12 additions & 0 deletions flour.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,15 @@ def write_cube(
voxels: npt.NDArray[np.float64],
) -> None:
pass


class XyzData:
elements: list[str]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use a numpy array of uint8 called atoms

this will be more consistent with cube file interface

comment: str
positions: npt.NDArray[np.float64]

def read_xyz(path: pathlib.Path | str) -> list[XyzData]:
pass

def write_xyz(path: pathlib.Path | str, xyz_structures: list[XyzData]) -> None:
pass
92 changes: 91 additions & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use itertools::izip;
use numpy::convert::IntoPyArray;
use numpy::ndarray::Axis;
use numpy::{PyArray1, PyArray2, PyArray3, PyReadonlyArray1, PyReadonlyArray2, PyReadonlyArray3};
use pyo3::{exceptions::PyRuntimeError, prelude::*};
use pyo3::{exceptions::PyRuntimeError, prelude::*, Python};

#[pyclass]
struct VoxelGrid {
Expand Down Expand Up @@ -192,10 +192,100 @@ fn write_cube(
Ok(())
}

#[pyclass]
struct XyzData {
#[pyo3(get)]
elements: Vec<String>,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use

#[pyo3(get)]
atoms: Py<PyArray1<u8>>,

instead of elements

this is more consistent with the cube file interface

also using Vec<String> would be less performant because of redundant copies and could lead to surprising behvaiour since it creates copies each time it is accessed -- see https://pyo3.rs/v0.20.0/faq.html#pyo3get-clones-my-field

#[pyo3(get)]
comment: String,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use Py<PyString> https://pyo3.rs/main/doc/pyo3/types/struct.pystring

here is an example of how to use the type

use pyo3::prelude::*;
use pyo3::types::PyString;

#[pyclass]
struct Foo {
    #[pyo3(get)]
    x: Py<PyString>,
}

#[pymethods]
impl Foo {
    #[new]
    fn new(py: Python, x: &str) -> PyResult<Self> {
        Ok(Foo {
            x: PyString::new(py, x).into_py(py),
        })
    }
}

#[pyo3(get)]
positions: Py<PyArray2<f64>>,
}

#[pymethods]
impl XyzData {
#[new]
fn new(comment: String, elements: Vec<String>, positions: Py<PyArray2<f64>>) -> Self {
XyzData {
elements,
comment,
positions,
}
}
}

/// Read a `.xyz` file.
#[pyfunction]
fn read_xyz(py: Python, path: PathBuf) -> PyResult<Vec<XyzData>> {
let mut multi_xyz_data: Vec<XyzData> = vec![];
let contents = fs::read_to_string(path)?;
let mut lines = contents.lines();
loop {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i would re-write this using awhile-let loop

while let (Some(first_line), Some(second_line)) = (lines.next(), lines.next()) {
...
}

let Some(first_line) = lines.next() else {break};
let mut first_line_words = first_line.split_ascii_whitespace();
let Some(num_atoms) = first_line_words.next() else {break};
let num_atoms = num_atoms.parse::<usize>()?;

let second_line = lines
.next()
.ok_or_else(|| PyRuntimeError::new_err("xyz file is missing lines"))?;

let mut elements = Vec::with_capacity(num_atoms);
let mut positions = Vec::with_capacity(num_atoms * 3);
for _ in 0..num_atoms {
let atom_line = lines.next().ok_or_else(|| {
PyRuntimeError::new_err("xyz file is missing atom definition line")
})?;
let mut words = atom_line.split_ascii_whitespace();
let element = words
.next()
.ok_or_else(|| PyRuntimeError::new_err("xyz file is missing element symbol"))?;
elements.push(element.to_string());
positions.extend(words.map(|word: &str| word.parse::<f64>().unwrap()));
}

multi_xyz_data.push(XyzData {
elements,
comment: second_line.to_string(),
positions: positions
.into_pyarray(py)
.reshape([num_atoms, 3])?
.to_owned(),
});
}
Ok(multi_xyz_data)
}

#[pyfunction]
fn write_xyz(path: PathBuf, xyz_structures: Vec<PyRef<XyzData>>) -> PyResult<()> {
let mut content = String::new();

for xyz_data in xyz_structures {
let positions = Python::with_gil(|py| xyz_data.positions.as_ref(py).to_owned_array());

content.push_str(&xyz_data.elements.len().to_string());
content.push('\n');
content.push_str(&xyz_data.comment);
content.push('\n');

izip!(&xyz_data.elements, positions.axis_iter(Axis(0))).for_each(|(element, position)| {
content.push_str(&format!(
"{} {: >11.6} {: >11.6} {: >11.6} \n",
element, position[0], position[1], position[2],
))
});
}
fs::write(path, content)?;
Ok(())
}

/// A Python module implemented in Rust.
#[pymodule]
fn flour(_py: Python, m: &PyModule) -> PyResult<()> {
m.add_function(wrap_pyfunction!(read_cube, m)?)?;
m.add_function(wrap_pyfunction!(write_cube, m)?)?;
m.add_function(wrap_pyfunction!(write_xyz, m)?)?;
m.add_function(wrap_pyfunction!(read_xyz, m)?)?;
m.add_class::<XyzData>()?;
Ok(())
}
46 changes: 46 additions & 0 deletions tests/test_xyz.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import pathlib

import flour
import numpy as np


def test_xyz(
tmp_path: pathlib.Path,
) -> None:

xyz_path = tmp_path / "molecule.xyz"
comment = "Test comment!"
elements = ['Pd', 'C', 'O', 'H']
positions = np.array(
[
[1.0, 2.0, 3.0],
[4.0, 5.0, 6.0],
[7.0, 8.0, 9.0],
[10.0, 11.0, 12.0],
]
)
xyz_data_0 = flour.XyzData(
comment=comment,
elements=elements,
positions=positions
)
xyz_data_1 = flour.XyzData(
comment=comment+'!',
elements=elements[::-1],
positions=positions*-1
)

flour.write_xyz(
path=xyz_path,
xyz_structures=[xyz_data_0, xyz_data_1],
)

xyz_structures = flour.read_xyz(xyz_path)
assert comment == xyz_structures[0].comment
assert np.all(np.equal(elements, xyz_structures[0].elements))
assert np.all(np.isclose(positions, xyz_structures[0].positions))

assert xyz_data_1.comment == xyz_structures[1].comment
assert np.all(np.equal(xyz_data_1.elements, xyz_structures[1].elements))
assert np.all(np.isclose(xyz_data_1.positions, xyz_structures[1].positions))

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

needs a newline

Loading