Skip to content

Commit

Permalink
account for various possible float dtypes
Browse files Browse the repository at this point in the history
  • Loading branch information
martinvoegele committed Jul 3, 2021
1 parent d867e66 commit f35a47e
Showing 1 changed file with 24 additions and 18 deletions.
42 changes: 24 additions & 18 deletions tests/datasets/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,19 +13,21 @@ def test_load_dataset_lmdb():
assert len(dataset) == 4
for df in dataset:
print(df)
assert df['atoms'].x.dtype == 'float'
assert df['atoms'].y.dtype == 'float'
assert df['atoms'].z.dtype == 'float'
print(df['atoms'].x.dtype)
assert df['atoms'].x.dtype in ['float', 'float32', 'float64']
assert df['atoms'].y.dtype in ['float', 'float32', 'float64']
assert df['atoms'].z.dtype in ['float', 'float32', 'float64']


def test_load_dataset_list():
dataset = da.load_dataset('tests/test_data/list/pdbs.txt', 'pdb')
assert len(dataset) == 4
for df in dataset:
print(df)
assert df['atoms'].x.dtype == 'float'
assert df['atoms'].y.dtype == 'float'
assert df['atoms'].z.dtype == 'float'
print(df['atoms'].x.dtype)
assert df['atoms'].x.dtype in ['float', 'float32', 'float64']
assert df['atoms'].y.dtype in ['float', 'float32', 'float64']
assert df['atoms'].z.dtype in ['float', 'float32', 'float64']

def test_load_dataset_list_nonexistent():
dataset = da.load_dataset('tests/test_data/list/nonexistent.txt', 'pdb')
Expand All @@ -43,9 +45,10 @@ def test_load_dataset_pdb():
assert len(dataset) == 4
for df in dataset:
print(df)
assert df['atoms'].x.dtype == 'float'
assert df['atoms'].y.dtype == 'float'
assert df['atoms'].z.dtype == 'float'
print(df['atoms'].x.dtype)
assert df['atoms'].x.dtype in ['float', 'float32', 'float64']
assert df['atoms'].y.dtype in ['float', 'float32', 'float64']
assert df['atoms'].z.dtype in ['float', 'float32', 'float64']


@pytest.mark.skipif(not importlib.util.find_spec("rdkit") is not None,
Expand All @@ -55,9 +58,10 @@ def test_load_dataset_sdf():
assert len(dataset) == 4
for df in dataset:
print(df)
assert df['atoms'].x.dtype == 'float'
assert df['atoms'].y.dtype == 'float'
assert df['atoms'].z.dtype == 'float'
print(df['atoms'].x.dtype)
assert df['atoms'].x.dtype in ['float', 'float32', 'float64']
assert df['atoms'].y.dtype in ['float', 'float32', 'float64']
assert df['atoms'].z.dtype in ['float', 'float32', 'float64']


@pytest.mark.skipif(not importlib.util.find_spec("rosetta") is not None,
Expand All @@ -74,9 +78,10 @@ def test_load_dataset_xyz():
assert len(dataset) == 3
for df in dataset:
print(df)
assert df['atoms'].x.dtype == 'float'
assert df['atoms'].y.dtype == 'float'
assert df['atoms'].z.dtype == 'float'
print(df['atoms'].x.dtype)
assert df['atoms'].x.dtype in ['float', 'float32', 'float64']
assert df['atoms'].y.dtype in ['float', 'float32', 'float64']
assert df['atoms'].z.dtype in ['float', 'float32', 'float64']


def test_load_dataset_xyzgdb():
Expand All @@ -87,9 +92,10 @@ def test_load_dataset_xyzgdb():
assert len(dataset) == 3
for df in dataset:
print(df)
assert df['atoms'].x.dtype == 'float'
assert df['atoms'].y.dtype == 'float'
assert df['atoms'].z.dtype == 'float'
print(df['atoms'].x.dtype)
assert df['atoms'].x.dtype in ['float', 'float32', 'float64']
assert df['atoms'].y.dtype in ['float', 'float32', 'float64']
assert df['atoms'].z.dtype in ['float', 'float32', 'float64']


# -- Creator for LMDB dataset
Expand Down

0 comments on commit f35a47e

Please sign in to comment.