From 9b790159f1a402412da9c302b6a1116622bbaf1c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9ment=20POIRET?= Date: Tue, 7 Dec 2021 20:28:00 +0100 Subject: [PATCH] [TEST] Fixed tests --- hsf/factory.py | 5 +++-- hsf/segment.py | 5 +++-- tests/test_hsf.py | 2 +- 3 files changed, 7 insertions(+), 5 deletions(-) diff --git a/hsf/factory.py b/hsf/factory.py index 2f24b1a..1b7a2ba 100644 --- a/hsf/factory.py +++ b/hsf/factory.py @@ -1,4 +1,5 @@ from pathlib import Path, PosixPath +from typing import Generator import ants import hydra @@ -55,13 +56,13 @@ def get_lr_hippocampi(mri: PosixPath, cfg: DictConfig) -> tuple: original_mri_path=mri) -def predict(mri: PosixPath, engines: list, cfg: DictConfig) -> tuple: +def predict(mri: PosixPath, engines: Generator, cfg: DictConfig) -> tuple: """ Predict the hippocampal segmentation for a given MRI. Args: mri (PosixPath): Path to the MRI. - engines (list): List of ONNX Runtime engines. + engines (Generator): Generator of InferenceEngines. cfg (DictConfig): Configuration. Returns: diff --git a/hsf/segment.py b/hsf/segment.py index 16cc7e4..7f9d228 100644 --- a/hsf/segment.py +++ b/hsf/segment.py @@ -1,4 +1,5 @@ from pathlib import PosixPath +from typing import Generator import ants import numpy as np @@ -105,7 +106,7 @@ def segment(subject: tio.Subject, augmentation_cfg: DictConfig, segmentation_cfg: DictConfig, n_engines: int, - engines: list, + engines: Generator, ca_mode: str = "1/2/3", batch_size: int = 1) -> tuple: """ @@ -115,7 +116,7 @@ def segment(subject: tio.Subject, subject (tio.Subject): The subject to segment. augmentation_cfg (DictConfig): Augmentation configuration. segmentation_cfg (DictConfig): Segmentation configuration. - engines (List[InferenceEngine]): Inference Engines. + engines (Generator[InferenceEngine]): Inference Engines. ca_mode (str): The cornu ammoni division mode. Defaults to "1/2/3". batch_size (int): Batch size. Defaults to 1. diff --git a/tests/test_hsf.py b/tests/test_hsf.py index 59cd1aa..7431d4e 100644 --- a/tests/test_hsf.py +++ b/tests/test_hsf.py @@ -116,7 +116,7 @@ def deepsparse_inference_engines(models_path): engine_name="deepsparse", engine_settings=settings) - return engines + return list(engines) # TESTS