Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix prithvi #398

Merged
merged 5 commits into from
Feb 4, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/confs/sen1floods11_vit_local_ckpt.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ model:
model_args:
backbone: prithvi_eo_v2_300
backbone_pretrained: false
backbone_ckpt_path: examples/Prithvi_100M.pt
backbone_ckpt_path: examples/Prithvi_EO_V2_300M.pt
backbone_drop_path: 0.1
backbone_bands:
- BLUE
Expand Down
64 changes: 36 additions & 28 deletions terratorch/models/backbones/prithvi_mae.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,11 +90,7 @@ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):


def _get_1d_sincos_embed_from_grid_torch(embed_dim: int, pos: torch.Tensor):
""" This is the torch version of *get_1d_sincos_pos_embed_from_grid()*. However,
it was modified to cast omega values to pos.dtype which must be float (and not int as in
regular positional embeddings). This was required in order to allow for native FSDP mixed
precision support: modify omega to appropriate dtype (pos carries the correct float dtype),
instead of manually forcing float32.
""" Modified torch version of *get_1d_sincos_pos_embed_from_grid()*.

embed_dim: output dimension for each position
pos: a list of positions to be encoded: size (M,) - must be float dtype!
Expand Down Expand Up @@ -145,7 +141,7 @@ def __init__(
self.input_size = input_size
self.patch_size = patch_size
self.grid_size = [s // p for s, p in zip(self.input_size, self.patch_size)]
assert self.grid_size >= [1,1,1], "Patch size is bigger than input size."
assert self.grid_size >= [1, 1, 1], "Patch size is bigger than input size."
self.num_patches = self.grid_size[0] * self.grid_size[1] * self.grid_size[2]
self.flatten = flatten

Expand Down Expand Up @@ -251,7 +247,6 @@ def __init__(self,
self.in_chans = in_chans
self.num_frames = num_frames
self.embed_dim = embed_dim
self.out_channels = [embed_dim] * depth
self.img_size = to_2tuple(img_size)
if isinstance(patch_size, int):
patch_size = (1, patch_size, patch_size)
Expand All @@ -263,6 +258,7 @@ def __init__(self,
in_chans=in_chans,
embed_dim=embed_dim,
)
self.out_channels = [embed_dim * self.patch_embed.grid_size[0]] * depth

# Optional temporal and location embedding
coords_encoding = coords_encoding or []
Expand Down Expand Up @@ -336,24 +332,30 @@ def random_masking(self, sequence, mask_ratio, noise=None):

return sequence_unmasked, mask, ids_restore

def interpolate_pos_encoding(self, x, t, w, h):
def interpolate_pos_encoding(self, t, w, h):
"""
Adapted from:
- transformers.models.vit.modeling_vit.ViTEmbeddings.interpolate_pos_encoding,
- https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194
"""
if x.shape[1] == self.pos_embed.shape[1] and w == h:
# No interpolation needed
return self.pos_embed

class_pos_embed = self.pos_embed[:, :1]
patch_pos_embed = self.pos_embed[:, 1:]
t_patches = t // self.patch_embed.patch_size[0]
w_patches = w // self.patch_embed.patch_size[1]
h_patches = h // self.patch_embed.patch_size[2]
if [t_patches, w_patches, h_patches] == self.patch_embed.grid_size:
# No interpolation needed
return self.pos_embed
if t_patches != self.patch_embed.grid_size[0]:
# Re-compute pos embedding to handle changed num_frames
grid_size = (t_patches, *self.patch_embed.grid_size[1:])
pos_embed = get_3d_sincos_pos_embed(self.pos_embed.shape[-1], grid_size, add_cls_token=True)
pos_embed = torch.from_numpy(pos_embed).float().unsqueeze(0)
else:
grid_size = self.patch_embed.grid_size
pos_embed = self.pos_embed

class_pos_embed, patch_pos_embed = pos_embed[:, :1], pos_embed[:, 1:]

n_sqrt = int((patch_pos_embed.shape[1] / t_patches) ** 0.5)
patch_pos_embed = patch_pos_embed.reshape(t_patches, n_sqrt, n_sqrt, self.embed_dim).permute(0, 3, 1, 2)
patch_pos_embed = patch_pos_embed.reshape(*grid_size, self.embed_dim).permute(0, 3, 1, 2)

patch_pos_embed = nn.functional.interpolate(
patch_pos_embed,
Expand All @@ -378,7 +380,7 @@ def forward(
# embed patches
x = self.patch_embed(x)

pos_embed = self.interpolate_pos_encoding(x, t, h, w)
pos_embed = self.interpolate_pos_encoding(t, h, w)
# add pos embed w/o cls token
x = x + pos_embed[:, 1:, :]

Expand Down Expand Up @@ -419,7 +421,7 @@ def forward_features(
# embed patches
x = self.patch_embed(x)

pos_embed = self.interpolate_pos_encoding(x, t, h, w)
pos_embed = self.interpolate_pos_encoding(t, h, w)
# add pos embed w/o cls token
x = x + pos_embed[:, 1:, :]

Expand Down Expand Up @@ -526,24 +528,30 @@ def initialize_weights(self):
torch.nn.init.normal_(self.mask_token, std=0.02)
self.apply(_init_weights)

def interpolate_pos_encoding(self, x, t, w, h):
def interpolate_pos_encoding(self, t, w, h):
"""
Adapted from:
- transformers.models.vit.modeling_vit.ViTEmbeddings.interpolate_pos_encoding,
- https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194
"""
if x.shape[1] == self.decoder_pos_embed.shape[1] and w == h:
# No interpolation needed
return self.decoder_pos_embed

class_pos_embed = self.decoder_pos_embed[:, :1]
patch_pos_embed = self.decoder_pos_embed[:, 1:]
t_patches = t // self.patch_size[0]
w_patches = w // self.patch_size[1]
h_patches = h // self.patch_size[2]
if [t_patches, w_patches, h_patches] == self.grid_size:
# No interpolation needed
return self.pos_embed
if t_patches != self.grid_size[0]:
# Re-compute pos embedding to handle changed num_frames
grid_size = (t_patches, *self.grid_size[1:])
decoder_pos_embed = get_3d_sincos_pos_embed(self.decoder_pos_embed.shape[-1], grid_size, add_cls_token=True)
decoder_pos_embed = torch.from_numpy(decoder_pos_embed).float().unsqueeze(0)
else:
grid_size = self.grid_size
decoder_pos_embed = self.decoder_pos_embed

class_pos_embed, patch_pos_embed = decoder_pos_embed[:, :1], decoder_pos_embed[:, 1:]

n_sqrt = int((patch_pos_embed.shape[1] / t_patches) ** 0.5)
patch_pos_embed = patch_pos_embed.reshape(t_patches, n_sqrt, n_sqrt, self.decoder_embed_dim).permute(0, 3, 1, 2)
patch_pos_embed = patch_pos_embed.reshape(*grid_size, self.decoder_embed_dim).permute(0, 3, 1, 2)

patch_pos_embed = nn.functional.interpolate(
patch_pos_embed,
Expand Down Expand Up @@ -574,7 +582,7 @@ def forward(

# add pos embed
t, h, w = input_size[-3:]
decoder_pos_embed = self.interpolate_pos_encoding(x, t, w, h)
decoder_pos_embed = self.interpolate_pos_encoding(t, w, h)
cls_token = cls_token + decoder_pos_embed[:, :1, :]
x = x + decoder_pos_embed[:, 1:, :]

Expand Down
53 changes: 30 additions & 23 deletions terratorch/models/backbones/prithvi_vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,30 +207,39 @@ def _create_prithvi(
prithvi_model_class = PrithviMAE
checkpoint_filter_wrapper_fn = checkpoint_filter_fn_mae

if pretrained:
assert variant in pretrained_weights, (f"No pre-trained model found for variant {variant} "
f"(pretrained models: {pretrained_weights.keys()})")

model = prithvi_model_class(**model_args)

if ckpt_path is not None:
# Load model from checkpoint
state_dict = torch.load(ckpt_path, map_location="cpu")
state_dict = checkpoint_filter_wrapper_fn(state_dict, model, pretrained_bands, model_bands)
model.load_state_dict(state_dict, strict=False)
elif pretrained:
try:
# Download config.json to count model downloads
_ = hf_hub_download(repo_id=pretrained_weights[variant]["hf_hub_id"], filename="config.json")
# Load model from Hugging Face
pretrained_path = hf_hub_download(repo_id=pretrained_weights[variant]["hf_hub_id"],
filename=pretrained_weights[variant]["hf_hub_filename"])
state_dict = torch.load(pretrained_path, map_location="cpu")
if pretrained:
if ckpt_path is not None:
# Load model from checkpoint
state_dict = torch.load(ckpt_path, map_location="cpu", weights_only=True)
state_dict = checkpoint_filter_wrapper_fn(state_dict, model, pretrained_bands, model_bands)
model.load_state_dict(state_dict, strict=True)
except RuntimeError as e:
logger.error(f"Failed to load the pre-trained weights for {variant}.")
raise e
loaded_keys = model.load_state_dict(state_dict, strict=False)
if loaded_keys.missing_keys:
logger.warning(f"Missing keys in ckpt_path {ckpt_path}: {loaded_keys.missing_keys}")
if loaded_keys.unexpected_keys:
logger.warning(f"Missing keys in ckpt_path {ckpt_path}: {loaded_keys.missing_keys}")
else:
assert variant in pretrained_weights, (f"No pre-trained model found for variant {variant} "
f"(pretrained models: {pretrained_weights.keys()})")

try:
# Download config.json to count model downloads
_ = hf_hub_download(repo_id=pretrained_weights[variant]["hf_hub_id"], filename="config.json")
# Load model from Hugging Face
pretrained_path = hf_hub_download(repo_id=pretrained_weights[variant]["hf_hub_id"],
filename=pretrained_weights[variant]["hf_hub_filename"])
state_dict = torch.load(pretrained_path, map_location="cpu", weights_only=True)
state_dict = checkpoint_filter_wrapper_fn(state_dict, model, pretrained_bands, model_bands)
model.load_state_dict(state_dict, strict=True)
except RuntimeError as e:
logger.error(f"Failed to load the pre-trained weights for {variant}.")
raise e
elif ckpt_path is not None:
logger.warning(f"ckpt_path is provided but pretrained is set to False, ignoring ckpt_path {ckpt_path}.")

model.model_bands = model_bands
model.pretrained_bands = pretrained_bands

assert encoder_only or "out_indices" not in kwargs, "out_indices provided for a MAE model."
if encoder_only:
Expand All @@ -243,8 +252,6 @@ def forward_filter_indices(*args, **kwargs):

model.forward = forward_filter_indices
model.out_indices = out_indices
model.model_bands = model_bands
model.pretrained_bands = pretrained_bands

return model

Expand Down