diff --git a/terratorch/models/backbones/prithvi_vit.py b/terratorch/models/backbones/prithvi_vit.py index 9b9b4f96..2561e12f 100644 --- a/terratorch/models/backbones/prithvi_vit.py +++ b/terratorch/models/backbones/prithvi_vit.py @@ -207,10 +207,6 @@ def _create_prithvi( prithvi_model_class = PrithviMAE checkpoint_filter_wrapper_fn = checkpoint_filter_fn_mae - if pretrained: - assert variant in pretrained_weights, (f"No pre-trained model found for variant {variant} " - f"(pretrained models: {pretrained_weights.keys()})") - model = prithvi_model_class(**model_args) if pretrained: @@ -224,6 +220,9 @@ def _create_prithvi( if loaded_keys.unexpected_keys: logger.warning(f"Missing keys in ckpt_path {ckpt_path}: {loaded_keys.missing_keys}") else: + assert variant in pretrained_weights, (f"No pre-trained model found for variant {variant} " + f"(pretrained models: {pretrained_weights.keys()})") + try: # Download config.json to count model downloads _ = hf_hub_download(repo_id=pretrained_weights[variant]["hf_hub_id"], filename="config.json") @@ -236,6 +235,8 @@ def _create_prithvi( except RuntimeError as e: logger.error(f"Failed to load the pre-trained weights for {variant}.") raise e + elif ckpt_path is not None: + logger.warning(f"ckpt_path is provided but pretrained is set to False, ignoring ckpt_path {ckpt_path}.") model.model_bands = model_bands model.pretrained_bands = pretrained_bands