Skip to content

Commit

Permalink
Update prithvi warnings
Browse files Browse the repository at this point in the history
Signed-off-by: Benedikt Blumenstiel <[email protected]>
  • Loading branch information
blumenstiel committed Feb 4, 2025
1 parent ff939cf commit 89c233a
Showing 1 changed file with 5 additions and 4 deletions.
9 changes: 5 additions & 4 deletions terratorch/models/backbones/prithvi_vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,10 +207,6 @@ def _create_prithvi(
prithvi_model_class = PrithviMAE
checkpoint_filter_wrapper_fn = checkpoint_filter_fn_mae

if pretrained:
assert variant in pretrained_weights, (f"No pre-trained model found for variant {variant} "
f"(pretrained models: {pretrained_weights.keys()})")

model = prithvi_model_class(**model_args)

if pretrained:
Expand All @@ -224,6 +220,9 @@ def _create_prithvi(
if loaded_keys.unexpected_keys:
logger.warning(f"Missing keys in ckpt_path {ckpt_path}: {loaded_keys.missing_keys}")
else:
assert variant in pretrained_weights, (f"No pre-trained model found for variant {variant} "
f"(pretrained models: {pretrained_weights.keys()})")

try:
# Download config.json to count model downloads
_ = hf_hub_download(repo_id=pretrained_weights[variant]["hf_hub_id"], filename="config.json")
Expand All @@ -236,6 +235,8 @@ def _create_prithvi(
except RuntimeError as e:
logger.error(f"Failed to load the pre-trained weights for {variant}.")
raise e
elif ckpt_path is not None:
logger.warning(f"ckpt_path is provided but pretrained is set to False, ignoring ckpt_path {ckpt_path}.")

model.model_bands = model_bands
model.pretrained_bands = pretrained_bands
Expand Down

0 comments on commit 89c233a

Please sign in to comment.