diff --git a/gmso/formats/gro.py b/gmso/formats/gro.py index bc113a2b8..f9315648f 100644 --- a/gmso/formats/gro.py +++ b/gmso/formats/gro.py @@ -103,7 +103,7 @@ def read_gro(filename): @saves_as(".gro") -def write_gro(top, filename, precision=3): +def write_gro(top, filename, n_decimals=3, shift_coord=False): """Write a topology to a gro file. The Gromos87 (gro) format is a common plain text structure file used @@ -119,8 +119,12 @@ def write_gro(top, filename, precision=3): The `topology` to write out to the gro file. filename : str or file object The location and name of file to save to disk. - precision : int, optional, default=3 + n_decimals : int, optional, default=3 The number of sig fig to write out the position in. + shift_coord : bool, optional, default=False + If True, shift the coordinates of all sites by the minimum position + to ensure all sites have non-negative positions. This is not a requirement + for GRO files, but can be useful for visualizing. Notes ----- @@ -131,7 +135,8 @@ def write_gro(top, filename, precision=3): """ pos_array = np.ndarray.copy(top.positions) - pos_array = _validate_positions(pos_array) + if shift_coord: + pos_array = _validate_positions(pos_array) with open(filename, "w") as out_file: out_file.write( @@ -142,7 +147,7 @@ def write_gro(top, filename, precision=3): ) ) out_file.write("{:d}\n".format(top.n_sites)) - out_file.write(_prepare_atoms(top, pos_array, precision)) + out_file.write(_prepare_atoms(top, pos_array, n_decimals)) out_file.write(_prepare_box(top)) @@ -154,14 +159,14 @@ def _validate_positions(pos_array): "in order to ensure all coordinates are non-negative." ) min_xyz = np.min(pos_array, axis=0) - for i, minimum in enumerate(min_xyz): - if minimum < 0.0: - for loc in pos_array: - loc[i] = loc[i] - minimum + min_xyz0 = np.where(min_xyz < 0, min_xyz, 0) * min_xyz.units + + pos_array -= min_xyz0 + return pos_array -def _prepare_atoms(top, updated_positions, precision): +def _prepare_atoms(top, updated_positions, n_decimals): out_str = str() warnings.warn( "Residue information is parsed from site.molecule," @@ -221,8 +226,8 @@ def _prepare_atoms(top, updated_positions, precision): atom_id = atom_id % max_val res_id = res_id % max_val - varwidth = 5 + precision - crdfmt = f"{{:{varwidth}.{precision}f}}" + varwidth = 5 + n_decimals + crdfmt = f"{{:{varwidth}.{n_decimals}f}}" # preformat pos str crt_x = crdfmt.format(pos[0].in_units(u.nm).value)[:varwidth] diff --git a/gmso/tests/test_gro.py b/gmso/tests/test_gro.py index 75534fc13..946d9f5dc 100644 --- a/gmso/tests/test_gro.py +++ b/gmso/tests/test_gro.py @@ -48,6 +48,13 @@ def test_write_gro(self): top = from_parmed(pmd.load_file(get_fn("ethane.gro"), structure=True)) top.save("out.gro") + def test_write_gro_with_shift_coord(self): + top = from_parmed(pmd.load_file(get_fn("ethane.mol2"), structure=True)) + top.save("out.gro", shift_coord=True) + + read_top = Topology.load("out.gro") + assert np.all(list(map(lambda x: x.position >= 0, read_top.sites))) + def test_write_gro_non_orthogonal(self): top = from_parmed(pmd.load_file(get_fn("ethane.gro"), structure=True)) top.box.angles = u.degree * [90, 90, 120]