Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: update instageo-model #14

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions instageo/model/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ pip install . #run from instageo root
- `learning_rate`: Initial learning rate.
- `num_epochs`: Number of training epochs.
- `batch_size`: Batch size for training and validation.
- `mode`: Select one of training or evaluation mode.
- `mode`: Select training, evaluation or stats mode.
See `configs/config.yaml` for more.

2. **Dataset Preparation:** Prepare your geospatial data using the InstaGeo Chip Creator or similar and place it in the specified `root_dir`. Ensure that the csv file for each dataset has `Input` and `Label` columns corresponding to the path of the image and label relative to the `root_dir`. Additionally, ensure the data is compatible with `InstaGeoDataset`
Expand Down Expand Up @@ -221,5 +221,6 @@ When the saved checkpoint is evaluated on the test set, you should have results

## Customization

- Modify the `bands`, `mean`, and `std` lists in `configs/config.yaml` to match your dataset's characteristics.
- Use the `stats` mode to compute the `mean`, and `std` lists of your dataset.
- Modify the `bands`, `mean`, and `std` lists in `configs/config.yaml` to match your dataset's characteristics and improve its normalization.
- Implement additional data augmentation strategies in `process_and_augment`.
24 changes: 20 additions & 4 deletions instageo/model/configs/locust.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,29 @@ model:
dataloader:
# 3*(Blue, Green, Red, Narrow NIR, SWIR 1, SWIR 2)
bands: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17]
mean: [0.14245495, 0.13921481, 0.12434631, 0.31420089, 0.20743526, 0.12046503]
std: [0.04036231, 0.04186983, 0.05267646, 0.0822221, 0.06834774, 0.05294205]
mean:
[
623.2724609375,
1247.657958984375,
1772.24169921875,
2371.256103515625,
2862.867431640625,
2357.759765625,
]
std:
[
2182.050048828125,
2248.420654296875,
2302.53515625,
2372.204345703125,
2398.52685546875,
2292.96435546875,
]
img_size: 224
temporal_dim: 3
replace_label: null
replace_label: [-9999, -1]
ZeLynxy marked this conversation as resolved.
Show resolved Hide resolved
reduce_to_zero: False
no_data_value: -1
no_data_value: -9999
constant_multiplier: 1.0

test:
Expand Down
8 changes: 0 additions & 8 deletions instageo/model/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,14 +244,6 @@ def get_raster_data(
data = src.read()
if (not is_label) and bands:
data = data[bands, ...]
# For some reasons, some few HLS tiles are not scaled in v2.0.
# In the following lines, we find and scale them
bands = []
for band in data:
if band.max() > 10:
band *= 0.0001
bands.append(band)
data = np.stack(bands, axis=0)
return data


Expand Down
7 changes: 6 additions & 1 deletion instageo/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,12 @@ def __init__(
if freeze_backbone:
for param in model.parameters():
param.requires_grad = False
_ = model.load_state_dict(checkpoint, strict=False)
filtered_checkpoint_state_dict = {
key[len("encoder.") :]: value
for key, value in checkpoint.items()
if key.startswith("encoder.")
}
_ = model.load_state_dict(filtered_checkpoint_state_dict)

self.prithvi_100M_backbone = model

Expand Down
60 changes: 60 additions & 0 deletions instageo/model/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,6 +418,38 @@ def compute_metrics(
}


def compute_mean_std(data_loader: DataLoader) -> Tuple[List[float], List[float]]:
"""Compute the mean and standard deviation of a dataset.

Args:
data_loader (DataLoader): PyTorch DataLoader.

Returns:
mean (list): List of means for each channel.
std (list): List of standard deviations for each channel.
"""
mean = 0.0
var = 0.0
nb_samples = 0

for data, _ in data_loader:
# Reshape data to (B, C, T*H*W)
batch_samples = data.size(0)
data = data.view(batch_samples, data.size(1), -1)

nb_samples += batch_samples

# Sum over batch, height and width
mean += data.mean(2).sum(0)

var += data.var(2, unbiased=False).sum(0)

mean /= nb_samples
var /= nb_samples
std = torch.sqrt(var)
return mean.tolist(), std.tolist() # type:ignore


@hydra.main(config_path="configs", version_base=None, config_name="config")
def main(cfg: DictConfig) -> None:
"""Runner Entry Point.
Expand Down Expand Up @@ -446,6 +478,34 @@ def main(cfg: DictConfig) -> None:
test_filepath = cfg.test_filepath
checkpoint_path = cfg.checkpoint_path

if cfg.mode == "stats":
train_dataset = InstaGeoDataset(
filename=train_filepath,
input_root=root_dir,
preprocess_func=partial(
process_and_augment,
mean=[0] * len(MEAN),
std=[1] * len(STD),
temporal_size=TEMPORAL_SIZE,
im_size=IM_SIZE,
),
bands=BANDS,
replace_label=cfg.dataloader.replace_label,
reduce_to_zero=cfg.dataloader.reduce_to_zero,
no_data_value=cfg.dataloader.no_data_value,
constant_multiplier=cfg.dataloader.constant_multiplier,
)
train_loader = create_dataloader(
train_dataset,
batch_size=batch_size,
shuffle=True,
num_workers=1,
)
mean, std = compute_mean_std(train_loader)
print(mean)
print(std)
exit(0)

if cfg.mode == "train":
check_required_flags(["root_dir", "train_filepath", "valid_filepath"], cfg)
train_dataset = InstaGeoDataset(
Expand Down
Loading