Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FLAVA] Move masked prediction head to flava_for_pretraining #195

Draft
wants to merge 7 commits into
base: gh/ankitade/11/base
Choose a base branch
from
148 changes: 139 additions & 9 deletions torchmultimodal/models/flava/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@
FLAVA_FOR_PRETRAINED_MAPPING = {
# This will no longer load with the updated model, but keeping here just in case
# "flava_full": "https://huggingface.co/aps/flava_full_pretrained_encoders_torchmm/resolve/main/pytorch_model.bin",
"flava_full": "https://download.pytorch.org/models/multimodal/flava/flava_for_pretraining_unified_itm.pt",
"flava_full": "https://download.pytorch.org/models/multimodal/flava/flava_for_pretraining_unified_itm_mp.pt",
}

FLAVA_MODEL_MAPPING = {
Expand Down Expand Up @@ -314,6 +314,50 @@ def forward(self, hidden_states: Tensor):
return logits


class MaskedPredictionHead(nn.Module):
def __init__(
self,
hidden_size: int = 768,
vocab_size: int = 30522,
transform_act_fn: Callable[[Tensor], Tensor] = nn.functional.gelu,
layer_norm_eps: float = 1e-5,
use_fp32_layer_norm: bool = True,
ignore_index: int = -1,
**kwargs: Any,
):
super().__init__()

self.dense = nn.Linear(hidden_size, hidden_size)
self.transform_act_fn = transform_act_fn

self.layer_norm: nn.LayerNorm
if use_fp32_layer_norm:
self.layer_norm = Fp32LayerNorm(hidden_size, eps=layer_norm_eps)
else:
self.layer_norm = nn.LayerNorm(hidden_size, eps=layer_norm_eps)

# The output weights are the same as the input embeddings, but there is
# an output-only bias for each token.
self.decoder = nn.Linear(hidden_size, vocab_size, bias=False)

self.bias = nn.Parameter(torch.zeros(vocab_size))

# Need a link between the two variables so that the bias is
# correctly resized with `resize_token_embeddings`
self.decoder.bias = self.bias
self.ignore_index = ignore_index

def forward(self, hidden_states: Tensor, masked_labels: Tensor) -> Tensor:
masked_tokens = masked_labels.ne(self.ignore_index)
sequence_output = hidden_states[masked_tokens, :]

head_output = self.dense(sequence_output)
head_output = self.transform_act_fn(head_output)
head_output = self.layer_norm(head_output)
head_output = self.decoder(head_output)
return head_output


class FLAVAForPreTraining(nn.Module, PretrainedMixin):
# TODOs:
# 1. Expose logit scale
Expand All @@ -325,12 +369,20 @@ def __init__(
image_codebook: nn.Module,
loss: nn.Module,
itm_head: nn.Module,
mlm_head: nn.Module,
mim_head: nn.Module,
mmm_mlm_head: nn.Module,
mmm_mim_head: nn.Module,
):
super().__init__()
self.model = model
self.image_codebook = image_codebook
self.loss = loss
self.itm_head = itm_head
self.mlm_head = mlm_head
self.mim_head = mim_head
self.mmm_mlm_head = mmm_mlm_head
self.mmm_mim_head = mmm_mim_head

def encode_image(
self,
Expand Down Expand Up @@ -380,24 +432,83 @@ def forward(
)
multimodal_masked_sequence = flava_output.multimodal_masked.last_hidden_state
itm_logits = None

image_masked_sequence = flava_output.image_masked.last_hidden_state
text_masked_sequence = flava_output.text_masked.last_hidden_state
mlm_head_output = (
mim_head_output
) = mmm_mlm_head_output = mmm_mim_head_output = None
pos_mask = None
if image_masked_sequence is not None and multimodal_masked_sequence is None:
# Remove CLS token from image_masked_sequence
start_index = -image_labels.size(1) if image_labels is not None else 1
mim_head_output = self.mim_head(
image_masked_sequence[:, start_index:, :], image_labels
)

if text_masked_sequence is not None and multimodal_masked_sequence is None:
start_index = -mlm_labels.size(1) if mlm_labels is not None else 1
mlm_head_output = self.mlm_head(
text_masked_sequence[:, start_index:, :], mlm_labels
)

mmm_mlm_labels = mlm_labels
mmm_mim_labels = image_labels

if multimodal_masked_sequence is not None:
if itm_labels is not None:
pos_pairs = itm_labels.ne(0)
pos_mask = torch.where(
pos_pairs.any(), pos_pairs, pos_pairs.new([True])
)
else:
pos_mask = torch.ones(
multimodal_masked_sequence.size(0),
device=multimodal_masked_sequence.device,
).bool()
itm_logits = self.itm_head(multimodal_masked_sequence)

multimodal_masked_sequence = multimodal_masked_sequence[pos_mask]
if mlm_labels is not None:
mmm_mlm_labels = mlm_labels[pos_mask]
if image_labels is not None:
mmm_mim_labels = image_labels[pos_mask]

if multimodal_masked_sequence is not None:
start_index = (
-mmm_mlm_labels.size(1)
if mmm_mlm_labels is not None
else -(text_masked_sequence.size(1) - 1)
)
sequence_for_text = multimodal_masked_sequence[:, start_index:, :]
mmm_mlm_head_output = self.mmm_mlm_head(sequence_for_text, mmm_mlm_labels)

if multimodal_masked_sequence is not None:
# Starts from 2 because of 2 CLS, one for multimodal encoder and one
# that comes from image encoder.
total_indices = (
mmm_mim_labels.size(1)
if mmm_mim_labels is not None
else (image_masked_sequence.size(1) - 1)
)
sequence_for_image = multimodal_masked_sequence[:, 2 : 2 + total_indices, :]
mmm_mim_head_output = self.mmm_mim_head(sequence_for_image, mmm_mim_labels)

return self.loss(
image_sequence=flava_output.image.last_hidden_state,
text_sequence=flava_output.text.last_hidden_state,
image_masked_sequence=flava_output.image_masked.last_hidden_state,
text_masked_sequence=flava_output.text_masked.last_hidden_state,
multimodal_sequence=flava_output.multimodal.last_hidden_state
if not skip_unmasked_mm_encoder
else None,
multimodal_masked_sequence=flava_output.multimodal_masked.last_hidden_state,
pos_mask=pos_mask,
itm_labels=itm_labels,
mim_labels=image_labels,
mlm_labels=mlm_labels,
mmm_mlm_labels=mmm_mlm_labels,
mmm_mim_labels=mmm_mim_labels,
projected_image_embeddings=flava_output.projected_image_embeddings,
projected_text_embeddings=flava_output.projected_text_embeddings,
itm_logits=itm_logits,
mlm_head_output=mlm_head_output,
mim_head_output=mim_head_output,
mmm_mlm_head_output=mmm_mlm_head_output,
mmm_mim_head_output=mmm_mim_head_output,
)


Expand Down Expand Up @@ -548,17 +659,36 @@ def flava_model(
def flava_model_for_pretraining(
codebook_image_size: int = 112,
pretrained_model_key: Optional[str] = None,
image_vocab_size: int = 8192,
**flava_model_kwargs: Any,
# TODO: Add parameters for loss here
) -> FLAVAForPreTraining:
model = flava_model(**flava_model_kwargs)
hidden_size = flava_model_kwargs.get("hidden_size") or 768
text_vocab_size = flava_model_kwargs.get("vocab_size") or 30522
itm_head = ITMHead(hidden_size)
mlm_head = MaskedPredictionHead(hidden_size=hidden_size, vocab_size=text_vocab_size)
mim_head = MaskedPredictionHead(
hidden_size=hidden_size, vocab_size=image_vocab_size
)
mmm_mlm_head = MaskedPredictionHead(
hidden_size=hidden_size, vocab_size=text_vocab_size
)
mmm_mim_head = MaskedPredictionHead(
hidden_size=hidden_size, vocab_size=image_vocab_size
)
losses = FLAVAPretrainingLoss()
codebook = DalleVAEEncoder(image_size=codebook_image_size)

flava = FLAVAForPreTraining(
model=model, image_codebook=codebook, loss=losses, itm_head=itm_head
model=model,
image_codebook=codebook,
loss=losses,
itm_head=itm_head,
mlm_head=mlm_head,
mim_head=mim_head,
mmm_mlm_head=mmm_mlm_head,
mmm_mim_head=mmm_mim_head,
)

if pretrained_model_key is not None:
Expand Down
Loading