Skip to content

Commit

Permalink
🗃️ Increase input channels from 12 to 13
Browse files Browse the repository at this point in the history
The datacube has 13 channels, namely 10 from Sentinel-2's 10m and 20m resolution bands, 2 from Sentinel-1's VV and VH, and 1 from the Copernicus DEM.
  • Loading branch information
weiji14 committed Nov 21, 2023
1 parent c7b8e66 commit 2ce108a
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 6 deletions.
4 changes: 2 additions & 2 deletions src/datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
# %%
class RandomDataset(torch.utils.data.Dataset):
"""
Torch Dataset that returns tensors of size (12, 256, 256) with random
Torch Dataset that returns tensors of size (13, 256, 256) with random
values.
"""

Expand All @@ -20,7 +20,7 @@ def __len__(self):
return 2048

def __getitem__(self, idx: int):
return torch.randn(12, 256, 256)
return torch.randn(13, 256, 256)


class BaseDataModule(L.LightningDataModule):
Expand Down
8 changes: 4 additions & 4 deletions src/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def __init__(self, lr: float = 0.001, mask_ratio: float = 0.75):
layer_norm_eps=1e-12,
image_size=256, # default was 224
patch_size=32, # default was 16
num_channels=12, # default was 3
num_channels=13, # default was 3
qkv_bias=True,
decoder_num_attention_heads=16,
decoder_hidden_size=512,
Expand Down Expand Up @@ -90,7 +90,7 @@ def training_step(self, batch: torch.Tensor, batch_idx: int) -> torch.Tensor:
- https://github.com/huggingface/transformers/blob/v4.35.2/src/transformers/models/vit_mae/modeling_vit_mae.py#L948-L1010
"""
x: torch.Tensor = batch
# x: torch.Tensor = torch.randn(32, 12, 256, 256) # BCHW
# x: torch.Tensor = torch.randn(32, 13, 256, 256) # BCHW

# Forward encoder
outputs_encoder: dict = self(x)
Expand All @@ -101,7 +101,7 @@ def training_step(self, batch: torch.Tensor, batch_idx: int) -> torch.Tensor:
ids_restore=outputs_encoder.ids_restore,
)
# output shape (batch_size, num_patches, patch_size*patch_size*num_channels)
assert outputs_decoder.logits.shape == torch.Size([32, 64, 12288])
assert outputs_decoder.logits.shape == torch.Size([32, 64, 13312])

# Log training loss and metrics
loss: torch.Tensor = self.vit.forward_loss(
Expand All @@ -113,7 +113,7 @@ def training_step(self, batch: torch.Tensor, batch_idx: int) -> torch.Tensor:
on_step=True,
on_epoch=True,
prog_bar=True,
logger=True,
# logger=True,
)

return loss
Expand Down

0 comments on commit 2ce108a

Please sign in to comment.