diff --git a/monai/bundle/scripts.py b/monai/bundle/scripts.py index 6b34627a6a..711f585159 100644 --- a/monai/bundle/scripts.py +++ b/monai/bundle/scripts.py @@ -541,11 +541,16 @@ def load( return model +@deprecated_arg_default("tag", "hosting_storage_v1", "dev", since="1.2", replaced="1.5") def _get_all_bundles_info( repo: str = "Project-MONAI/model-zoo", tag: str = "hosting_storage_v1", auth_token: str | None = None ) -> dict[str, dict[str, dict[str, Any]]]: if has_requests: - request_url = f"https://api.github.com/repos/{repo}/releases" + if tag == "hosting_storage_v1": + request_url = f"https://api.github.com/repos/{repo}/releases" + else: + request_url = f"https://raw.githubusercontent.com/{repo}/{tag}/models/model_info.json" + if auth_token is not None: headers = {"Authorization": f"Bearer {auth_token}"} resp = requests_get(request_url, headers=headers) @@ -558,33 +563,39 @@ def _get_all_bundles_info( bundle_name_pattern = re.compile(r"_v\d*.") bundles_info: dict[str, dict[str, dict[str, Any]]] = {} - for release in releases_list: - if release["tag_name"] == tag: - for asset in release["assets"]: - asset_name = bundle_name_pattern.split(asset["name"])[0] - if asset_name not in bundles_info: - bundles_info[asset_name] = {} - asset_version = asset["name"].split(f"{asset_name}_v")[-1].replace(".zip", "") - bundles_info[asset_name][asset_version] = { - "id": asset["id"], - "name": asset["name"], - "size": asset["size"], - "download_count": asset["download_count"], - "browser_download_url": asset["browser_download_url"], - "created_at": asset["created_at"], - "updated_at": asset["updated_at"], - } - return bundles_info + if tag == "hosting_storage_v1": + for release in releases_list: + if release["tag_name"] == tag: + for asset in release["assets"]: + asset_name = bundle_name_pattern.split(asset["name"])[0] + if asset_name not in bundles_info: + bundles_info[asset_name] = {} + asset_version = asset["name"].split(f"{asset_name}_v")[-1].replace(".zip", "") + bundles_info[asset_name][asset_version] = dict(asset) + return bundles_info + else: + for asset in releases_list.keys(): + asset_name = bundle_name_pattern.split(asset)[0] + if asset_name not in bundles_info: + bundles_info[asset_name] = {} + asset_version = asset.split(f"{asset_name}_v")[-1] + bundles_info[asset_name][asset_version] = { + "name": asset, + "browser_download_url": releases_list[asset]["source"], + } return bundles_info +@deprecated_arg_default("tag", "hosting_storage_v1", "dev", since="1.3", replaced="1.5") def get_all_bundles_list( repo: str = "Project-MONAI/model-zoo", tag: str = "hosting_storage_v1", auth_token: str | None = None ) -> list[tuple[str, str]]: """ Get all bundles names (and the latest versions) that are stored in the release of specified repository - with the provided tag. The default values of arguments correspond to the release of MONAI model zoo. - In order to increase the rate limits of calling Github APIs, you can input your personal access token. + with the provided tag. If tag is "dev", will get model information from + https://raw.githubusercontent.com/repo_owner/repo_name/dev/models/model_info.json. + The default values of arguments correspond to the release of MONAI model zoo. In order to increase the + rate limits of calling Github APIs, you can input your personal access token. Please check the following link for more details about rate limiting: https://docs.github.com/en/rest/overview/resources-in-the-rest-api#rate-limiting @@ -610,6 +621,7 @@ def get_all_bundles_list( return bundles_list +@deprecated_arg_default("tag", "hosting_storage_v1", "dev", since="1.3", replaced="1.5") def get_bundle_versions( bundle_name: str, repo: str = "Project-MONAI/model-zoo", @@ -618,7 +630,8 @@ def get_bundle_versions( ) -> dict[str, list[str] | str]: """ Get the latest version, as well as all existing versions of a bundle that is stored in the release of specified - repository with the provided tag. + repository with the provided tag. If tag is "dev", will get model information from + https://raw.githubusercontent.com/repo_owner/repo_name/dev/models/model_info.json. In order to increase the rate limits of calling Github APIs, you can input your personal access token. Please check the following link for more details about rate limiting: https://docs.github.com/en/rest/overview/resources-in-the-rest-api#rate-limiting @@ -646,6 +659,7 @@ def get_bundle_versions( return {"latest_version": all_versions[-1], "all_versions": all_versions} +@deprecated_arg_default("tag", "hosting_storage_v1", "dev", since="1.3", replaced="1.5") def get_bundle_info( bundle_name: str, version: str | None = None, @@ -656,7 +670,9 @@ def get_bundle_info( """ Get all information (include "id", "name", "size", "download_count", "browser_download_url", "created_at", "updated_at") of a bundle - with the specified bundle name and version. + with the specified bundle name and version which is stored in the release of specified repository with the provided tag. + Since v1.5, "hosting_storage_v1" will be deprecated in favor of 'dev', which contains only "name" and "browser_download_url". + information about a bundle. In order to increase the rate limits of calling Github APIs, you can input your personal access token. Please check the following link for more details about rate limiting: https://docs.github.com/en/rest/overview/resources-in-the-rest-api#rate-limiting @@ -685,7 +701,7 @@ def get_bundle_info( if version not in bundle_info: raise ValueError(f"version: {version} of bundle: {bundle_name} is not existing.") - return bundle_info[version] + return bundle_info[version] # type: ignore[no-any-return] @deprecated_arg("runner_id", since="1.1", removed="1.3", new_name="run_id", msg_suffix="please use `run_id` instead.") diff --git a/tests/test_bundle_get_data.py b/tests/test_bundle_get_data.py index a560f3945f..a2e6f642e5 100644 --- a/tests/test_bundle_get_data.py +++ b/tests/test_bundle_get_data.py @@ -25,21 +25,34 @@ TEST_CASE_2 = [{"bundle_name": "spleen_ct_segmentation", "version": "0.1.0", "auth_token": None}] -TEST_CASE_FAKE_TOKEN = [{"bundle_name": "spleen_ct_segmentation", "version": "0.1.0", "auth_token": "ghp_errortoken"}] +TEST_CASE_3 = [{"tag": "hosting_storage_v1"}] + +TEST_CASE_4 = [{"tag": "dev"}] + +TEST_CASE_5 = [{"bundle_name": "brats_mri_segmentation", "tag": "dev"}] + +TEST_CASE_6 = [{"bundle_name": "spleen_ct_segmentation", "version": "0.1.0", "auth_token": None, "tag": "dev"}] + +TEST_CASE_FAKE_TOKEN_1 = [{"bundle_name": "spleen_ct_segmentation", "version": "0.1.0", "auth_token": "ghp_errortoken"}] + +TEST_CASE_FAKE_TOKEN_2 = [ + {"bundle_name": "spleen_ct_segmentation", "version": "0.1.0", "auth_token": "ghp_errortoken", "tag": "dev"} +] @skip_if_windows @SkipIfNoModule("requests") class TestGetBundleData(unittest.TestCase): + @parameterized.expand([TEST_CASE_3, TEST_CASE_4]) @skip_if_quick - def test_get_all_bundles_list(self): + def test_get_all_bundles_list(self, params): with skip_if_downloading_fails(): - output = get_all_bundles_list() + output = get_all_bundles_list(**params) self.assertTrue(isinstance(output, list)) self.assertTrue(isinstance(output[0], tuple)) self.assertTrue(len(output[0]) == 2) - @parameterized.expand([TEST_CASE_1]) + @parameterized.expand([TEST_CASE_1, TEST_CASE_5]) @skip_if_quick def test_get_bundle_versions(self, params): with skip_if_downloading_fails(): @@ -57,7 +70,16 @@ def test_get_bundle_info(self, params): for key in ["id", "name", "size", "download_count", "browser_download_url"]: self.assertTrue(key in output) - @parameterized.expand([TEST_CASE_FAKE_TOKEN]) + @parameterized.expand([TEST_CASE_5, TEST_CASE_6]) + @skip_if_quick + def test_get_bundle_info_monaihosting(self, params): + with skip_if_downloading_fails(): + output = get_bundle_info(**params) + self.assertTrue(isinstance(output, dict)) + for key in ["name", "browser_download_url"]: + self.assertTrue(key in output) + + @parameterized.expand([TEST_CASE_FAKE_TOKEN_1, TEST_CASE_FAKE_TOKEN_2]) @skip_if_quick def test_fake_token(self, params): with skip_if_downloading_fails():