diff --git a/deepbgc/pipeline/classifier.py b/deepbgc/pipeline/classifier.py index 6a74949..55b4a14 100644 --- a/deepbgc/pipeline/classifier.py +++ b/deepbgc/pipeline/classifier.py @@ -13,7 +13,7 @@ class DeepBGCClassifier(PipelineStep): def __init__(self, classifier, score_threshold=0.5): if classifier is None or not isinstance(classifier, six.string_types): raise ValueError('Expected classifier name or path, got {}'.format(classifier)) - if os.path.exists(classifier) or os.path.sep in classifier: + if (os.path.exists(classifier) or os.path.sep in classifier) and not os.path.isdir(classifier): classifier_path = classifier # Set classifier name to filename without suffix classifier, _ = os.path.splitext(os.path.basename(classifier)) diff --git a/deepbgc/pipeline/detector.py b/deepbgc/pipeline/detector.py index d6b681f..a99d422 100644 --- a/deepbgc/pipeline/detector.py +++ b/deepbgc/pipeline/detector.py @@ -16,7 +16,7 @@ def __init__(self, detector, label=None, score_threshold=0.5, merge_max_protein_ self.score_threshold = score_threshold if detector is None or not isinstance(detector, six.string_types): raise ValueError('Expected detector name or path, got {}'.format(detector)) - if os.path.exists(detector) or os.path.sep in detector: + if (os.path.exists(detector) or os.path.sep in detector) and not os.path.isdir(detector): model_path = detector # Set detector name to filename without suffix detector, _ = os.path.splitext(os.path.basename(detector)) diff --git a/test/integration/pipeline/test_integration_pfam.py b/test/integration/pipeline/test_integration_pfam.py index cd66c6a..54c5b01 100644 --- a/test/integration/pipeline/test_integration_pfam.py +++ b/test/integration/pipeline/test_integration_pfam.py @@ -26,7 +26,7 @@ def test_integration_pfam_annotator(tmpdir): assert pfam.location.start == 249 assert pfam.location.end == 696 assert pfam.location.strand == -1 - assert pfam.qualifiers.get('PFAM_ID') == ['PF00005'] + assert pfam.qualifiers.get('db_xref') == ['PF00005.26'] assert pfam.qualifiers.get('locus_tag') == ['AAK73498.1'] assert pfam.qualifiers.get('description') == ['ABC transporter'] assert pfam.qualifiers.get('database') == ['Pfam-A.31.0.hmm']