Skip to content

Commit

Permalink
parse rigid sites from site.molecule, update test
Browse files Browse the repository at this point in the history
  • Loading branch information
chrisjonesBSU committed Sep 26, 2024
1 parent 8842973 commit 6a383c6
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 19 deletions.
21 changes: 11 additions & 10 deletions gmso/external/convert_hoomd.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,18 +293,16 @@ def _parse_particle_information(
else 1 * base_units["mass"]
)
charges[idx] = site.charge if site.charge else 0 * u.elementary_charge
# Check for rigid IDs
rigid_ids = [site.rigid_id for site in top.sites]
rigid_ids_set = set(rigid_ids)
if None not in rigid_ids:
n_rigid = len(rigid_ids_set)
write_rigid = True
else:
write_rigid = False
n_rigid = 0

unique_types = sorted(list(set(types)))
typeids = np.array([unique_types.index(t) for t in types])
if write_rigid:
# Check for rigid molecules
rigid_mols = any([site.molecule.isrigid for site in top.sites])
if rigid_mols:
rigid_ids = [site.molecule.number for site in top.sites]
rigid_ids_set = set(rigid_ids)
n_rigid = len(rigid_ids_set)
write_rigid = True
rigid_masses = np.zeros(n_rigid)
rigid_xyz = np.zeros((n_rigid, 3))
# Rigid particle type defaults to "R"; add to front of list
Expand All @@ -326,6 +324,9 @@ def _parse_particle_information(
masses = np.concatenate((rigid_masses, masses))
xyz = np.concatenate((rigid_xyz, xyz))
rigid_id_tags = np.concatenate((np.arange(n_rigid), np.array(rigid_ids)))
else:
write_rigid = False
n_rigid = 0

"""
Permittivity of free space = 2.39725e-4 e^2/((kcal/mol)(angstrom)),
Expand Down
16 changes: 7 additions & 9 deletions gmso/tests/test_hoomd.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,21 +60,19 @@ class TestGsd(BaseTest):
def test_rigid_bodies(self):
ethane = mb.lib.molecules.Ethane()
box = mb.fill_box(ethane, n_compounds=2, box=[2, 2, 2])
box_no_rigid = mb.clone(box)
for i, child in enumerate(box.children):
for p in child.particles():
p.rigid_id = i

top = from_mbuild(box)
top_no_rigid = from_mbuild(box_no_rigid)
for site in top.sites:
site.molecule.isrigid = True

top_no_rigid = from_mbuild(box)

rigid_ids = [site.rigid_id for site in top.sites]
assert len(rigid_ids) == box.n_particles
assert len(set(rigid_ids)) == 2
rigid_ids = [site.molecule.number for site in top.sites]
assert set(rigid_ids) == {0, 1}

snapshot, refs = to_gsd_snapshot(top)
snapshot_no_rigid, refs = to_gsd_snapshot(top_no_rigid)
assert "R" in snapshot.particles.types
assert "R" not in snapshot_no_rigid.particles.types
assert snapshot.particles.N - 2 == snapshot_no_rigid.particles.N

@pytest.mark.skipif(
Expand Down

0 comments on commit 6a383c6

Please sign in to comment.