Skip to content

Commit

Permalink
[FLAVA]Change ordering on contrastive loss initialization
Browse files Browse the repository at this point in the history
ghstack-source-id: fb5a658cb28155d50f219186b798511fb0060bdb
Pull Request resolved: #105
  • Loading branch information
ankitade committed Jun 21, 2022
1 parent 0349375 commit b7b0166
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 20 deletions.
6 changes: 3 additions & 3 deletions test/models/flava/test_flava.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def test_forward_pretraining(self):
sum(
value if value is not None else 0 for value in output.losses.values()
).item(),
20.4199,
21.0189,
places=4,
)

Expand All @@ -103,7 +103,7 @@ def test_forward_pretraining(self):
sum(
value if value is not None else 0 for value in output.losses.values()
).item(),
9.3403,
8.9271,
places=4,
)

Expand All @@ -128,7 +128,7 @@ def test_forward_pretraining(self):
sum(
value if value is not None else 0 for value in output.losses.values()
).item(),
10.8777,
10.1290,
places=4,
)

Expand Down
3 changes: 1 addition & 2 deletions torchmultimodal/models/flava/flava_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,9 +185,8 @@ def flava_model_for_pretraining(
# TODO: Add parameters for loss here
):
model = flava_model(**flava_model_kwargs)

codebook = DalleVAEEncoder(image_size=codebook_image_size)
losses = FLAVAPretrainingLoss()
codebook = DalleVAEEncoder(image_size=codebook_image_size)

flava = FLAVAForPreTraining(
model=model,
Expand Down
30 changes: 15 additions & 15 deletions torchmultimodal/modules/losses/flava.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,6 +380,21 @@ def forward(
outputs = FLAVAPretrainingLossOutput()
pos_mask = None

if (
image_sequence is not None
and text_sequence is not None
and self.contrastive_loss_weight > 0
):
outputs.global_contrastive_output = self.contrastive_loss(
image_sequence,
text_sequence,
pos_mask,
)
outputs.global_contrastive_output.loss *= self.contrastive_loss_weight
outputs.losses.global_contrastive_loss = (
outputs.global_contrastive_output.loss
)

# Check multimodal_masked_sequence to make sure this is unimodal case
# This specific case can though be backpropagated directly as MIM is independent of
# text, but that is a research question :)
Expand Down Expand Up @@ -461,19 +476,4 @@ def forward(
outputs.mmm_image_output.loss *= self.mmm_image_loss_weight
outputs.losses.mmm_image_loss = outputs.mmm_image_output.loss

if (
image_sequence is not None
and text_sequence is not None
and self.contrastive_loss_weight > 0
):
outputs.global_contrastive_output = self.contrastive_loss(
image_sequence,
text_sequence,
pos_mask,
)
outputs.global_contrastive_output.loss *= self.contrastive_loss_weight
outputs.losses.global_contrastive_loss = (
outputs.global_contrastive_output.loss
)

return outputs

0 comments on commit b7b0166

Please sign in to comment.