Skip to content

Commit

Permalink
Merge pull request #14 from instadeepai/13-update-instageo-model-to-c…
Browse files Browse the repository at this point in the history
…onform-with-new-instageo-data-pipeline

fix: update instageo-model
  • Loading branch information
rym-oualha authored Jan 15, 2025
2 parents c2c8630 + 86730a7 commit c169f8c
Show file tree
Hide file tree
Showing 5 changed files with 89 additions and 15 deletions.
5 changes: 3 additions & 2 deletions instageo/model/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ pip install . #run from instageo root
- `learning_rate`: Initial learning rate.
- `num_epochs`: Number of training epochs.
- `batch_size`: Batch size for training and validation.
- `mode`: Select one of training or evaluation mode.
- `mode`: Select training, evaluation or stats mode.
See `configs/config.yaml` for more.

2. **Dataset Preparation:** Prepare your geospatial data using the InstaGeo Chip Creator or similar and place it in the specified `root_dir`. Ensure that the csv file for each dataset has `Input` and `Label` columns corresponding to the path of the image and label relative to the `root_dir`. Additionally, ensure the data is compatible with `InstaGeoDataset`
Expand Down Expand Up @@ -221,5 +221,6 @@ When the saved checkpoint is evaluated on the test set, you should have results

## Customization

- Modify the `bands`, `mean`, and `std` lists in `configs/config.yaml` to match your dataset's characteristics.
- Use the `stats` mode to compute the `mean`, and `std` lists of your dataset.
- Modify the `bands`, `mean`, and `std` lists in `configs/config.yaml` to match your dataset's characteristics and improve its normalization.
- Implement additional data augmentation strategies in `process_and_augment`.
24 changes: 20 additions & 4 deletions instageo/model/configs/locust.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,29 @@ model:
dataloader:
# 3*(Blue, Green, Red, Narrow NIR, SWIR 1, SWIR 2)
bands: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17]
mean: [0.14245495, 0.13921481, 0.12434631, 0.31420089, 0.20743526, 0.12046503]
std: [0.04036231, 0.04186983, 0.05267646, 0.0822221, 0.06834774, 0.05294205]
mean:
[
623.2724609375,
1247.657958984375,
1772.24169921875,
2371.256103515625,
2862.867431640625,
2357.759765625,
]
std:
[
2182.050048828125,
2248.420654296875,
2302.53515625,
2372.204345703125,
2398.52685546875,
2292.96435546875,
]
img_size: 224
temporal_dim: 3
replace_label: null
replace_label: [-9999, -1]
reduce_to_zero: False
no_data_value: -1
no_data_value: -9999
constant_multiplier: 1.0

test:
Expand Down
8 changes: 0 additions & 8 deletions instageo/model/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,14 +244,6 @@ def get_raster_data(
data = src.read()
if (not is_label) and bands:
data = data[bands, ...]
# For some reasons, some few HLS tiles are not scaled in v2.0.
# In the following lines, we find and scale them
bands = []
for band in data:
if band.max() > 10:
band *= 0.0001
bands.append(band)
data = np.stack(bands, axis=0)
return data


Expand Down
7 changes: 6 additions & 1 deletion instageo/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,12 @@ def __init__(
if freeze_backbone:
for param in model.parameters():
param.requires_grad = False
_ = model.load_state_dict(checkpoint, strict=False)
filtered_checkpoint_state_dict = {
key[len("encoder.") :]: value
for key, value in checkpoint.items()
if key.startswith("encoder.")
}
_ = model.load_state_dict(filtered_checkpoint_state_dict)

self.prithvi_100M_backbone = model

Expand Down
60 changes: 60 additions & 0 deletions instageo/model/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,6 +418,38 @@ def compute_metrics(
}


def compute_mean_std(data_loader: DataLoader) -> Tuple[List[float], List[float]]:
"""Compute the mean and standard deviation of a dataset.
Args:
data_loader (DataLoader): PyTorch DataLoader.
Returns:
mean (list): List of means for each channel.
std (list): List of standard deviations for each channel.
"""
mean = 0.0
var = 0.0
nb_samples = 0

for data, _ in data_loader:
# Reshape data to (B, C, T*H*W)
batch_samples = data.size(0)
data = data.view(batch_samples, data.size(1), -1)

nb_samples += batch_samples

# Sum over batch, height and width
mean += data.mean(2).sum(0)

var += data.var(2, unbiased=False).sum(0)

mean /= nb_samples
var /= nb_samples
std = torch.sqrt(var)
return mean.tolist(), std.tolist() # type:ignore


@hydra.main(config_path="configs", version_base=None, config_name="config")
def main(cfg: DictConfig) -> None:
"""Runner Entry Point.
Expand Down Expand Up @@ -446,6 +478,34 @@ def main(cfg: DictConfig) -> None:
test_filepath = cfg.test_filepath
checkpoint_path = cfg.checkpoint_path

if cfg.mode == "stats":
train_dataset = InstaGeoDataset(
filename=train_filepath,
input_root=root_dir,
preprocess_func=partial(
process_and_augment,
mean=[0] * len(MEAN),
std=[1] * len(STD),
temporal_size=TEMPORAL_SIZE,
im_size=IM_SIZE,
),
bands=BANDS,
replace_label=cfg.dataloader.replace_label,
reduce_to_zero=cfg.dataloader.reduce_to_zero,
no_data_value=cfg.dataloader.no_data_value,
constant_multiplier=cfg.dataloader.constant_multiplier,
)
train_loader = create_dataloader(
train_dataset,
batch_size=batch_size,
shuffle=True,
num_workers=1,
)
mean, std = compute_mean_std(train_loader)
print(mean)
print(std)
exit(0)

if cfg.mode == "train":
check_required_flags(["root_dir", "train_filepath", "valid_filepath"], cfg)
train_dataset = InstaGeoDataset(
Expand Down

0 comments on commit c169f8c

Please sign in to comment.