28
28
29
29
from monai .apps .mmars .mmars import _get_all_ngc_models
30
30
from monai .apps .utils import _basename , download_url , extractall , get_logger
31
+ from monai .bundle .config_item import ConfigComponent
31
32
from monai .bundle .config_parser import ConfigParser
32
33
from monai .bundle .utils import DEFAULT_INFERENCE , DEFAULT_METADATA
33
34
from monai .bundle .workflows import BundleWorkflow , ConfigWorkflow
@@ -247,7 +248,7 @@ def _process_bundle_dir(bundle_dir: PathLike | None = None) -> Path:
247
248
return Path (bundle_dir )
248
249
249
250
250
- @deprecated_arg_default ("source" , "github" , "monaihosting" , since = "1.3" , replaced = "1.4 " )
251
+ @deprecated_arg_default ("source" , "github" , "monaihosting" , since = "1.3" , replaced = "1.5 " )
251
252
def download (
252
253
name : str | None = None ,
253
254
version : str | None = None ,
@@ -375,8 +376,9 @@ def download(
375
376
)
376
377
377
378
378
- @deprecated_arg ("net_name" , since = "1.3" , removed = "1.4" , msg_suffix = "please use ``model`` instead." )
379
- @deprecated_arg ("net_kwargs" , since = "1.3" , removed = "1.3" , msg_suffix = "please use ``model`` instead." )
379
+ @deprecated_arg ("net_name" , since = "1.3" , removed = "1.5" , msg_suffix = "please use ``model`` instead." )
380
+ @deprecated_arg ("net_kwargs" , since = "1.3" , removed = "1.5" , msg_suffix = "please use ``model`` instead." )
381
+ @deprecated_arg ("return_state_dict" , since = "1.3" , removed = "1.5" )
380
382
def load (
381
383
name : str ,
382
384
model : torch .nn .Module | None = None ,
@@ -395,8 +397,10 @@ def load(
395
397
workflow_name : str | BundleWorkflow | None = None ,
396
398
args_file : str | None = None ,
397
399
copy_model_args : dict | None = None ,
400
+ return_state_dict : bool = True ,
401
+ net_override : dict | None = None ,
398
402
net_name : str | None = None ,
399
- ** net_override : Any ,
403
+ ** net_kwargs : Any ,
400
404
) -> object | tuple [torch .nn .Module , dict , dict ] | Any :
401
405
"""
402
406
Load model weights or TorchScript module of a bundle.
@@ -441,7 +445,12 @@ def load(
441
445
workflow_name: specified bundle workflow name, should be a string or class, default to "ConfigWorkflow".
442
446
args_file: a JSON or YAML file to provide default values for all the args in "download" function.
443
447
copy_model_args: other arguments for the `monai.networks.copy_model_state` function.
444
- net_override: id-value pairs to override the parameters in the network of the bundle.
448
+ return_state_dict: whether to return state dict, if True, return state_dict, else a corresponding network
449
+ from `_workflow.network_def` will be instantiated and load the achieved weights.
450
+ net_override: id-value pairs to override the parameters in the network of the bundle, default to `None`.
451
+ net_name: if not `None`, a corresponding network will be instantiated and load the achieved weights.
452
+ This argument only works when loading weights.
453
+ net_kwargs: other arguments that are used to instantiate the network class defined by `net_name`.
445
454
446
455
Returns:
447
456
1. If `load_ts_module` is `False` and `model` is `None`,
@@ -452,9 +461,15 @@ def load(
452
461
3. If `load_ts_module` is `True`, return a triple that include a TorchScript module,
453
462
the corresponding metadata dict, and extra files dict.
454
463
please check `monai.data.load_net_with_metadata` for more details.
464
+ 4. If `return_state_dict` is True, return model weights, only used for compatibility
465
+ when `model` and `net_name` are all `None`.
455
466
456
467
"""
468
+ if return_state_dict and (model is not None or net_name is not None ):
469
+ warnings .warn ("Incompatible values: model and net_name are all specified, return state dict instead." )
470
+
457
471
bundle_dir_ = _process_bundle_dir (bundle_dir )
472
+ net_override = {} if net_override is None else net_override
458
473
copy_model_args = {} if copy_model_args is None else copy_model_args
459
474
460
475
if device is None :
@@ -466,7 +481,7 @@ def load(
466
481
if remove_prefix :
467
482
name = _remove_ngc_prefix (name , prefix = remove_prefix )
468
483
full_path = os .path .join (bundle_dir_ , name , model_file )
469
- if not os .path .exists (full_path ) or model is None :
484
+ if not os .path .exists (full_path ):
470
485
download (
471
486
name = name ,
472
487
version = version ,
@@ -477,34 +492,52 @@ def load(
477
492
progress = progress ,
478
493
args_file = args_file ,
479
494
)
480
- train_config_file = bundle_dir_ / name / "configs" / f"{ workflow_type } .json"
481
- if train_config_file .is_file ():
482
- _net_override = {f"network_def#{ key } " : value for key , value in net_override .items ()}
483
- _workflow = create_workflow (
484
- workflow_name = workflow_name ,
485
- args_file = args_file ,
486
- config_file = str (train_config_file ),
487
- workflow_type = workflow_type ,
488
- ** _net_override ,
489
- )
490
- else :
491
- _workflow = None
492
495
493
496
# loading with `torch.jit.load`
494
497
if load_ts_module is True :
495
498
return load_net_with_metadata (full_path , map_location = torch .device (device ), more_extra_files = config_files )
496
499
# loading with `torch.load`
497
500
model_dict = torch .load (full_path , map_location = torch .device (device ))
501
+
498
502
if not isinstance (model_dict , Mapping ):
499
503
warnings .warn (f"the state dictionary from { full_path } should be a dictionary but got { type (model_dict )} ." )
500
504
model_dict = get_state_dict (model_dict )
501
505
502
- if model is None and _workflow is None :
506
+ if return_state_dict :
503
507
return model_dict
504
- model = _workflow .network_def if model is None else model
505
- model .to (device )
506
508
507
- copy_model_state (dst = model , src = model_dict if key_in_ckpt is None else model_dict [key_in_ckpt ], ** copy_model_args )
509
+ _workflow = None
510
+ if model is None and net_name is None :
511
+ bundle_config_file = bundle_dir_ / name / "configs" / f"{ workflow_type } .json"
512
+ if bundle_config_file .is_file ():
513
+ _net_override = {f"network_def#{ key } " : value for key , value in net_override .items ()}
514
+ _workflow = create_workflow (
515
+ workflow_name = workflow_name ,
516
+ args_file = args_file ,
517
+ config_file = str (bundle_config_file ),
518
+ workflow_type = workflow_type ,
519
+ ** _net_override ,
520
+ )
521
+ else :
522
+ warnings .warn (f"Cannot find the config file: { bundle_config_file } , return state dict instead." )
523
+ return model_dict
524
+ if _workflow is not None :
525
+ if not hasattr (_workflow , "network_def" ):
526
+ warnings .warn ("No available network definition in the bundle, return state dict instead." )
527
+ return model_dict
528
+ else :
529
+ model = _workflow .network_def
530
+ elif net_name is not None :
531
+ net_kwargs ["_target_" ] = net_name
532
+ configer = ConfigComponent (config = net_kwargs )
533
+ model = configer .instantiate () # type: ignore
534
+
535
+ model .to (device ) # type: ignore
536
+
537
+ copy_model_state (
538
+ dst = model , src = model_dict if key_in_ckpt is None else model_dict [key_in_ckpt ], ** copy_model_args # type: ignore
539
+ )
540
+
508
541
return model
509
542
510
543
0 commit comments