diff --git a/monai/networks/utils.py b/monai/networks/utils.py index 6a97434215..f301c2dd5c 100644 --- a/monai/networks/utils.py +++ b/monai/networks/utils.py @@ -822,7 +822,7 @@ def _onnx_trt_compile( output_names = [] if not output_names else output_names # set up the TensorRT builder - torch_tensorrt.set_device(device) + torch.cuda.set_device(device) logger = trt.Logger(trt.Logger.WARNING) builder = trt.Builder(logger) network = builder.create_network(1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)) @@ -931,7 +931,7 @@ def convert_to_trt( warnings.warn(f"The dynamic batch range sequence should have 3 elements, but got {dynamic_batchsize} elements.") device = device if device else 0 - target_device = torch.device(f"cuda:{device}") if device else torch.device("cuda:0") + target_device = torch.device(f"cuda:{device}") convert_precision = torch.float32 if precision == "fp32" else torch.half inputs = [torch.rand(ensure_tuple(input_shape)).to(target_device)] @@ -986,7 +986,7 @@ def scale_batch_size(input_shape: Sequence[int], scale_num: int): ir_model, inputs=input_placeholder, enabled_precisions=convert_precision, - device=target_device, + device=torch_tensorrt.Device(f"cuda:{device}"), ir="torchscript", **kwargs, ) diff --git a/monai/utils/module.py b/monai/utils/module.py index 1ac8140b39..78087aef84 100644 --- a/monai/utils/module.py +++ b/monai/utils/module.py @@ -214,7 +214,6 @@ def load_submodules( loader = mod_spec.loader loader.exec_module(mod) submodules.append(mod) - except OptionalImportError: pass # could not import the optional deps., they are ignored except ImportError as e: