diff --git a/test/models/flava/test_flava.py b/test/models/flava/test_flava.py index e096b4347..349e001f2 100644 --- a/test/models/flava/test_flava.py +++ b/test/models/flava/test_flava.py @@ -79,7 +79,7 @@ def test_forward_pretraining(self): sum( value if value is not None else 0 for value in output.losses.values() ).item(), - 20.4199, + 21.0189, places=4, ) @@ -103,7 +103,7 @@ def test_forward_pretraining(self): sum( value if value is not None else 0 for value in output.losses.values() ).item(), - 9.3403, + 8.9271, places=4, ) @@ -128,7 +128,7 @@ def test_forward_pretraining(self): sum( value if value is not None else 0 for value in output.losses.values() ).item(), - 10.8777, + 10.1290, places=4, ) diff --git a/torchmultimodal/models/flava/flava_model.py b/torchmultimodal/models/flava/flava_model.py index 9c74f527b..8a2af13de 100644 --- a/torchmultimodal/models/flava/flava_model.py +++ b/torchmultimodal/models/flava/flava_model.py @@ -185,9 +185,8 @@ def flava_model_for_pretraining( # TODO: Add parameters for loss here ): model = flava_model(**flava_model_kwargs) - - codebook = DalleVAEEncoder(image_size=codebook_image_size) losses = FLAVAPretrainingLoss() + codebook = DalleVAEEncoder(image_size=codebook_image_size) flava = FLAVAForPreTraining( model=model, diff --git a/torchmultimodal/modules/losses/flava.py b/torchmultimodal/modules/losses/flava.py index c26467a88..d40e9afaf 100644 --- a/torchmultimodal/modules/losses/flava.py +++ b/torchmultimodal/modules/losses/flava.py @@ -380,6 +380,21 @@ def forward( outputs = FLAVAPretrainingLossOutput() pos_mask = None + if ( + image_sequence is not None + and text_sequence is not None + and self.contrastive_loss_weight > 0 + ): + outputs.global_contrastive_output = self.contrastive_loss( + image_sequence, + text_sequence, + pos_mask, + ) + outputs.global_contrastive_output.loss *= self.contrastive_loss_weight + outputs.losses.global_contrastive_loss = ( + outputs.global_contrastive_output.loss + ) + # Check multimodal_masked_sequence to make sure this is unimodal case # This specific case can though be backpropagated directly as MIM is independent of # text, but that is a research question :) @@ -461,19 +476,4 @@ def forward( outputs.mmm_image_output.loss *= self.mmm_image_loss_weight outputs.losses.mmm_image_loss = outputs.mmm_image_output.loss - if ( - image_sequence is not None - and text_sequence is not None - and self.contrastive_loss_weight > 0 - ): - outputs.global_contrastive_output = self.contrastive_loss( - image_sequence, - text_sequence, - pos_mask, - ) - outputs.global_contrastive_output.loss *= self.contrastive_loss_weight - outputs.losses.global_contrastive_loss = ( - outputs.global_contrastive_output.loss - ) - return outputs