Skip to content

Commit 6f13b8d

Browse files
KumoLiuericspod
andauthored
Avoid breaking change in creating BundleWorkflow (#6950)
Fixes # . ### Description Avoid breaking changes introduced by #6835 - when creating `BundleWorkflow` - when using `load` API, add `return_state_dict` when `model` and `net_name` are both `None`. ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: KumoLiu <[email protected]> Signed-off-by: YunLiu <[email protected]> Co-authored-by: Eric Kerfoot <[email protected]>
1 parent 8ccde11 commit 6f13b8d

File tree

4 files changed

+101
-27
lines changed

4 files changed

+101
-27
lines changed

monai/bundle/scripts.py

+55-22
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828

2929
from monai.apps.mmars.mmars import _get_all_ngc_models
3030
from monai.apps.utils import _basename, download_url, extractall, get_logger
31+
from monai.bundle.config_item import ConfigComponent
3132
from monai.bundle.config_parser import ConfigParser
3233
from monai.bundle.utils import DEFAULT_INFERENCE, DEFAULT_METADATA
3334
from monai.bundle.workflows import BundleWorkflow, ConfigWorkflow
@@ -247,7 +248,7 @@ def _process_bundle_dir(bundle_dir: PathLike | None = None) -> Path:
247248
return Path(bundle_dir)
248249

249250

250-
@deprecated_arg_default("source", "github", "monaihosting", since="1.3", replaced="1.4")
251+
@deprecated_arg_default("source", "github", "monaihosting", since="1.3", replaced="1.5")
251252
def download(
252253
name: str | None = None,
253254
version: str | None = None,
@@ -375,8 +376,9 @@ def download(
375376
)
376377

377378

378-
@deprecated_arg("net_name", since="1.3", removed="1.4", msg_suffix="please use ``model`` instead.")
379-
@deprecated_arg("net_kwargs", since="1.3", removed="1.3", msg_suffix="please use ``model`` instead.")
379+
@deprecated_arg("net_name", since="1.3", removed="1.5", msg_suffix="please use ``model`` instead.")
380+
@deprecated_arg("net_kwargs", since="1.3", removed="1.5", msg_suffix="please use ``model`` instead.")
381+
@deprecated_arg("return_state_dict", since="1.3", removed="1.5")
380382
def load(
381383
name: str,
382384
model: torch.nn.Module | None = None,
@@ -395,8 +397,10 @@ def load(
395397
workflow_name: str | BundleWorkflow | None = None,
396398
args_file: str | None = None,
397399
copy_model_args: dict | None = None,
400+
return_state_dict: bool = True,
401+
net_override: dict | None = None,
398402
net_name: str | None = None,
399-
**net_override: Any,
403+
**net_kwargs: Any,
400404
) -> object | tuple[torch.nn.Module, dict, dict] | Any:
401405
"""
402406
Load model weights or TorchScript module of a bundle.
@@ -441,7 +445,12 @@ def load(
441445
workflow_name: specified bundle workflow name, should be a string or class, default to "ConfigWorkflow".
442446
args_file: a JSON or YAML file to provide default values for all the args in "download" function.
443447
copy_model_args: other arguments for the `monai.networks.copy_model_state` function.
444-
net_override: id-value pairs to override the parameters in the network of the bundle.
448+
return_state_dict: whether to return state dict, if True, return state_dict, else a corresponding network
449+
from `_workflow.network_def` will be instantiated and load the achieved weights.
450+
net_override: id-value pairs to override the parameters in the network of the bundle, default to `None`.
451+
net_name: if not `None`, a corresponding network will be instantiated and load the achieved weights.
452+
This argument only works when loading weights.
453+
net_kwargs: other arguments that are used to instantiate the network class defined by `net_name`.
445454
446455
Returns:
447456
1. If `load_ts_module` is `False` and `model` is `None`,
@@ -452,9 +461,15 @@ def load(
452461
3. If `load_ts_module` is `True`, return a triple that include a TorchScript module,
453462
the corresponding metadata dict, and extra files dict.
454463
please check `monai.data.load_net_with_metadata` for more details.
464+
4. If `return_state_dict` is True, return model weights, only used for compatibility
465+
when `model` and `net_name` are all `None`.
455466
456467
"""
468+
if return_state_dict and (model is not None or net_name is not None):
469+
warnings.warn("Incompatible values: model and net_name are all specified, return state dict instead.")
470+
457471
bundle_dir_ = _process_bundle_dir(bundle_dir)
472+
net_override = {} if net_override is None else net_override
458473
copy_model_args = {} if copy_model_args is None else copy_model_args
459474

460475
if device is None:
@@ -466,7 +481,7 @@ def load(
466481
if remove_prefix:
467482
name = _remove_ngc_prefix(name, prefix=remove_prefix)
468483
full_path = os.path.join(bundle_dir_, name, model_file)
469-
if not os.path.exists(full_path) or model is None:
484+
if not os.path.exists(full_path):
470485
download(
471486
name=name,
472487
version=version,
@@ -477,34 +492,52 @@ def load(
477492
progress=progress,
478493
args_file=args_file,
479494
)
480-
train_config_file = bundle_dir_ / name / "configs" / f"{workflow_type}.json"
481-
if train_config_file.is_file():
482-
_net_override = {f"network_def#{key}": value for key, value in net_override.items()}
483-
_workflow = create_workflow(
484-
workflow_name=workflow_name,
485-
args_file=args_file,
486-
config_file=str(train_config_file),
487-
workflow_type=workflow_type,
488-
**_net_override,
489-
)
490-
else:
491-
_workflow = None
492495

493496
# loading with `torch.jit.load`
494497
if load_ts_module is True:
495498
return load_net_with_metadata(full_path, map_location=torch.device(device), more_extra_files=config_files)
496499
# loading with `torch.load`
497500
model_dict = torch.load(full_path, map_location=torch.device(device))
501+
498502
if not isinstance(model_dict, Mapping):
499503
warnings.warn(f"the state dictionary from {full_path} should be a dictionary but got {type(model_dict)}.")
500504
model_dict = get_state_dict(model_dict)
501505

502-
if model is None and _workflow is None:
506+
if return_state_dict:
503507
return model_dict
504-
model = _workflow.network_def if model is None else model
505-
model.to(device)
506508

507-
copy_model_state(dst=model, src=model_dict if key_in_ckpt is None else model_dict[key_in_ckpt], **copy_model_args)
509+
_workflow = None
510+
if model is None and net_name is None:
511+
bundle_config_file = bundle_dir_ / name / "configs" / f"{workflow_type}.json"
512+
if bundle_config_file.is_file():
513+
_net_override = {f"network_def#{key}": value for key, value in net_override.items()}
514+
_workflow = create_workflow(
515+
workflow_name=workflow_name,
516+
args_file=args_file,
517+
config_file=str(bundle_config_file),
518+
workflow_type=workflow_type,
519+
**_net_override,
520+
)
521+
else:
522+
warnings.warn(f"Cannot find the config file: {bundle_config_file}, return state dict instead.")
523+
return model_dict
524+
if _workflow is not None:
525+
if not hasattr(_workflow, "network_def"):
526+
warnings.warn("No available network definition in the bundle, return state dict instead.")
527+
return model_dict
528+
else:
529+
model = _workflow.network_def
530+
elif net_name is not None:
531+
net_kwargs["_target_"] = net_name
532+
configer = ConfigComponent(config=net_kwargs)
533+
model = configer.instantiate() # type: ignore
534+
535+
model.to(device) # type: ignore
536+
537+
copy_model_state(
538+
dst=model, src=model_dict if key_in_ckpt is None else model_dict[key_in_ckpt], **copy_model_args # type: ignore
539+
)
540+
508541
return model
509542

510543

monai/bundle/workflows.py

+12-1
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,10 @@ class BundleWorkflow(ABC):
4343
or "infer", "inference", "eval", "evaluation" for a inference workflow,
4444
other unsupported string will raise a ValueError.
4545
default to `None` for common workflow.
46+
workflow: specifies the workflow type: "train" or "training" for a training workflow,
47+
or "infer", "inference", "eval", "evaluation" for a inference workflow,
48+
other unsupported string will raise a ValueError.
49+
default to `None` for common workflow.
4650
4751
"""
4852

@@ -56,7 +60,8 @@ class BundleWorkflow(ABC):
5660
new_name="workflow_type",
5761
msg_suffix="please use `workflow_type` instead.",
5862
)
59-
def __init__(self, workflow_type: str | None = None):
63+
def __init__(self, workflow_type: str | None = None, workflow: str | None = None):
64+
workflow_type = workflow if workflow is not None else workflow_type
6065
if workflow_type is None:
6166
self.properties = copy(MetaProperties)
6267
self.workflow_type = None
@@ -198,6 +203,10 @@ class ConfigWorkflow(BundleWorkflow):
198203
or "infer", "inference", "eval", "evaluation" for a inference workflow,
199204
other unsupported string will raise a ValueError.
200205
default to `None` for common workflow.
206+
workflow: specifies the workflow type: "train" or "training" for a training workflow,
207+
or "infer", "inference", "eval", "evaluation" for a inference workflow,
208+
other unsupported string will raise a ValueError.
209+
default to `None` for common workflow.
201210
override: id-value pairs to override or add the corresponding config content.
202211
e.g. ``--net#input_chns 42``, ``--net %/data/other.json#net_arg``
203212
@@ -221,8 +230,10 @@ def __init__(
221230
final_id: str = "finalize",
222231
tracking: str | dict | None = None,
223232
workflow_type: str | None = None,
233+
workflow: str | None = None,
224234
**override: Any,
225235
) -> None:
236+
workflow_type = workflow if workflow is not None else workflow_type
226237
super().__init__(workflow_type=workflow_type)
227238
if config_file is not None:
228239
_config_files = ensure_tuple(config_file)

tests/ngc_bundle_download.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,12 @@ def test_ngc_download_bundle(self, bundle_name, version, remove_prefix, download
8383
self.assertTrue(check_hash(filepath=full_file_path, val=hash_val))
8484

8585
model = load(
86-
name=bundle_name, source="ngc", version=version, bundle_dir=tempdir, remove_prefix=remove_prefix
86+
name=bundle_name,
87+
source="ngc",
88+
version=version,
89+
bundle_dir=tempdir,
90+
remove_prefix=remove_prefix,
91+
return_state_dict=False,
8792
)
8893
assert_allclose(
8994
model.state_dict()[TESTCASE_WEIGHTS["key"]],

tests/test_bundle_download.py

+28-3
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,7 @@ def test_load_weights(self, bundle_files, bundle_name, repo, device, model_file)
146146
source="github",
147147
progress=False,
148148
device=device,
149+
return_state_dict=True,
149150
)
150151

151152
# prepare network
@@ -174,21 +175,44 @@ def test_load_weights(self, bundle_files, bundle_name, repo, device, model_file)
174175
bundle_dir=tempdir,
175176
progress=False,
176177
device=device,
177-
net_name=model_name,
178178
source="github",
179+
return_state_dict=False,
179180
)
180181
model_2.eval()
181182
output_2 = model_2.forward(input_tensor)
182183
assert_allclose(output_2, expected_output, atol=1e-4, rtol=1e-4, type_test=False)
183184

185+
# test compatibility with return_state_dict=True.
186+
model_3 = load(
187+
name=bundle_name,
188+
model_file=model_file,
189+
bundle_dir=tempdir,
190+
progress=False,
191+
device=device,
192+
net_name=model_name,
193+
source="github",
194+
return_state_dict=False,
195+
**net_args,
196+
)
197+
model_3.eval()
198+
output_3 = model_3.forward(input_tensor)
199+
assert_allclose(output_3, expected_output, atol=1e-4, rtol=1e-4, type_test=False)
200+
184201
@parameterized.expand([TEST_CASE_7])
185202
@skip_if_quick
186203
def test_load_weights_with_net_override(self, bundle_name, device, net_override):
187204
with skip_if_downloading_fails():
188205
# download bundle, and load weights from the downloaded path
189206
with tempfile.TemporaryDirectory() as tempdir:
190207
# load weights
191-
model = load(name=bundle_name, bundle_dir=tempdir, source="monaihosting", progress=False, device=device)
208+
model = load(
209+
name=bundle_name,
210+
bundle_dir=tempdir,
211+
source="monaihosting",
212+
progress=False,
213+
device=device,
214+
return_state_dict=False,
215+
)
192216

193217
# prepare data and test
194218
input_tensor = torch.rand(1, 1, 96, 96, 96).to(device)
@@ -209,7 +233,8 @@ def test_load_weights_with_net_override(self, bundle_name, device, net_override)
209233
source="monaihosting",
210234
progress=False,
211235
device=device,
212-
**net_override,
236+
return_state_dict=False,
237+
net_override=net_override,
213238
)
214239

215240
# prepare data and test

0 commit comments

Comments
 (0)