Skip to content

Commit

Permalink
auto convert feature_modules to list in build method
Browse files Browse the repository at this point in the history
  • Loading branch information
DaniBodor committed Oct 31, 2023
1 parent 6b63b9b commit eb2c8fe
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 9 deletions.
2 changes: 2 additions & 0 deletions deeprank2/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,8 @@ def build(
:class:`Graph`: The resulting :class:`Graph` object with all the features and targets.
"""

if not isinstance(feature_modules, list):
feature_modules = [feature_modules]
feature_modules = [importlib.import_module('deeprank2.features.' + module)
if isinstance(module, str) else module
for module in feature_modules]
Expand Down
18 changes: 9 additions & 9 deletions tests/test_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,15 +373,15 @@ def test_incorrect_pssm_order():

# check that error is thrown for incorrect pssm
with pytest.raises(ValueError):
_ = q.build([conservation])
_ = q.build(conservation)

# no error if conservation module is not used
_ = q.build([components])
_ = q.build(components)

# check that error suppression works
with pytest.warns(UserWarning):
q.suppress_pssm_errors = True
_ = q.build([conservation])
_ = q.build(conservation)


def test_incomplete_pssm():
Expand All @@ -396,15 +396,15 @@ def test_incomplete_pssm():
)

with pytest.raises(ValueError):
_ = q.build([conservation])
_ = q.build(conservation)

# no error if conservation module is not used
_ = q.build([components])
_ = q.build(components)

# check that error suppression works
with pytest.warns(UserWarning):
q.suppress_pssm_errors = True
_ = q.build([conservation])
_ = q.build(conservation)


def test_no_pssm_provided():
Expand Down Expand Up @@ -480,13 +480,13 @@ def test_variant_query_multiple_chains():

# at radius 10, chain B is included in graph
# no error without conservation module
graph = q.build([components])
graph = q.build(components)
assert 'B' in graph.get_all_chains()
# if we rebuild the graph with conservation module it should fail
with pytest.raises(FileNotFoundError):
_ = q.build([conservation])
_ = q.build(conservation)

# at radius 7, chain B is not included in graph
q.radius = 7.0
graph = q.build([conservation])
graph = q.build(conservation)
assert 'B' not in graph.get_all_chains()

0 comments on commit eb2c8fe

Please sign in to comment.