Skip to content

Commit

Permalink
Add shift_coord for gro writer (#730)
Browse files Browse the repository at this point in the history
* add shift_coord option

* add simple test to improve coverage

* Update gmso/formats/gro.py

Co-authored-by: CalCraven <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci

* add additional check for test_write_gro_with_shift_coord

* change precision to n_decimals

* vectorize shift_coord

* address remaining comment

---------

Co-authored-by: CalCraven <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
3 people authored Sep 7, 2023
1 parent d584ffa commit 0011e15
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 11 deletions.
27 changes: 16 additions & 11 deletions gmso/formats/gro.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def read_gro(filename):


@saves_as(".gro")
def write_gro(top, filename, precision=3):
def write_gro(top, filename, n_decimals=3, shift_coord=False):
"""Write a topology to a gro file.
The Gromos87 (gro) format is a common plain text structure file used
Expand All @@ -119,8 +119,12 @@ def write_gro(top, filename, precision=3):
The `topology` to write out to the gro file.
filename : str or file object
The location and name of file to save to disk.
precision : int, optional, default=3
n_decimals : int, optional, default=3
The number of sig fig to write out the position in.
shift_coord : bool, optional, default=False
If True, shift the coordinates of all sites by the minimum position
to ensure all sites have non-negative positions. This is not a requirement
for GRO files, but can be useful for visualizing.
Notes
-----
Expand All @@ -131,7 +135,8 @@ def write_gro(top, filename, precision=3):
"""
pos_array = np.ndarray.copy(top.positions)
pos_array = _validate_positions(pos_array)
if shift_coord:
pos_array = _validate_positions(pos_array)

with open(filename, "w") as out_file:
out_file.write(
Expand All @@ -142,7 +147,7 @@ def write_gro(top, filename, precision=3):
)
)
out_file.write("{:d}\n".format(top.n_sites))
out_file.write(_prepare_atoms(top, pos_array, precision))
out_file.write(_prepare_atoms(top, pos_array, n_decimals))
out_file.write(_prepare_box(top))


Expand All @@ -154,14 +159,14 @@ def _validate_positions(pos_array):
"in order to ensure all coordinates are non-negative."
)
min_xyz = np.min(pos_array, axis=0)
for i, minimum in enumerate(min_xyz):
if minimum < 0.0:
for loc in pos_array:
loc[i] = loc[i] - minimum
min_xyz0 = np.where(min_xyz < 0, min_xyz, 0) * min_xyz.units

pos_array -= min_xyz0

return pos_array


def _prepare_atoms(top, updated_positions, precision):
def _prepare_atoms(top, updated_positions, n_decimals):
out_str = str()
warnings.warn(
"Residue information is parsed from site.molecule,"
Expand Down Expand Up @@ -221,8 +226,8 @@ def _prepare_atoms(top, updated_positions, precision):
atom_id = atom_id % max_val
res_id = res_id % max_val

varwidth = 5 + precision
crdfmt = f"{{:{varwidth}.{precision}f}}"
varwidth = 5 + n_decimals
crdfmt = f"{{:{varwidth}.{n_decimals}f}}"

# preformat pos str
crt_x = crdfmt.format(pos[0].in_units(u.nm).value)[:varwidth]
Expand Down
7 changes: 7 additions & 0 deletions gmso/tests/test_gro.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,13 @@ def test_write_gro(self):
top = from_parmed(pmd.load_file(get_fn("ethane.gro"), structure=True))
top.save("out.gro")

def test_write_gro_with_shift_coord(self):
top = from_parmed(pmd.load_file(get_fn("ethane.mol2"), structure=True))
top.save("out.gro", shift_coord=True)

read_top = Topology.load("out.gro")
assert np.all(list(map(lambda x: x.position >= 0, read_top.sites)))

def test_write_gro_non_orthogonal(self):
top = from_parmed(pmd.load_file(get_fn("ethane.gro"), structure=True))
top.box.angles = u.degree * [90, 90, 120]
Expand Down

0 comments on commit 0011e15

Please sign in to comment.