Skip to content

Commit

Permalink
test
Browse files Browse the repository at this point in the history
  • Loading branch information
yuxuanzhuang committed Jun 23, 2022
1 parent cbb8448 commit fe595b2
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 15 deletions.
2 changes: 1 addition & 1 deletion package/CHANGELOG
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ Enhancements
* Timestep has been converted to a Cython Extension type
(CZI Performance track, PR #3683)
* Refactor make_downshift_arrays with vector operation in numpy(PR #3724)
* Optimize topology serialization by reconstructing _RA and _SR.
* Optimize topology serialization/construction by lazy-building _RA and _SR.

Changes

Expand Down
3 changes: 3 additions & 0 deletions package/MDAnalysis/core/topology.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,9 @@ class TransTable(object):
size : tuple
tuple ``(n_atoms, n_residues, n_segments)`` describing the shape of
the TransTable
.. versionchanged:: 2.3.0
Lazy building RA and SR.
"""
def __init__(self,
n_atoms, n_residues, n_segments, # Size of tables
Expand Down
4 changes: 2 additions & 2 deletions testsuite/MDAnalysisTests/core/test_copying.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def test_size_independent(self, refTT):
refTT.n_atoms = -10
assert new.n_atoms == old

@pytest.mark.parametrize('attr', ['_AR', '_RA', '_RS', '_SR'])
@pytest.mark.parametrize('attr', ['_AR', 'RA', '_RS', 'SR'])
def test_AR(self, refTT, attr):
new = refTT.copy()
ref = getattr(refTT, attr)
Expand All @@ -66,7 +66,7 @@ def test_AR(self, refTT, attr):
for a, b in zip(ref, other):
assert_equal(a, b)

@pytest.mark.parametrize('attr', ['_AR', '_RA', '_RS', '_SR'])
@pytest.mark.parametrize('attr', ['_AR', 'RA', '_RS', 'SR'])
def test_AR_independent(self, refTT, attr):
new = refTT.copy()
ref = getattr(refTT, attr)
Expand Down
58 changes: 46 additions & 12 deletions testsuite/MDAnalysisTests/core/test_topology.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
)
import pytest
import numpy as np
import pickle

from MDAnalysisTests import make_Universe

Expand All @@ -23,6 +24,11 @@
import MDAnalysis


def assert_rows_match(a, b):
for row_a, row_b in zip(a, b):
assert_equal(row_a, row_b)


class TestTransTable(object):
@pytest.fixture()
def tt(self):
Expand Down Expand Up @@ -144,6 +150,39 @@ def test_move_residue_simple(self, tt):
assert_equal(len(tt.segments2residues_1d(0)), 3)
assert_equal(len(tt.segments2residues_1d(1)), 1)

def test_lazy_building_RA(self, tt):
assert_equal(tt._RA, None)
RA = tt.RA
assert_rows_match(tt.RA,
np.array([np.array([0, 1]),
np.array([4, 5, 8]),
np.array([2, 3, 9]),
np.array([6, 7]),
None], dtype=object))

tt.move_atom(1, 3)
assert_equal(tt._RA, None)

def test_lazy_building_SR(self, tt):
assert_equal(tt._SR, None)
SR = tt.SR
assert_rows_match(tt.SR,
np.array([np.array([0, 3]),
np.array([1, 2]),
None], dtype=object))

tt.move_residue(1, 0)
assert_equal(tt._SR, None)

def test_serialization(self, tt):
_ = tt.RA
_ = tt.SR
tt_loaded = pickle.loads(pickle.dumps(tt))
assert_equal(tt_loaded._RA, None)
assert_equal(tt_loaded._SR, None)
assert_rows_match(tt_loaded.RA, tt.RA)
assert_rows_match(tt_loaded.SR, tt.SR)


class TestLevelMoves(object):
"""Tests for moving atoms/residues between residues/segments
Expand Down Expand Up @@ -467,11 +506,6 @@ def ragged_result(self):
return np.array([[0, 4, 7], [1, 5, 8], [2, 3, 6, 9]],
dtype=object)

@staticmethod
def assert_rows_match(a, b):
for row_a, row_b in zip(a, b):
assert_equal(row_a, row_b)

# The array as a whole must be dtype object
# While the subarrays must be integers
def test_downshift_dtype_square(self, square, square_size):
Expand All @@ -498,16 +532,16 @@ def test_shape_ragged(self, ragged, ragged_size):

def test_contents_square(self, square, square_size, square_result):
out = make_downshift_arrays(square, square_size)
self.assert_rows_match(out, square_result)
assert_rows_match(out, square_result)

def test_contents_ragged(self, ragged, ragged_size, ragged_result):
out = make_downshift_arrays(ragged, ragged_size)
self.assert_rows_match(out, ragged_result)
assert_rows_match(out, ragged_result)

def test_missing_intra_values(self):
out = make_downshift_arrays(
np.array([0, 0, 2, 2, 3, 3]), 4)
self.assert_rows_match(out,
assert_rows_match(out,
np.array([np.array([0, 1]),
np.array([], dtype=int),
np.array([2, 3]),
Expand All @@ -517,7 +551,7 @@ def test_missing_intra_values(self):
def test_missing_intra_values_2(self):
out = make_downshift_arrays(
np.array([0, 0, 3, 3, 4, 4]), 5)
self.assert_rows_match(out,
assert_rows_match(out,
np.array([np.array([0, 1]),
np.array([], dtype=int),
np.array([], dtype=int),
Expand All @@ -527,7 +561,7 @@ def test_missing_intra_values_2(self):

def test_missing_end_values(self):
out = make_downshift_arrays(np.array([0, 0, 1, 1, 2, 2]), 4)
self.assert_rows_match(out,
assert_rows_match(out,
np.array([np.array([0, 1]),
np.array([2, 3]),
np.array([4, 5]),
Expand All @@ -536,7 +570,7 @@ def test_missing_end_values(self):

def test_missing_end_values_2(self):
out = make_downshift_arrays(np.array([0, 0, 1, 1, 2, 2]), 6)
self.assert_rows_match(out,
assert_rows_match(out,
np.array([np.array([0, 1]),
np.array([2, 3]),
np.array([4, 5]),
Expand All @@ -546,7 +580,7 @@ def test_missing_end_values_2(self):

def test_missing_start_values_2(self):
out = make_downshift_arrays(np.array([1, 1, 2, 2, 3, 3]), 4)
self.assert_rows_match(out,
assert_rows_match(out,
np.array([np.array([], dtype=int),
np.array([0, 1]),
np.array([2, 3]),
Expand Down

0 comments on commit fe595b2

Please sign in to comment.