Skip to content

Commit

Permalink
[TEST] Fixed tests
Browse files Browse the repository at this point in the history
  • Loading branch information
clementpoiret committed Dec 7, 2021
1 parent cb4b273 commit 9b79015
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 5 deletions.
5 changes: 3 additions & 2 deletions hsf/factory.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from pathlib import Path, PosixPath
from typing import Generator

import ants
import hydra
Expand Down Expand Up @@ -55,13 +56,13 @@ def get_lr_hippocampi(mri: PosixPath, cfg: DictConfig) -> tuple:
original_mri_path=mri)


def predict(mri: PosixPath, engines: list, cfg: DictConfig) -> tuple:
def predict(mri: PosixPath, engines: Generator, cfg: DictConfig) -> tuple:
"""
Predict the hippocampal segmentation for a given MRI.
Args:
mri (PosixPath): Path to the MRI.
engines (list): List of ONNX Runtime engines.
engines (Generator): Generator of InferenceEngines.
cfg (DictConfig): Configuration.
Returns:
Expand Down
5 changes: 3 additions & 2 deletions hsf/segment.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from pathlib import PosixPath
from typing import Generator

import ants
import numpy as np
Expand Down Expand Up @@ -105,7 +106,7 @@ def segment(subject: tio.Subject,
augmentation_cfg: DictConfig,
segmentation_cfg: DictConfig,
n_engines: int,
engines: list,
engines: Generator,
ca_mode: str = "1/2/3",
batch_size: int = 1) -> tuple:
"""
Expand All @@ -115,7 +116,7 @@ def segment(subject: tio.Subject,
subject (tio.Subject): The subject to segment.
augmentation_cfg (DictConfig): Augmentation configuration.
segmentation_cfg (DictConfig): Segmentation configuration.
engines (List[InferenceEngine]): Inference Engines.
engines (Generator[InferenceEngine]): Inference Engines.
ca_mode (str): The cornu ammoni division mode. Defaults to "1/2/3".
batch_size (int): Batch size. Defaults to 1.
Expand Down
2 changes: 1 addition & 1 deletion tests/test_hsf.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def deepsparse_inference_engines(models_path):
engine_name="deepsparse",
engine_settings=settings)

return engines
return list(engines)


# TESTS
Expand Down

0 comments on commit 9b79015

Please sign in to comment.