diff --git a/monai/bundle/scripts.py b/monai/bundle/scripts.py index f81905a961..b625578049 100644 --- a/monai/bundle/scripts.py +++ b/monai/bundle/scripts.py @@ -326,6 +326,7 @@ def load( bundle_dir: Optional[PathLike] = None, source: str = download_source, repo: Optional[str] = None, + remove_prefix: Optional[str] = "monai_", progress: bool = True, device: Optional[str] = None, key_in_ckpt: Optional[str] = None, @@ -356,6 +357,10 @@ def load( it should be "ngc" or "github". repo: repo name. This argument is used when `url` is `None` and `source` is "github". If used, it should be in the form of "repo_owner/repo_name/release_tag". + remove_prefix: This argument is used when `source` is "ngc". Currently, all ngc bundles + have the ``monai_`` prefix, which is not existing in their model zoo contrasts. In order to + maintain the consistency between these two sources, remove prefix is necessary. + Therefore, if specified, downloaded folder name will remove the prefix. progress: whether to display a progress bar when downloading. device: target device of returned weights or module, if `None`, prefer to "cuda" if existing. key_in_ckpt: for nested checkpoint like `{"model": XXX, "optimizer": XXX, ...}`, specify the key of model @@ -379,9 +384,21 @@ def load( if model_file is None: model_file = os.path.join("models", "model.ts" if load_ts_module is True else "model.pt") + if source == "ngc": + name = _add_ngc_prefix(name) + if remove_prefix: + name = _remove_ngc_prefix(name, prefix=remove_prefix) full_path = os.path.join(bundle_dir_, name, model_file) if not os.path.exists(full_path): - download(name=name, version=version, bundle_dir=bundle_dir_, source=source, repo=repo, progress=progress) + download( + name=name, + version=version, + bundle_dir=bundle_dir_, + source=source, + repo=repo, + remove_prefix=remove_prefix, + progress=progress, + ) if device is None: device = "cuda:0" if is_available() else "cpu" diff --git a/tests/ngc_bundle_download.py b/tests/ngc_bundle_download.py index 5e69bf4d63..2b376c3c2d 100644 --- a/tests/ngc_bundle_download.py +++ b/tests/ngc_bundle_download.py @@ -19,10 +19,10 @@ from monai.apps import check_hash from monai.apps.mmars import MODEL_DESC, load_from_mmar -from monai.bundle import download +from monai.bundle import download, load from monai.config import print_debug_info from monai.networks.utils import copy_model_state -from tests.utils import skip_if_downloading_fails, skip_if_quick, skip_if_windows +from tests.utils import assert_allclose, skip_if_downloading_fails, skip_if_quick, skip_if_windows TEST_CASE_NGC_1 = [ "spleen_ct_segmentation", @@ -41,6 +41,30 @@ "b418a2dc8672ce2fd98dc255036e7a3d", ] +TESTCASE_WEIGHTS = { + "key": "model.0.conv.unit0.adn.N.bias", + "value": torch.tensor( + [ + -0.0705, + -0.0937, + -0.0422, + -0.2068, + 0.1023, + -0.2007, + -0.0883, + 0.0018, + -0.1719, + 0.0116, + 0.0285, + -0.0044, + 0.1223, + -0.1287, + -0.1858, + 0.0460, + ] + ), +} + @skip_if_windows class TestNgcBundleDownload(unittest.TestCase): @@ -56,6 +80,13 @@ def test_ngc_download_bundle(self, bundle_name, version, remove_prefix, download self.assertTrue(os.path.exists(full_file_path)) self.assertTrue(check_hash(filepath=full_file_path, val=hash_val)) + weights = load( + name=bundle_name, source="ngc", version=version, bundle_dir=tempdir, remove_prefix=remove_prefix + ) + assert_allclose( + weights[TESTCASE_WEIGHTS["key"]], TESTCASE_WEIGHTS["value"], atol=1e-4, rtol=1e-4, type_test=False + ) + @unittest.skip("deprecating mmar tests") class TestAllDownloadingMMAR(unittest.TestCase):