-
Notifications
You must be signed in to change notification settings - Fork 2
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add XYZ file support #13
base: master
Are you sure you want to change the base?
Changes from all commits
7e6a6d3
61cd085
791475d
7af3bc9
804fc1c
c0ddac9
8e6ed28
af3c465
53e4532
caa4c64
09e6e4b
da1c9f5
8a689c2
0cbfb8f
4db2d68
1c9e102
4cf70cd
c223a14
ee3626c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
import pathlib | ||
import typing | ||
|
||
import flour | ||
import numpy as np | ||
import pytest | ||
|
||
|
||
@pytest.fixture | ||
def structures() -> list[flour.XyzData]: | ||
num_atoms = 100 | ||
rng = np.random.default_rng(11) | ||
return [ | ||
flour.XyzData( | ||
comment="Test comment", | ||
elements=["C"]*num_atoms, | ||
positions=rng.random((num_atoms, 3)) * 100, | ||
) | ||
for _ in range(500) | ||
] | ||
|
||
|
||
@pytest.mark.benchmark(group="write_xyz") | ||
def benchmark_write_xyz( | ||
benchmark: typing.Any, | ||
tmp_path: pathlib.Path, | ||
structures: list[flour.XyzData], | ||
) -> None: | ||
benchmark( | ||
flour.write_xyz, | ||
path=tmp_path / "bench.xyz", | ||
xyz_structures=structures | ||
) | ||
|
||
|
||
@pytest.mark.benchmark(group="read_xyz") | ||
def benchmark_read_xyz( | ||
benchmark: typing.Any, | ||
tmp_path: pathlib.Path, | ||
structures: list[flour.XyzData], | ||
) -> None: | ||
path = tmp_path / "bench.xyz" | ||
flour.write_xyz( | ||
path=path, | ||
xyz_structures=structures | ||
) | ||
benchmark(flour.read_xyz, path) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,7 +4,7 @@ use itertools::izip; | |
use numpy::convert::IntoPyArray; | ||
use numpy::ndarray::Axis; | ||
use numpy::{PyArray1, PyArray2, PyArray3, PyReadonlyArray1, PyReadonlyArray2, PyReadonlyArray3}; | ||
use pyo3::{exceptions::PyRuntimeError, prelude::*}; | ||
use pyo3::{exceptions::PyRuntimeError, prelude::*, Python}; | ||
|
||
#[pyclass] | ||
struct VoxelGrid { | ||
|
@@ -192,10 +192,95 @@ fn write_cube( | |
Ok(()) | ||
} | ||
|
||
#[pyclass] | ||
struct XyzData { | ||
#[pyo3(get)] | ||
elements: Vec<String>, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. use #[pyo3(get)]
atoms: Py<PyArray1<u8>>, instead of elements this is more consistent with the cube file interface also using |
||
#[pyo3(get)] | ||
comment: String, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Use here is an example of how to use the type use pyo3::prelude::*;
use pyo3::types::PyString;
#[pyclass]
struct Foo {
#[pyo3(get)]
x: Py<PyString>,
}
#[pymethods]
impl Foo {
#[new]
fn new(py: Python, x: &str) -> PyResult<Self> {
Ok(Foo {
x: PyString::new(py, x).into_py(py),
})
}
} |
||
#[pyo3(get)] | ||
positions: Py<PyArray2<f64>>, | ||
} | ||
|
||
#[pymethods] | ||
impl XyzData { | ||
#[new] | ||
fn new(comment: String, elements: Vec<String>, positions: Py<PyArray2<f64>>) -> Self { | ||
XyzData { | ||
elements, | ||
comment, | ||
positions, | ||
} | ||
} | ||
} | ||
|
||
/// Read a `.xyz` file. | ||
#[pyfunction] | ||
fn read_xyz(py: Python, path: PathBuf) -> PyResult<Vec<XyzData>> { | ||
let mut multi_xyz_data: Vec<XyzData> = vec![]; | ||
let contents = fs::read_to_string(path)?; | ||
let mut lines = contents.lines(); | ||
while let (Some(first_line), Some(second_line)) = (lines.next(), lines.next()) { | ||
let mut first_line_words = first_line.split_ascii_whitespace(); | ||
let Some(num_atoms) = first_line_words.next() else {break}; | ||
let num_atoms = num_atoms.parse::<usize>()?; | ||
|
||
let mut elements = Vec::with_capacity(num_atoms); | ||
let mut positions = Vec::with_capacity(num_atoms * 3); | ||
for _ in 0..num_atoms { | ||
let atom_line = lines.next().ok_or_else(|| { | ||
PyRuntimeError::new_err("xyz file is missing atom definition line") | ||
})?; | ||
let mut words = atom_line.split_ascii_whitespace(); | ||
let element = words | ||
.next() | ||
.ok_or_else(|| PyRuntimeError::new_err("xyz file is missing element symbol"))?; | ||
elements.push(element.to_string()); | ||
positions.extend(words.map(|word: &str| word.parse::<f64>().unwrap())); | ||
} | ||
|
||
multi_xyz_data.push(XyzData { | ||
elements, | ||
comment: second_line.to_string(), | ||
positions: positions | ||
.into_pyarray(py) | ||
.reshape([num_atoms, 3])? | ||
.to_owned(), | ||
}); | ||
} | ||
Ok(multi_xyz_data) | ||
} | ||
|
||
#[pyfunction] | ||
fn write_xyz(path: PathBuf, xyz_structures: Vec<PyRef<XyzData>>) -> PyResult<()> { | ||
let mut content = String::new(); | ||
|
||
for xyz_data in xyz_structures { | ||
let positions = Python::with_gil(|py| xyz_data.positions.as_ref(py).to_owned_array()); | ||
|
||
content.push_str(&xyz_data.elements.len().to_string()); | ||
content.push('\n'); | ||
content.push_str(&xyz_data.comment); | ||
content.push('\n'); | ||
|
||
izip!(&xyz_data.elements, positions.axis_iter(Axis(0))).for_each(|(element, position)| { | ||
content.push_str(&format!( | ||
"{} {: >11.6} {: >11.6} {: >11.6} \n", | ||
element, position[0], position[1], position[2], | ||
)) | ||
}); | ||
} | ||
fs::write(path, content)?; | ||
Ok(()) | ||
} | ||
|
||
/// A Python module implemented in Rust. | ||
#[pymodule] | ||
fn flour(_py: Python, m: &PyModule) -> PyResult<()> { | ||
m.add_function(wrap_pyfunction!(read_cube, m)?)?; | ||
m.add_function(wrap_pyfunction!(write_cube, m)?)?; | ||
m.add_function(wrap_pyfunction!(write_xyz, m)?)?; | ||
m.add_function(wrap_pyfunction!(read_xyz, m)?)?; | ||
m.add_class::<XyzData>()?; | ||
Ok(()) | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
import pathlib | ||
|
||
import flour | ||
import numpy as np | ||
|
||
|
||
def test_xyz( | ||
tmp_path: pathlib.Path, | ||
) -> None: | ||
|
||
xyz_path = tmp_path / "molecule.xyz" | ||
comment = "Test comment!" | ||
elements = ['Pd', 'C', 'O', 'H'] | ||
positions = np.array( | ||
[ | ||
[1.0, 2.0, 3.0], | ||
[4.0, 5.0, 6.0], | ||
[7.0, 8.0, 9.0], | ||
[10.0, 11.0, 12.0], | ||
] | ||
) | ||
xyz_data_0 = flour.XyzData( | ||
comment=comment, | ||
elements=elements, | ||
positions=positions | ||
) | ||
xyz_data_1 = flour.XyzData( | ||
comment=comment+'!', | ||
elements=elements[::-1], | ||
positions=positions*-1 | ||
) | ||
|
||
flour.write_xyz( | ||
path=xyz_path, | ||
xyz_structures=[xyz_data_0, xyz_data_1], | ||
) | ||
|
||
xyz_structures = flour.read_xyz(xyz_path) | ||
assert comment == xyz_structures[0].comment | ||
assert np.all(np.equal(elements, xyz_structures[0].elements)) | ||
assert np.all(np.isclose(positions, xyz_structures[0].positions)) | ||
|
||
assert xyz_data_1.comment == xyz_structures[1].comment | ||
assert np.all(np.equal(xyz_data_1.elements, xyz_structures[1].elements)) | ||
assert np.all(np.isclose(xyz_data_1.positions, xyz_structures[1].positions)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
use a numpy array of uint8 called atoms
this will be more consistent with cube file interface