diff --git a/muse/modeling_transformer.py b/muse/modeling_transformer.py index b03c031d..950196d1 100644 --- a/muse/modeling_transformer.py +++ b/muse/modeling_transformer.py @@ -1253,21 +1253,18 @@ def custom_forward(*inputs): logits = self.mlm_layer(hidden_states) if labels is not None: - cross_entropy_per_image = F.cross_entropy( + reduction = "none" if loss_weight is not None else "mean" + loss = F.cross_entropy( logits.view(-1, self.output_size), labels.view(-1), ignore_index=-100, label_smoothing=label_smoothing, - reduction="none", + reduction=reduction, ) - - if loss_weight is None: - loss = cross_entropy_per_image.mean() - else: + if loss_weight is not None: loss_weight = loss_weight.view(-1) - loss = ((cross_entropy_per_image * loss_weight).sum(dim=-1) / loss_weight.sum(dim=-1)).mean() - - return logits, loss, cross_entropy_per_image + loss = ((loss * loss_weight).sum(dim=-1) / loss_weight.sum(dim=-1)).mean() + return logits, loss return logits def _set_gradient_checkpointing(self, module, value=False): diff --git a/muse/training_utils.py b/muse/training_utils.py index b5b87cda..954cf127 100644 --- a/muse/training_utils.py +++ b/muse/training_utils.py @@ -343,7 +343,15 @@ def image_entropy_per_percent_masked_bucket(logits, input_ids, mask_id): return entropy_by_masked_bucket -def cross_entropy_per_percent_masked_bucket(cross_entropy_per_image, input_ids, mask_id): +def cross_entropy_per_percent_masked_bucket(logits, labels, input_ids, mask_id, output_size, label_smoothing): + cross_entropy_per_image = F.cross_entropy( + logits.view(-1, output_size), + labels.view(-1), + ignore_index=-100, + label_smoothing=label_smoothing, + reduction="none", + ) + total_buckets = 10 masked_buckets = input_ids_to_masked_buckets(input_ids, mask_id, total_buckets) diff --git a/training/train_muse.py b/training/train_muse.py index 8d4c69a7..fe3d5c47 100644 --- a/training/train_muse.py +++ b/training/train_muse.py @@ -300,6 +300,7 @@ def main(): model_cls = MaskGitTransformer if config.model.get("architecture", "transformer") == "transformer" else MaskGiTUViT model = model_cls(**config.model.transformer) mask_id = model.config.mask_token_id + output_size = model.output_size # Create EMA if config.training.get("use_ema", False): @@ -575,7 +576,7 @@ def prepare_inputs_and_labels( ) loss = soft_target_cross_entropy(logits, labels, soft_targets) else: - logits, loss, cross_entropy_per_image = model( + logits, loss = model( input_ids=input_ids, encoder_hidden_states=encoder_hidden_states, labels=labels, @@ -663,7 +664,16 @@ def prepare_inputs_and_labels( and ((global_step + 1) % config.experiment.log_cross_entropy_every == 0) and accelerator.is_main_process ): - log_cross_entropy(cross_entropy_per_image, input_ids, mask_id, accelerator, global_step + 1) + log_cross_entropy( + logits, + labels, + input_ids, + mask_id, + output_size, + config.training.label_smoothing, + accelerator, + global_step + 1, + ) if ( ("log_token_probability_distributions_every" in config.experiment) @@ -749,7 +759,7 @@ def validate_model(model, eval_dataloader, accelerator, global_step, prepare_inp input_ids, encoder_hidden_states, labels, _, _, loss_weight = prepare_inputs_and_labels( pixel_values, input_ids ) - _, loss, _ = model( + _, loss = model( input_ids=input_ids, encoder_hidden_states=encoder_hidden_states, labels=labels, loss_weight=loss_weight ) eval_loss += loss.mean() @@ -902,9 +912,9 @@ def log_image_entropy(logits, input_ids, mask_id, accelerator, global_step): @torch.no_grad() -def log_cross_entropy(cross_entropy_per_image, input_ids, mask_id, accelerator, global_step): +def log_cross_entropy(logits, labels, input_ids, mask_id, output_size, label_smoothing, accelerator, global_step): cross_entropy_per_percent_masked_bucket = muse.training_utils.cross_entropy_per_percent_masked_bucket( - cross_entropy_per_image, input_ids, mask_id + logits, labels, input_ids, mask_id, output_size, label_smoothing ) cross_entropy_log = {}