diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index f829636..c4477d3 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -11,7 +11,7 @@ # blacken-docs: Checks docs follow black format standard. # pydocstyle: Checking docstring style. -default_stages: ["commit", "commit-msg", "push"] +default_stages: ["pre-commit", "commit-msg", "pre-push"] default_language_version: python: python3.10 diff --git a/instageo/model/model.py b/instageo/model/model.py index 473dc7e..fe2b5a4 100644 --- a/instageo/model/model.py +++ b/instageo/model/model.py @@ -137,14 +137,14 @@ def __init__( super().__init__() weights_dir = Path.home() / ".instageo" / "prithvi" weights_dir.mkdir(parents=True, exist_ok=True) - weights_path = weights_dir / "Prithvi_100M.pt" - cfg_path = weights_dir / "Prithvi_100M_config.yaml" + weights_path = weights_dir / "Prithvi_EO_V1_100M.pt" + cfg_path = weights_dir / "config.yaml" download_file( - "https://huggingface.co/ibm-nasa-geospatial/Prithvi-100M/resolve/main/Prithvi_100M.pt?download=true", # noqa + "https://huggingface.co/ibm-nasa-geospatial/Prithvi-EO-1.0-100M/resolve/main/Prithvi_EO_V1_100M.pt?download=true", # noqa weights_path, ) download_file( - "https://huggingface.co/ibm-nasa-geospatial/Prithvi-100M/raw/main/Prithvi_100M_config.yaml", # noqa + "https://huggingface.co/ibm-nasa-geospatial/Prithvi-100M/raw/main/config.yaml", # noqa cfg_path, ) checkpoint = torch.load(weights_path, map_location="cpu") @@ -160,7 +160,6 @@ def __init__( if freeze_backbone: for param in model.parameters(): param.requires_grad = False - del checkpoint["pos_embed"] _ = model.load_state_dict(checkpoint, strict=False) self.prithvi_100M_backbone = model