Skip to content

Commit

Permalink
add ngc download bundle (#5710)
Browse files Browse the repository at this point in the history
Signed-off-by: Yiheng Wang <[email protected]>

Fixes #5679  and #5320

### Description

This PR adds the support of download bundles from ngc:
https://catalog.ngc.nvidia.com/models?filters=&orderBy=scoreDESC&query=monai
In addition, when "version" is not provided, it changes to download the
latest version in default.

### Types of changes
<!--- Put an `x` in all the boxes that apply, and remove the not
applicable items -->
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [ ] Breaking change (fix or new feature that would cause existing
functionality to change).
- [ ] New tests added to cover the changes.
- [ ] Integration tests passed locally by running `./runtests.sh -f -u
--net --coverage`.
- [ ] Quick tests passed locally by running `./runtests.sh --quick
--unittests --disttests`.
- [ ] In-line docstrings updated.
- [ ] Documentation updated, tested `make html` command in the `docs/`
folder.

Signed-off-by: Yiheng Wang <[email protected]>
  • Loading branch information
yiheng-wang-nv authored Dec 17, 2022
1 parent 0abd04e commit 67d84d3
Show file tree
Hide file tree
Showing 4 changed files with 164 additions and 49 deletions.
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
# daily tests for clara mmar models
name: cron-mmar
# daily tests for ngc bundles
name: cron-ngc-bundle

on:
# schedule:
# - cron: "0 2 * * *" # at 02:00 UTC
schedule:
- cron: "0 2 * * *" # at 02:00 UTC
# Allows you to run this workflow manually from the Actions tab
workflow_dispatch:

concurrency:
# automatically cancel the previously triggered workflows when there's a newer version
group: mmar-tests-${{ github.event.pull_request.number || github.ref }}
group: bundle-tests-${{ github.event.pull_request.number || github.ref }}
cancel-in-progress: true

jobs:
Expand All @@ -33,12 +33,12 @@ jobs:
key: ${{ runner.os }}-pip-${{ steps.pip-cache.outputs.datew }}
- name: Install dependencies
run: |
rm -rf /github/home/.cache/torch/hub/mmars/
rm -rf /github/home/.cache/torch/hub/bundle/
python -m pip install --upgrade pip wheel
python -m pip install -r requirements-dev.txt
- name: Loading MMARs
- name: Loading Bundles
run: |
# clean up temporary files
$(pwd)/runtests.sh --build --clean
# run tests
python -m tests.ngc_mmar_loading
python -m tests.ngc_bundle_download
139 changes: 110 additions & 29 deletions monai/bundle/scripts.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import torch
from torch.cuda import is_available

from monai.apps.mmars.mmars import _get_all_ngc_models
from monai.apps.utils import _basename, download_url, extractall, get_logger
from monai.bundle.config_item import ConfigComponent
from monai.bundle.config_parser import ConfigParser
Expand All @@ -42,6 +43,9 @@

logger = get_logger(module_name=__name__)

# set BUNDLE_DOWNLOAD_SRC="ngc" to use NGC source in default for bundle download
download_source = os.environ.get("BUNDLE_DOWNLOAD_SRC", "github")


def _update_args(args: Optional[Union[str, Dict]] = None, ignore_none: bool = True, **kwargs) -> Dict:
"""
Expand Down Expand Up @@ -130,9 +134,11 @@ def _get_git_release_url(repo_owner: str, repo_name: str, tag_name: str, filenam
return f"https://github.com/{repo_owner}/{repo_name}/releases/download/{tag_name}/{filename}"


def _get_ngc_bundle_url(model_name: str, version: str):
return f"https://api.ngc.nvidia.com/v2/models/nvidia/monaitoolkit/{model_name}/versions/{version}/zip"


def _download_from_github(repo: str, download_path: Path, filename: str, progress: bool = True):
if len(repo.split("/")) != 3:
raise ValueError("if source is `github`, repo should be in the form of `repo_owner/repo_name/release_tag`.")
repo_owner, repo_name, tag_name = repo.split("/")
if ".zip" not in filename:
filename += ".zip"
Expand All @@ -142,6 +148,45 @@ def _download_from_github(repo: str, download_path: Path, filename: str, progres
extractall(filepath=filepath, output_dir=download_path, has_base=True)


def _add_ngc_prefix(name: str, prefix: str = "monai_"):
if name.startswith(prefix):
return name
return f"{prefix}{name}"


def _remove_ngc_prefix(name: str, prefix: str = "monai_"):
if name.startswith(prefix):
return name[len(prefix) :]
return name


def _download_from_ngc(download_path: Path, filename: str, version: str, remove_prefix: Optional[str], progress: bool):
# ensure prefix is contained
filename = _add_ngc_prefix(filename)
url = _get_ngc_bundle_url(model_name=filename, version=version)
filepath = download_path / f"{filename}_v{version}.zip"
if remove_prefix:
filename = _remove_ngc_prefix(filename)
extract_path = download_path / f"{filename}"
download_url(url=url, filepath=filepath, hash_val=None, progress=progress)
extractall(filepath=filepath, output_dir=extract_path, has_base=True)


def _get_latest_bundle_version(source: str, name: str, repo: str):
if source == "ngc":
name = _add_ngc_prefix(name)
model_dict = _get_all_ngc_models(name)
for v in model_dict.values():
if v["name"] == name:
return v["latest"]
return None
elif source == "github":
repo_owner, repo_name, tag_name = repo.split("/")
return get_bundle_versions(name, repo=os.path.join(repo_owner, repo_name), tag=tag_name)["latest_version"]
else:
raise ValueError(f"To get the latest bundle version, source should be 'github' or 'ngc', got {source}.")


def _process_bundle_dir(bundle_dir: Optional[PathLike] = None):
if bundle_dir is None:
get_dir, has_home = optional_import("torch.hub", name="get_dir")
Expand All @@ -156,9 +201,10 @@ def download(
name: Optional[str] = None,
version: Optional[str] = None,
bundle_dir: Optional[PathLike] = None,
source: str = "github",
repo: str = "Project-MONAI/model-zoo/hosting_storage_v1",
source: str = download_source,
repo: Optional[str] = None,
url: Optional[str] = None,
remove_prefix: Optional[str] = "monai_",
progress: bool = True,
args_file: Optional[str] = None,
):
Expand All @@ -175,9 +221,12 @@ def download(
# Execute this module as a CLI entry, and download bundle from the model-zoo repo:
python -m monai.bundle download --name <bundle_name> --version "0.1.0" --bundle_dir "./"
# Execute this module as a CLI entry, and download bundle:
# Execute this module as a CLI entry, and download bundle from specified github repo:
python -m monai.bundle download --name <bundle_name> --source "github" --repo "repo_owner/repo_name/release_tag"
# Execute this module as a CLI entry, and download bundle from ngc with latest version:
python -m monai.bundle download --name <bundle_name> --source "ngc" --bundle_dir "./"
# Execute this module as a CLI entry, and download bundle via URL:
python -m monai.bundle download --name <bundle_name> --url <url>
Expand All @@ -190,18 +239,27 @@ def download(
Args:
name: bundle name. If `None` and `url` is `None`, it must be provided in `args_file`.
for example: "spleen_ct_segmentation", "prostate_mri_anatomy" in the model-zoo:
for example:
"spleen_ct_segmentation", "prostate_mri_anatomy" in model-zoo:
https://github.com/Project-MONAI/model-zoo/releases/tag/hosting_storage_v1.
version: version name of the target bundle to download, like: "0.1.0".
"monai_brats_mri_segmentation" in ngc:
https://catalog.ngc.nvidia.com/models?filters=&orderBy=scoreDESC&query=monai.
version: version name of the target bundle to download, like: "0.1.0". If `None`, will download
the latest version.
bundle_dir: target directory to store the downloaded data.
Default is `bundle` subfolder under `torch.hub.get_dir()`.
source: storage location name. This argument is used when `url` is `None`.
"github" is currently the only supported value.
repo: repo name. This argument is used when `url` is `None`.
If `source` is "github", it should be in the form of "repo_owner/repo_name/release_tag".
In default, the value is achieved from the environment variable BUNDLE_DOWNLOAD_SRC, and
it should be "ngc" or "github".
repo: repo name. This argument is used when `url` is `None` and `source` is "github".
If used, it should be in the form of "repo_owner/repo_name/release_tag".
url: url to download the data. If not `None`, data will be downloaded directly
and `source` will not be checked.
If `name` is `None`, filename is determined by `monai.apps.utils._basename(url)`.
remove_prefix: This argument is used when `source` is "ngc". Currently, all ngc bundles
have the ``monai_`` prefix, which is not existing in their model zoo contrasts. In order to
maintain the consistency between these two sources, remove prefix is necessary.
Therefore, if specified, downloaded folder name will remove the prefix.
progress: whether to display a progress bar.
args_file: a JSON or YAML file to provide default values for all the args in this function.
so that the command line inputs can be simplified.
Expand All @@ -215,17 +273,20 @@ def download(
source=source,
repo=repo,
url=url,
remove_prefix=remove_prefix,
progress=progress,
)

_log_input_summary(tag="download", args=_args)
source_, repo_, progress_, name_, version_, bundle_dir_, url_ = _pop_args(
_args, "source", "repo", "progress", name=None, version=None, bundle_dir=None, url=None
source_, progress_, remove_prefix_, repo_, name_, version_, bundle_dir_, url_ = _pop_args(
_args, "source", "progress", remove_prefix=None, repo=None, name=None, version=None, bundle_dir=None, url=None
)

bundle_dir_ = _process_bundle_dir(bundle_dir_)
if name_ is not None and version_ is not None:
name_ = "_v".join([name_, version_])
if repo_ is None:
repo_ = "Project-MONAI/model-zoo/hosting_storage_v1"
if len(repo_.split("/")) != 3:
raise ValueError("repo should be in the form of `repo_owner/repo_name/release_tag`.")

if url_ is not None:
if name_ is not None:
Expand All @@ -234,14 +295,27 @@ def download(
filepath = bundle_dir_ / f"{_basename(url_)}"
download_url(url=url_, filepath=filepath, hash_val=None, progress=progress_)
extractall(filepath=filepath, output_dir=bundle_dir_, has_base=True)
elif source_ == "github":
if name_ is None:
raise ValueError(f"To download from source: Github, `name` must be provided, got {name_}.")
_download_from_github(repo=repo_, download_path=bundle_dir_, filename=name_, progress=progress_)
else:
raise NotImplementedError(
f"Currently only download from provided URL in `url` or Github is implemented, got source: {source_}."
)
if name_ is None:
raise ValueError(f"To download from source: {source_}, `name` must be provided.")
if version_ is None:
version_ = _get_latest_bundle_version(source=source_, name=name_, repo=repo_)
if source_ == "github":
if version_ is not None:
name_ = "_v".join([name_, version_])
_download_from_github(repo=repo_, download_path=bundle_dir_, filename=name_, progress=progress_)
elif source_ == "ngc":
_download_from_ngc(
download_path=bundle_dir_,
filename=name_,
version=version_,
remove_prefix=remove_prefix_,
progress=progress_,
)
else:
raise NotImplementedError(
f"Currently only download from `url`, source 'github' or 'ngc' are implemented, got source: {source_}."
)


def load(
Expand All @@ -250,8 +324,8 @@ def load(
model_file: Optional[str] = None,
load_ts_module: bool = False,
bundle_dir: Optional[PathLike] = None,
source: str = "github",
repo: str = "Project-MONAI/model-zoo/hosting_storage_v1",
source: str = download_source,
repo: Optional[str] = None,
progress: bool = True,
device: Optional[str] = None,
key_in_ckpt: Optional[str] = None,
Expand All @@ -263,18 +337,25 @@ def load(
Load model weights or TorchScript module of a bundle.
Args:
name: bundle name, for example: "spleen_ct_segmentation", "prostate_mri_anatomy" in the model-zoo:
name: bundle name. If `None` and `url` is `None`, it must be provided in `args_file`.
for example:
"spleen_ct_segmentation", "prostate_mri_anatomy" in model-zoo:
https://github.com/Project-MONAI/model-zoo/releases/tag/hosting_storage_v1.
version: version name of the target bundle to download, like: "0.1.0".
"monai_brats_mri_segmentation" in ngc:
https://catalog.ngc.nvidia.com/models?filters=&orderBy=scoreDESC&query=monai.
version: version name of the target bundle to download, like: "0.1.0". If `None`, will download
the latest version.
model_file: the relative path of the model weights or TorchScript module within bundle.
If `None`, "models/model.pt" or "models/model.ts" will be used.
load_ts_module: a flag to specify if loading the TorchScript module.
bundle_dir: directory the weights/TorchScript module will be loaded from.
Default is `bundle` subfolder under `torch.hub.get_dir()`.
source: storage location name. This argument is used when `model_file` is not existing locally and need to be
downloaded first. "github" is currently the only supported value.
repo: repo name. This argument is used when `model_file` is not existing locally and need to be
downloaded first. If `source` is "github", it should be in the form of "repo_owner/repo_name/release_tag".
downloaded first.
In default, the value is achieved from the environment variable BUNDLE_DOWNLOAD_SRC, and
it should be "ngc" or "github".
repo: repo name. This argument is used when `url` is `None` and `source` is "github".
If used, it should be in the form of "repo_owner/repo_name/release_tag".
progress: whether to display a progress bar when downloading.
device: target device of returned weights or module, if `None`, prefer to "cuda" if existing.
key_in_ckpt: for nested checkpoint like `{"model": XXX, "optimizer": XXX, ...}`, specify the key of model
Expand Down Expand Up @@ -421,7 +502,7 @@ def get_bundle_versions(

bundles_info = _get_all_bundles_info(repo=repo, tag=tag, auth_token=auth_token)
if bundle_name not in bundles_info:
raise ValueError(f"bundle: {bundle_name} is not existing.")
raise ValueError(f"bundle: {bundle_name} is not existing in repo: {repo}.")
bundle_info = bundles_info[bundle_name]
all_versions = sorted(bundle_info.keys())

Expand Down
36 changes: 36 additions & 0 deletions tests/ngc_mmar_loading.py → tests/ngc_bundle_download.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,50 @@

import os
import sys
import tempfile
import unittest

import torch
from parameterized import parameterized

from monai.apps import check_hash
from monai.apps.mmars import MODEL_DESC, load_from_mmar
from monai.bundle import download
from monai.config import print_debug_info
from monai.networks.utils import copy_model_state
from tests.utils import skip_if_downloading_fails, skip_if_quick, skip_if_windows

TEST_CASE_NGC_1 = [
"spleen_ct_segmentation",
"0.3.7",
None,
"monai_spleen_ct_segmentation",
"models/model.pt",
"b418a2dc8672ce2fd98dc255036e7a3d",
]
TEST_CASE_NGC_2 = [
"monai_spleen_ct_segmentation",
"0.3.7",
"monai_",
"spleen_ct_segmentation",
"models/model.pt",
"b418a2dc8672ce2fd98dc255036e7a3d",
]


@skip_if_windows
class TestNgcBundleDownload(unittest.TestCase):
@parameterized.expand([TEST_CASE_NGC_1, TEST_CASE_NGC_2])
@skip_if_quick
def test_ngc_download_bundle(self, bundle_name, version, remove_prefix, download_name, file_path, hash_val):
with skip_if_downloading_fails():
with tempfile.TemporaryDirectory() as tempdir:
download(
name=bundle_name, source="ngc", version=version, bundle_dir=tempdir, remove_prefix=remove_prefix
)
full_file_path = os.path.join(tempdir, download_name, file_path)
self.assertTrue(os.path.exists(full_file_path))
self.assertTrue(check_hash(filepath=full_file_path, val=hash_val))


@unittest.skip("deprecating mmar tests")
Expand Down
Loading

0 comments on commit 67d84d3

Please sign in to comment.