diff --git a/instageo/model/README.md b/instageo/model/README.md index 3432fd5..08db1a8 100644 --- a/instageo/model/README.md +++ b/instageo/model/README.md @@ -38,7 +38,7 @@ pip install . #run from instageo root - `learning_rate`: Initial learning rate. - `num_epochs`: Number of training epochs. - `batch_size`: Batch size for training and validation. - - `mode`: Select one of training or evaluation mode. + - `mode`: Select training, evaluation or stats mode. See `configs/config.yaml` for more. 2. **Dataset Preparation:** Prepare your geospatial data using the InstaGeo Chip Creator or similar and place it in the specified `root_dir`. Ensure that the csv file for each dataset has `Input` and `Label` columns corresponding to the path of the image and label relative to the `root_dir`. Additionally, ensure the data is compatible with `InstaGeoDataset` @@ -221,5 +221,6 @@ When the saved checkpoint is evaluated on the test set, you should have results ## Customization -- Modify the `bands`, `mean`, and `std` lists in `configs/config.yaml` to match your dataset's characteristics. +- Use the `stats` mode to compute the `mean`, and `std` lists of your dataset. +- Modify the `bands`, `mean`, and `std` lists in `configs/config.yaml` to match your dataset's characteristics and improve its normalization. - Implement additional data augmentation strategies in `process_and_augment`. diff --git a/instageo/model/configs/locust.yaml b/instageo/model/configs/locust.yaml index cabd5cc..5e75880 100644 --- a/instageo/model/configs/locust.yaml +++ b/instageo/model/configs/locust.yaml @@ -21,13 +21,29 @@ model: dataloader: # 3*(Blue, Green, Red, Narrow NIR, SWIR 1, SWIR 2) bands: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17] - mean: [0.14245495, 0.13921481, 0.12434631, 0.31420089, 0.20743526, 0.12046503] - std: [0.04036231, 0.04186983, 0.05267646, 0.0822221, 0.06834774, 0.05294205] + mean: + [ + 623.2724609375, + 1247.657958984375, + 1772.24169921875, + 2371.256103515625, + 2862.867431640625, + 2357.759765625, + ] + std: + [ + 2182.050048828125, + 2248.420654296875, + 2302.53515625, + 2372.204345703125, + 2398.52685546875, + 2292.96435546875, + ] img_size: 224 temporal_dim: 3 - replace_label: null + replace_label: [-9999, -1] reduce_to_zero: False - no_data_value: -1 + no_data_value: -9999 constant_multiplier: 1.0 test: diff --git a/instageo/model/dataloader.py b/instageo/model/dataloader.py index a0f2d06..b06c06c 100644 --- a/instageo/model/dataloader.py +++ b/instageo/model/dataloader.py @@ -244,14 +244,6 @@ def get_raster_data( data = src.read() if (not is_label) and bands: data = data[bands, ...] - # For some reasons, some few HLS tiles are not scaled in v2.0. - # In the following lines, we find and scale them - bands = [] - for band in data: - if band.max() > 10: - band *= 0.0001 - bands.append(band) - data = np.stack(bands, axis=0) return data diff --git a/instageo/model/model.py b/instageo/model/model.py index fe2b5a4..6a22ccf 100644 --- a/instageo/model/model.py +++ b/instageo/model/model.py @@ -160,7 +160,12 @@ def __init__( if freeze_backbone: for param in model.parameters(): param.requires_grad = False - _ = model.load_state_dict(checkpoint, strict=False) + filtered_checkpoint_state_dict = { + key[len("encoder.") :]: value + for key, value in checkpoint.items() + if key.startswith("encoder.") + } + _ = model.load_state_dict(filtered_checkpoint_state_dict) self.prithvi_100M_backbone = model diff --git a/instageo/model/run.py b/instageo/model/run.py index 8d2d217..821f8a2 100644 --- a/instageo/model/run.py +++ b/instageo/model/run.py @@ -418,6 +418,38 @@ def compute_metrics( } +def compute_mean_std(data_loader: DataLoader) -> Tuple[List[float], List[float]]: + """Compute the mean and standard deviation of a dataset. + + Args: + data_loader (DataLoader): PyTorch DataLoader. + + Returns: + mean (list): List of means for each channel. + std (list): List of standard deviations for each channel. + """ + mean = 0.0 + var = 0.0 + nb_samples = 0 + + for data, _ in data_loader: + # Reshape data to (B, C, T*H*W) + batch_samples = data.size(0) + data = data.view(batch_samples, data.size(1), -1) + + nb_samples += batch_samples + + # Sum over batch, height and width + mean += data.mean(2).sum(0) + + var += data.var(2, unbiased=False).sum(0) + + mean /= nb_samples + var /= nb_samples + std = torch.sqrt(var) + return mean.tolist(), std.tolist() # type:ignore + + @hydra.main(config_path="configs", version_base=None, config_name="config") def main(cfg: DictConfig) -> None: """Runner Entry Point. @@ -446,6 +478,34 @@ def main(cfg: DictConfig) -> None: test_filepath = cfg.test_filepath checkpoint_path = cfg.checkpoint_path + if cfg.mode == "stats": + train_dataset = InstaGeoDataset( + filename=train_filepath, + input_root=root_dir, + preprocess_func=partial( + process_and_augment, + mean=[0] * len(MEAN), + std=[1] * len(STD), + temporal_size=TEMPORAL_SIZE, + im_size=IM_SIZE, + ), + bands=BANDS, + replace_label=cfg.dataloader.replace_label, + reduce_to_zero=cfg.dataloader.reduce_to_zero, + no_data_value=cfg.dataloader.no_data_value, + constant_multiplier=cfg.dataloader.constant_multiplier, + ) + train_loader = create_dataloader( + train_dataset, + batch_size=batch_size, + shuffle=True, + num_workers=1, + ) + mean, std = compute_mean_std(train_loader) + print(mean) + print(std) + exit(0) + if cfg.mode == "train": check_required_flags(["root_dir", "train_filepath", "valid_filepath"], cfg) train_dataset = InstaGeoDataset(