Skip to content

Commit

Permalink
fix: update URLS to download Prithvi Segmentation Model weights and c…
Browse files Browse the repository at this point in the history
…onfig (#11)

* fix: update URLS to download Prithvi Segmentation Model weights and config

* Update model.py

---------

Co-authored-by: Ibrahim Salihu Yusuf <[email protected]>
  • Loading branch information
BioGeek and Alikerin authored Jan 9, 2025
1 parent e125c97 commit c2c8630
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 6 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# blacken-docs: Checks docs follow black format standard.
# pydocstyle: Checking docstring style.

default_stages: ["commit", "commit-msg", "push"]
default_stages: ["pre-commit", "commit-msg", "pre-push"]
default_language_version:
python: python3.10

Expand Down
9 changes: 4 additions & 5 deletions instageo/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,14 +137,14 @@ def __init__(
super().__init__()
weights_dir = Path.home() / ".instageo" / "prithvi"
weights_dir.mkdir(parents=True, exist_ok=True)
weights_path = weights_dir / "Prithvi_100M.pt"
cfg_path = weights_dir / "Prithvi_100M_config.yaml"
weights_path = weights_dir / "Prithvi_EO_V1_100M.pt"
cfg_path = weights_dir / "config.yaml"
download_file(
"https://huggingface.co/ibm-nasa-geospatial/Prithvi-100M/resolve/main/Prithvi_100M.pt?download=true", # noqa
"https://huggingface.co/ibm-nasa-geospatial/Prithvi-EO-1.0-100M/resolve/main/Prithvi_EO_V1_100M.pt?download=true", # noqa
weights_path,
)
download_file(
"https://huggingface.co/ibm-nasa-geospatial/Prithvi-100M/raw/main/Prithvi_100M_config.yaml", # noqa
"https://huggingface.co/ibm-nasa-geospatial/Prithvi-100M/raw/main/config.yaml", # noqa
cfg_path,
)
checkpoint = torch.load(weights_path, map_location="cpu")
Expand All @@ -160,7 +160,6 @@ def __init__(
if freeze_backbone:
for param in model.parameters():
param.requires_grad = False
del checkpoint["pos_embed"]
_ = model.load_state_dict(checkpoint, strict=False)

self.prithvi_100M_backbone = model
Expand Down

0 comments on commit c2c8630

Please sign in to comment.