Skip to content

Commit

Permalink
Use rslearn functions to create window and layer paths
Browse files Browse the repository at this point in the history
  • Loading branch information
favyen2 committed Feb 6, 2025
1 parent 2620713 commit 2b6c0ed
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 35 deletions.
1 change: 1 addition & 0 deletions rslp/landsat_vessels/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

# Landsat config
LANDSAT_LAYER_NAME = "landsat"
OUTPUT_LAYER_NAME = "output"
LANDSAT_RESOLUTION = 15
LANDSAT_SOURCE = "landsat"

Expand Down
79 changes: 44 additions & 35 deletions rslp/landsat_vessels/predict_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,15 @@
import numpy as np
import rasterio
import rasterio.features
import shapely
from PIL import Image
from rslearn.const import WGS84_PROJECTION
from rslearn.data_sources import Item, data_source_from_config
from rslearn.data_sources.aws_landsat import LandsatOliTirs
from rslearn.dataset import Dataset, Window, WindowLayerData
from rslearn.utils import Projection
from rslearn.utils.get_utm_ups_crs import get_utm_ups_projection
from rslearn.utils.raster_format import GeotiffRasterFormat
from rslearn.utils.vector_format import GeojsonVectorFormat
from typing_extensions import TypedDict
from upath import UPath

Expand All @@ -31,6 +32,7 @@
LANDSAT_RESOLUTION,
LANDSAT_SOURCE,
LOCAL_FILES_DATASET_CONFIG,
OUTPUT_LAYER_NAME,
)
from rslp.log_utils import get_logger
from rslp.utils.filter import NearInfraFilter
Expand Down Expand Up @@ -82,11 +84,12 @@ def get_vessel_detections(
"""
# Create a window for applying detector.
group = "default"
window_path = ds_path / "windows" / group / "default"
window_name = "default"
window_path = Window.get_window_root(ds_path, group, window_name)
window = Window(
path=window_path,
group=group,
name="default",
name=window_name,
projection=projection,
bounds=bounds,
time_range=time_range,
Expand All @@ -111,27 +114,27 @@ def get_vessel_detections(
),
)
materialize_dataset(ds_path, materialize_pipeline_args)
assert (window_path / "layers" / LANDSAT_LAYER_NAME / "B8" / "geotiff.tif").exists()

# Sanity check that the layer is completed.
if not window.is_layer_completed(LANDSAT_LAYER_NAME):
raise ValueError("landsat layer did not get materialized")

# Run object detector.
run_model_predict(DETECT_MODEL_CONFIG, ds_path)

# Read the detections.
output_fname = window_path / "layers" / "output" / "data.geojson"
layer_dir = window.get_layer_dir(OUTPUT_LAYER_NAME)
features = GeojsonVectorFormat().decode_vector(layer_dir)
detections: list[VesselDetection] = []
with output_fname.open() as f:
feature_collection = json.load(f)
for feature in feature_collection["features"]:
shp = shapely.geometry.shape(feature["geometry"])
col = int(shp.centroid.x)
row = int(shp.centroid.y)
for feature in features:
geometry = feature.geometry
score = feature["properties"]["score"]

detection = VesselDetection(
source=LANDSAT_SOURCE,
col=col,
row=row,
projection=projection,
col=int(geometry.shp.centroid.x),
row=int(geometry.shp.centroid.y),
projection=geometry.projection,
score=score,
)
if item:
Expand Down Expand Up @@ -167,11 +170,10 @@ def run_classifier(

# Create windows for applying classifier.
group = "classify_predict"
window_paths: list[UPath] = []
windows: list[Window] = []
for detection in detections:
window_name = f"{detection.col}_{detection.row}"
window_path = ds_path / "windows" / group / window_name
detection.metadata["crop_window_dir"] = window_path
bounds = [
detection.col - CLASSIFY_WINDOW_SIZE // 2,
detection.row - CLASSIFY_WINDOW_SIZE // 2,
Expand All @@ -187,7 +189,8 @@ def run_classifier(
time_range=time_range,
)
window.save()
window_paths.append(window_path)
windows.append(window)
detection.metadata["crop_window"] = window

if item:
layer_data = WindowLayerData(LANDSAT_LAYER_NAME, [[item.serialize()]])
Expand All @@ -207,21 +210,20 @@ def run_classifier(
)
materialize_dataset(ds_path, materialize_pipeline_args)

for window_path in window_paths:
assert (
window_path / "layers" / LANDSAT_LAYER_NAME / "B8" / "geotiff.tif"
).exists()
# Verify that no window is unmaterialized.
for window in windows:
if not window.is_layer_completed(LANDSAT_LAYER_NAME):
raise ValueError(f"window {window.name} does not have materialized Landsat")

# Run classification model.
run_model_predict(CLASSIFY_MODEL_CONFIG, ds_path, groups=[group])

# Read the results.
good_detections = []
for detection, window_path in zip(detections, window_paths):
output_fname = window_path / "layers" / "output" / "data.geojson"
with output_fname.open() as f:
feature_collection = json.load(f)
category = feature_collection["features"][0]["properties"]["label"]
for detection, window in zip(detections, windows):
layer_dir = window.get_layer_dir(OUTPUT_LAYER_NAME)
features = GeojsonVectorFormat().decode_vector(layer_dir, window.bounds)
category = features[0].properties["label"]
if category == "correct":
good_detections.append(detection)

Expand Down Expand Up @@ -413,6 +415,7 @@ def predict_pipeline(
json_data = []
geojson_features = []
near_infra_filter = NearInfraFilter(infra_distance_threshold=INFRA_THRESHOLD_KM)
raster_format = GeotiffRasterFormat()
infra_detections = 0
for idx, detection in enumerate(detections):
# Apply near infra filter (True -> filter out, False -> keep)
Expand All @@ -427,16 +430,22 @@ def predict_pipeline(
# - rgb.png: true color with pan-sharpening. The RGB is from B4, B3, and B2
# respectively while B8 is used for pan-sharpening.
images = {}
crop_window_dir = detection.metadata["crop_window_dir"]
if crop_window_dir is None:
raise ValueError("Crop window directory is None")
crop_window: Window = detection.metadata["crop_window"]
if crop_window is None:
raise ValueError("Crop window is None")
for band in ["B2", "B3", "B4", "B8"]:
image_fname = (
crop_window_dir / "layers" / LANDSAT_LAYER_NAME / band / "geotiff.tif"
)
with image_fname.open("rb") as f:
with rasterio.open(f) as src:
images[band] = src.read(1)
raster_dir = crop_window.get_raster_dir(LANDSAT_LAYER_NAME, [band])

# Different bands are in different resolutions so get the bounds from the
# raster since anyway we just want to read the whole raster.
band_bounds = raster_format.get_raster_bounds(raster_dir)
image = raster_format.decode_raster(raster_dir, band_bounds)
if image.shape[0] != 1:
raise ValueError(
f"expected single-band image for {band} but got {image.shape[0]} bands"
)

images[band] = image[0, :, :]

# Apply simple pan-sharpening for the RGB.
# This is just linearly scaling RGB bands to add up to B8, which is captured at
Expand Down

0 comments on commit 2b6c0ed

Please sign in to comment.