Skip to content

Commit

Permalink
Update test_align.py
Browse files Browse the repository at this point in the history
test again
  • Loading branch information
talagayev authored Jan 11, 2025
1 parent 69311d4 commit fdb4f35
Showing 1 changed file with 17 additions and 147 deletions.
164 changes: 17 additions & 147 deletions testsuite/MDAnalysisTests/analysis/test_align.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,18 +303,11 @@ def test_AlignTraj_outfile_default(self, universe, reference, tmpdir):
x._writer.close()

def test_AlignTraj_outfile_default_exists(
self, universe, reference, tmpdir, client_AlignTraj
):
reference.trajectory[-1]
outfile = str(tmpdir.join('align_test.dcd'))
align.AlignTraj(universe, reference, filename=outfile).run(
**client_AlignTraj
)
self, universe, reference, tmpdir
self, universe, reference, tmpdir, client_AlignTraj
):
reference.trajectory[-1]
outfile = str(tmpdir.join("align_test.dcd"))
align.AlignTraj(universe, reference, filename=outfile).run()
align.AlignTraj(universe, reference, filename=outfile).run(**client_AlignTraj)
fitted = mda.Universe(PSF, outfile)

# ensure default file exists
Expand All @@ -331,39 +324,23 @@ def test_AlignTraj_outfile_default_exists(
with pytest.raises(IOError):
align.AlignTraj(fitted, reference, force=False)

def test_AlignTraj_step_works(
self, universe, reference, tmpdir, client_AlignTraj
):
def test_AlignTraj_step_works(self, universe, reference, tmpdir, client_AlignTraj):
reference.trajectory[-1]
outfile = str(tmpdir.join("align_test.dcd"))
# this shouldn't throw an exception
align.AlignTraj(universe, reference, filename=outfile).run(
step=10, **client_AlignTraj
)
align.AlignTraj(universe, reference, filename=outfile).run(step=10)

def test_AlignTraj_deprecated_attribute(
self, universe, reference, tmpdir, client_AlignTraj
):
def test_AlignTraj_deprecated_attribute(self, universe, reference, tmpdir, client_AlignTraj):
reference.trajectory[-1]
outfile = str(tmpdir.join('align_test.dcd'))
x = align.AlignTraj(universe, reference, filename=outfile).run(
stop=2, **client_AlignTraj
)
outfile = str(tmpdir.join("align_test.dcd"))
x = align.AlignTraj(universe, reference, filename=outfile).run(stop=2)

wmsg = "The `rmsd` attribute was deprecated in MDAnalysis 2.0.0"
with pytest.warns(DeprecationWarning, match=wmsg):
assert_equal(x.rmsd, x.results.rmsd)

def test_AlignTraj(
self, universe, reference, tmpdir, client_AlignTraj
):
def test_AlignTraj(self, universe, reference, tmpdir, client_AlignTraj):
reference.trajectory[-1]
outfile = str(tmpdir.join('align_test.dcd'))
x = align.AlignTraj(universe, reference, filename=outfile).run(
**client_AlignTraj
)
outfile = str(tmpdir.join("align_test.dcd"))
x = align.AlignTraj(universe, reference, filename=outfile).run()
fitted = mda.Universe(PSF, outfile)
Expand All @@ -377,15 +354,7 @@ def test_AlignTraj(
self._assert_rmsd(reference, fitted, 0, 6.929083044751061)
self._assert_rmsd(reference, fitted, -1, 0.0)

def test_AlignTraj_weighted(
self, universe, reference, tmpdir, client_AlignTraj
):
outfile = str(tmpdir.join('align_test.dcd'))
x = align.AlignTraj(universe, reference,
filename=outfile, weights='mass').run(
**client_AlignTraj
)
def test_AlignTraj_weighted(self, universe, reference, tmpdir):
def test_AlignTraj_weighted(self, universe, reference, tmpdir, client_AlignTraj):
outfile = str(tmpdir.join("align_test.dcd"))
x = align.AlignTraj(
universe, reference, filename=outfile, weights="mass"
Expand All @@ -405,23 +374,13 @@ def test_AlignTraj_weighted(self, universe, reference, tmpdir):
weights=universe.atoms.masses,
)

def test_AlignTraj_custom_weights(
self, universe, reference, tmpdir, client_AlignTraj
):
def test_AlignTraj_custom_weights(self, universe, reference, tmpdir, client_AlignTraj):
weights = np.zeros(universe.atoms.n_atoms)
ca = universe.select_atoms("name CA")
weights[ca.indices] = 1

outfile = str(tmpdir.join("align_test.dcd"))

x = align.AlignTraj(universe, reference,
filename=outfile, select='name CA').run(
**client_AlignTraj
)
x_weights = align.AlignTraj(universe, reference,
filename=outfile, weights=weights).run(
**client_AlignTraj
)
x = align.AlignTraj(
universe, reference, filename=outfile, select="name CA"
).run()
Expand All @@ -433,16 +392,6 @@ def test_AlignTraj_custom_weights(
x.results.rmsd, x_weights.results.rmsd, rtol=0, atol=1.5e-7
)

def test_AlignTraj_custom_mass_weights(
self, universe, reference, tmpdir, client_AlignTraj
):
outfile = str(tmpdir.join('align_test.dcd'))
x = align.AlignTraj(universe, reference,
filename=outfile,
weights=reference.atoms.masses).run(
**client_AlignTraj
)

def test_AlignTraj_custom_mass_weights(self, universe, reference, tmpdir):
outfile = str(tmpdir.join("align_test.dcd"))
x = align.AlignTraj(
Expand All @@ -466,25 +415,6 @@ def test_AlignTraj_custom_mass_weights(self, universe, reference, tmpdir):
weights=universe.atoms.masses,
)

def test_AlignTraj_partial_fit(
self, universe, reference, tmpdir, client_AlignTraj
):
outfile = str(tmpdir.join('align_test.dcd'))
# fitting on a partial selection should still write the whole topology
align.AlignTraj(universe, reference, select='resid 1-20',
filename=outfile, weights='mass').run(
**client_AlignTraj
)
mda.Universe(PSF, outfile)

def test_AlignTraj_in_memory(
self, universe, reference, tmpdir, client_AlignTraj
):
outfile = str(tmpdir.join('align_test.dcd'))
reference.trajectory[-1]
x = align.AlignTraj(universe, reference, filename=outfile,
in_memory=True).run(**client_AlignTraj)

def test_AlignTraj_partial_fit(self, universe, reference, tmpdir):
outfile = str(tmpdir.join("align_test.dcd"))
# fitting on a partial selection should still write the whole topology
Expand All @@ -511,17 +441,10 @@ def test_AlignTraj_in_memory(self, universe, reference, tmpdir):
self._assert_rmsd(reference, universe, 0, 6.929083044751061)
self._assert_rmsd(reference, universe, -1, 0.0)

def test_AlignTraj_writer_kwargs(
self, universe, reference, tmpdir, client_AlignTraj
):
def test_AlignTraj_writer_kwargs(self, universe, reference, tmpdir):
# Issue 4564
writer_kwargs = dict(precision=2)
with tmpdir.as_cwd():
aligner = align.AlignTraj(universe, reference,
select='protein and name CA',
filename='aligned_traj.xtc',
writer_kwargs=writer_kwargs,
in_memory=False).run(**client_AlignTraj)
aligner = align.AlignTraj(
universe,
reference,
Expand Down Expand Up @@ -584,7 +507,7 @@ def test_alignto_partial_universe(self, universe, reference):
)


def _get_aligned_average_positions(ref_files, ref, select="all", **kwargs):
def _get_aligned_average_positions(ref_files, ref, select="all", **kwargs, ):
u = mda.Universe(*ref_files, in_memory=True)
prealigner = align.AlignTraj(u, ref, select=select, **kwargs).run()
ag = u.select_atoms(select)
Expand All @@ -605,13 +528,9 @@ def universe(self):
def reference(self):
return mda.Universe(PSF, CRD)

def test_average_structure_deprecated_attrs(
self, universe, reference, client_AverageStructure
):
def test_average_structure_deprecated_attrs(self, universe, reference):
# Issue #3278 - remove in MDAnalysis 3.0.0
avg = align.AverageStructure(universe, reference).run(
stop=2, **client_AverageStructure
)
avg = align.AverageStructure(universe, reference).run(stop=2,)

wmsg = "The `universe` attribute was deprecated in MDAnalysis 2.0.0"
with pytest.warns(DeprecationWarning, match=wmsg):
Expand All @@ -628,38 +547,8 @@ def test_average_structure_deprecated_attrs(
with pytest.warns(DeprecationWarning, match=wmsg):
assert avg.rmsd == avg.results.rmsd

def test_average_structure(
self, universe, reference, client_AverageStructure
):
def test_average_structure(self, universe, reference):
ref, rmsd = _get_aligned_average_positions(self.ref_files, reference)
avg = align.AverageStructure(universe, reference).run(
**client_AverageStructure
)
assert_allclose(avg.results.universe.atoms.positions, ref, rtol=0, atol=1.5e-4)
assert_allclose(avg.results.rmsd, rmsd, rtol=0, atol=1.5e-7)

def test_average_structure_mass_weighted(
self, universe, reference, client_AverageStructure
):
ref, rmsd = _get_aligned_average_positions(self.ref_files, reference, weights='mass')
avg = align.AverageStructure(universe, reference, weights='mass').run(
**client_AverageStructure
)
assert_allclose(avg.results.universe.atoms.positions, ref,
rtol=0, atol=1.5e-4)
assert_allclose(avg.results.rmsd, rmsd, rtol=0, atol=1.5e-7)

def test_average_structure_select(
self, universe, reference, client_AverageStructure
):
select = 'protein and name CA and resid 3-5'
ref, rmsd = _get_aligned_average_positions(self.ref_files, reference, select=select)
avg = align.AverageStructure(universe, reference, select=select).run(
**client_AverageStructure
)
assert_allclose(avg.results.universe.atoms.positions, ref,
rtol=0, atol=1.5e-4)
=======
avg = align.AverageStructure(universe, reference).run()
assert_allclose(
avg.results.universe.atoms.positions, ref, rtol=0, atol=1.5e-4
Expand Down Expand Up @@ -687,16 +576,8 @@ def test_average_structure_select(self, universe, reference):
)
assert_allclose(avg.results.rmsd, rmsd, rtol=0, atol=1.5e-7)

def test_average_structure_no_ref(self, universe, client_AverageStructure):
def test_average_structure_no_ref(self, universe):
ref, rmsd = _get_aligned_average_positions(self.ref_files, universe)
avg = align.AverageStructure(universe).run(**client_AverageStructure)
assert_allclose(avg.results.universe.atoms.positions, ref,
rtol=0, atol=1.5e-4)
assert_allclose(avg.results.rmsd, rmsd, rtol=0, atol=1.5e-7)

def test_average_structure_no_msf(self, universe, client_AverageStructure):
avg = align.AverageStructure(universe).run(**client_AverageStructure)
assert not hasattr(avg, 'msf')
avg = align.AverageStructure(universe).run()
assert_allclose(
avg.results.universe.atoms.positions, ref, rtol=0, atol=1.5e-4
Expand All @@ -712,9 +593,7 @@ def test_mismatch_atoms(self, universe):
with pytest.raises(SelectionError):
align.AverageStructure(universe, u)

def test_average_structure_ref_frame(
self, universe, client_AverageStructure
):
def test_average_structure_ref_frame(self, universe):
ref_frame = 3
u = mda.Merge(universe.atoms)

Expand All @@ -725,23 +604,14 @@ def test_average_structure_ref_frame(
# back to start
universe.trajectory[0]
ref, rmsd = _get_aligned_average_positions(self.ref_files, u)
avg = align.AverageStructure(universe, ref_frame=ref_frame).run(
**client_AverageStructure
)
assert_allclose(avg.results.universe.atoms.positions, ref,
rtol=0, atol=1.5e-4)
avg = align.AverageStructure(universe, ref_frame=ref_frame).run()
assert_allclose(
avg.results.universe.atoms.positions, ref, rtol=0, atol=1.5e-4
)
assert_allclose(avg.results.rmsd, rmsd, rtol=0, atol=1.5e-7)

def test_average_structure_in_memory(
self, universe, client_AverageStructure
):
avg = align.AverageStructure(universe, in_memory=True).run(
**client_AverageStructure
)
def test_average_structure_in_memory(self, universe):
avg = align.AverageStructure(universe, in_memory=True).run()
reference_coordinates = universe.trajectory.timeseries().mean(axis=1)
assert_allclose(
avg.results.universe.atoms.positions,
Expand Down

0 comments on commit fdb4f35

Please sign in to comment.