Skip to content

Commit

Permalink
Change --load_best=true to fail if checkpoint does not exist.
Browse files Browse the repository at this point in the history
Also adjust it to support loading checkpoint with arbitrary name in the
checkpoint directory even if last.ckpt doesn't exist.
  • Loading branch information
favyen2 committed Feb 5, 2025
1 parent 689b7ca commit b624da2
Show file tree
Hide file tree
Showing 4 changed files with 124 additions and 55 deletions.
4 changes: 2 additions & 2 deletions docs/landsat_vessels.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,10 @@ First, download the detector and classifier checkpoints to the `RSLP_PREFIX` dir

cd rslearn_projects
mkdir -p project_data/projects/landsat_vessels/data_20240924_model_20240924_imagenet_patch512_flip_03/checkpoints/
wget https://storage.googleapis.com/ai2-rslearn-projects-data/landsat_vessels/detector/best.ckpt -O project_data/projects/landsat_vessels/data_20240924_model_20240924_imagenet_patch512_flip_03/checkpoints/last.ckpt
wget https://storage.googleapis.com/ai2-rslearn-projects-data/landsat_vessels/detector/best.ckpt -O project_data/projects/landsat_vessels/data_20240924_model_20240924_imagenet_patch512_flip_03/checkpoints/best.ckpt

mkdir -p project_data/projects/rslearn-landsat-recheck/phase123_20240919_01_copy/checkpoints/
wget https://storage.googleapis.com/ai2-rslearn-projects-data/landsat_vessels/classifer/best.ckpt -O project_data/projects/rslearn-landsat-recheck/phase123_20240919_01_copy/checkpoints/last.ckpt
wget https://storage.googleapis.com/ai2-rslearn-projects-data/landsat_vessels/classifer/best.ckpt -O project_data/projects/rslearn-landsat-recheck/phase123_20240919_01_copy/checkpoints/best.ckpt

The easiest way to apply the model is using the prediction pipeline in `rslp/landsat_vessels/predict_pipeline.py`. You can download the Landsat scene files, e.g. from USGS EarthExplorer or AWS, and then create a configuration file for the prediction pipeline, here is an example:

Expand Down
4 changes: 2 additions & 2 deletions docs/sentinel2_vessels.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ First, download the model checkpoint to the `RSLP_PREFIX` directory.

cd rslearn_projects
mkdir -p project_data/projects/sentinel2_vessels/data_20240927_satlaspretrain_patch512_00/checkpoints/
wget https://storage.googleapis.com/ai2-rslearn-projects-data/sentinel2_vessels/best.ckpt -O project_data/projects/sentinel2_vessels/data_20240927_satlaspretrain_patch512_00/checkpoints/last.ckpt
wget https://storage.googleapis.com/ai2-rslearn-projects-data/sentinel2_vessels/best.ckpt -O project_data/projects/sentinel2_vessels/data_20240927_satlaspretrain_patch512_00/checkpoints/best.ckpt

The easiest way to apply the model is using the prediction pipeline in
`rslp/sentinel2_vessels/predict_pipeline.py`. It accepts a Sentinel-2 scene ID and
Expand All @@ -29,7 +29,7 @@ automatically downloads the scene images from a

mkdir output_crops
mkdir scratch_dir
python -m rslp.main sentinel2_vessels predict '["scene_id": "S2A_MSIL1C_20180904T110621_N0206_R137_T30UYD_20180904T133425", "json_path": "out.json", "crop_path": "output_crops/"]' scratch_dir/
python -m rslp.main sentinel2_vessels predict '[{"scene_id": "S2A_MSIL1C_20180904T110621_N0206_R137_T30UYD_20180904T133425", "json_path": "out.json", "crop_path": "output_crops/"}]' scratch_dir/

Then, `out.json` will contain a JSON list of detected ships while `output_crops` will
contain corresponding crops centered around those ships (showing the RGB B4/B3/B2
Expand Down
146 changes: 95 additions & 51 deletions rslp/lightning_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,95 @@ def add_arguments_to_parser(self, parser: jsonargparse.ArgumentParser) -> None:
default=False,
)

def _get_checkpoint_path(
self, checkpoint_dir: UPath, load_best: bool = False, autoresume: bool = False
) -> str | None:
"""Get path to checkpoint to load from, or None to not restore checkpoint.
With --load_best=true, we load the best-performing checkpoint. An error is
thrown if the checkpoint doesn't exist.
With --autoresume=true, we load last.ckpt if it exists, but proceed with
default initialization otherwise.
Otherwise, we do not restore any existing checkpoint (i.e., we use default
initialization), and throw an error if there is an existing checkpoint.
When training, it is suggested to use no option (don't expect to restart
training) or --autoresume=true (if restart is expected, e.g. due to
preemption). For inference, it is suggested to use --load_best=true.
Args:
checkpoint_dir: the directory where checkpoints are stored.
load_best: whether to load the best performing checkpoint and require a
checkpoint to exist.
autoresume: whether to load the checkpoint if it exists but proceed even
if it does not.
Returns:
the path to the checkpoint for setting c.ckpt_path, or None if no
checkpoint should be restored.
"""
if load_best:
# Checkpoints should be either:
# - last.ckpt
# - of the form "A=B-C=D-....ckpt" with one key being epoch=X
# So we want the one with the highest epoch, and only use last.ckpt if
# it's the only option.
# User should set save_top_k=1 so there's just one, otherwise we won't
# actually know which one is the best.
best_checkpoint = None
best_epochs = None
for option in checkpoint_dir.iterdir():
if not option.name.endswith(".ckpt"):
continue

# Try to see what epochs this checkpoint is at.
# If it is some other format, then set it 0 so we only use it if it's
# the only option.
# If it is last.ckpt then we set it -100 to only use it if there is not
# even another format like "best.ckpt".
extracted_epochs = 0
if option.name == "last.ckpt":
extracted_epochs = -100

parts = option.name.split(".ckpt")[0].split("-")
for part in parts:
kv_parts = part.split("=")
if len(kv_parts) != 2:
continue
if kv_parts[0] != "epoch":
continue
extracted_epochs = int(kv_parts[1])

if best_checkpoint is None or extracted_epochs > best_epochs:
best_checkpoint = option
best_epochs = extracted_epochs

if best_checkpoint is None:
raise ValueError(
f"load_best enabled but no checkpoint is available in {checkpoint_dir}"
)

# Cache the checkpoint so we only need to download once in case we
# reuse it later.
# We only cache with --load_best since this is the only scenario where it
return get_cached_checkpoint(best_checkpoint)

elif autoresume:
last_checkpoint_path = checkpoint_dir / "last.ckpt"
if last_checkpoint_path.exists():
return last_checkpoint_path
else:
return None

else:
last_checkpoint_path = checkpoint_dir / "last.ckpt"
if last_checkpoint_path.exists():
raise ValueError("autoresume is off but checkpoint already exists")
else:
return None

def before_instantiate_classes(self) -> None:
"""Called before Lightning class initialization."""
super().before_instantiate_classes()
Expand Down Expand Up @@ -260,57 +349,12 @@ def before_instantiate_classes(self) -> None:
)
c.trainer.callbacks.append(upload_wandb_callback)

# Check if there is an existing checkpoint.
# If so, and autoresume/load_best are disabled, we should throw error.
# If autoresume is enabled, then we should resume from last.ckpt.
# If load_best is enabled, then we should try to identify the best checkpoint.
# We still use last.ckpt to see if checkpoint exists since last.ckpt should
# always be written.
if (checkpoint_dir / "last.ckpt").exists():
if c.load_best:
# Checkpoints should be either:
# - last.ckpt
# - of the form "A=B-C=D-....ckpt" with one key being epoch=X
# So we want the one with the highest epoch, and only use last.ckpt if
# it's the only option.
# User should set save_top_k=1 so there's just one, otherwise we won't
# actually know which one is the best.
best_checkpoint = None
best_epochs = None
for option in checkpoint_dir.iterdir():
if not option.name.endswith(".ckpt"):
continue

# Try to see what epochs this checkpoint is at.
# If it is last.ckpt or some other format, then set it 0 so we only
# use it if it's the only option.
extracted_epochs = 0
parts = option.name.split(".ckpt")[0].split("-")
for part in parts:
kv_parts = part.split("=")
if len(kv_parts) != 2:
continue
if kv_parts[0] != "epoch":
continue
extracted_epochs = int(kv_parts[1])

if best_checkpoint is None or extracted_epochs > best_epochs:
best_checkpoint = option
best_epochs = extracted_epochs

# Cache the checkpoint so we only need to download once in case we
# reuse it later.
c.ckpt_path = get_cached_checkpoint(best_checkpoint)

elif c.autoresume:
# Don't cache the checkpoint here since last.ckpt could change if the
# model is trained further.
c.ckpt_path = str(checkpoint_dir / "last.ckpt")

else:
raise ValueError("autoresume is off but checkpoint already exists")

logger.info(f"found checkpoint to resume from at {c.ckpt_path}")
checkpoint_path = self._get_checkpoint_path(
checkpoint_dir, load_best=c.load_best, autoresume=c.autoresume
)
if checkpoint_path is not None:
logger.info(f"found checkpoint to resume from at {checkpoint_path}")
c.ckpt_path = checkpoint_path

wandb_id = launcher_lib.download_wandb_id(
c.rslp_project, c.rslp_experiment, run_id
Expand Down
25 changes: 25 additions & 0 deletions tests/integration/test_lightning_cli.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import pathlib
import shutil
from typing import Any

import pytest
from upath import UPath

from rslp.utils.rslearn import run_model_predict


def test_error_if_no_checkpoint(tmp_path: pathlib.Path, monkeypatch: Any) -> None:
"""Verify that an error is raised if --load_best=true with no checkpoint."""
# We need to use some config for this, so here we use the landsat_vessels one.
model_config_fname = "data/landsat_vessels/config_detector.yaml"
ds_config_fname = "data/landsat_vessels/predict_dataset_config.json"

# Copy the config.json so that the dataset is valid.
shutil.copyfile(ds_config_fname, tmp_path / "config.json")
(tmp_path / "windows" / "default").mkdir(parents=True)

# Overwrite RSLP_PREFIX to ensure the checkpoint won't exist.
monkeypatch.setenv("RSLP_PREFIX", str(tmp_path))

with pytest.raises(FileNotFoundError):
run_model_predict(model_config_fname, UPath(tmp_path))

0 comments on commit b624da2

Please sign in to comment.