Skip to content

Commit

Permalink
misc fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
williamberman committed Jul 12, 2023
1 parent e90baeb commit 8fdc934
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 15 deletions.
15 changes: 6 additions & 9 deletions muse/modeling_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1253,21 +1253,18 @@ def custom_forward(*inputs):
logits = self.mlm_layer(hidden_states)

if labels is not None:
cross_entropy_per_image = F.cross_entropy(
reduction = "none" if loss_weight is not None else "mean"
loss = F.cross_entropy(
logits.view(-1, self.output_size),
labels.view(-1),
ignore_index=-100,
label_smoothing=label_smoothing,
reduction="none",
reduction=reduction,
)

if loss_weight is None:
loss = cross_entropy_per_image.mean()
else:
if loss_weight is not None:
loss_weight = loss_weight.view(-1)
loss = ((cross_entropy_per_image * loss_weight).sum(dim=-1) / loss_weight.sum(dim=-1)).mean()

return logits, loss, cross_entropy_per_image
loss = ((loss * loss_weight).sum(dim=-1) / loss_weight.sum(dim=-1)).mean()
return logits, loss
return logits

def _set_gradient_checkpointing(self, module, value=False):
Expand Down
10 changes: 9 additions & 1 deletion muse/training_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,7 +343,15 @@ def image_entropy_per_percent_masked_bucket(logits, input_ids, mask_id):
return entropy_by_masked_bucket


def cross_entropy_per_percent_masked_bucket(cross_entropy_per_image, input_ids, mask_id):
def cross_entropy_per_percent_masked_bucket(logits, labels, input_ids, mask_id, output_size, label_smoothing):
cross_entropy_per_image = F.cross_entropy(
logits.view(-1, output_size),
labels.view(-1),
ignore_index=-100,
label_smoothing=label_smoothing,
reduction="none",
)

total_buckets = 10
masked_buckets = input_ids_to_masked_buckets(input_ids, mask_id, total_buckets)

Expand Down
20 changes: 15 additions & 5 deletions training/train_muse.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,7 @@ def main():
model_cls = MaskGitTransformer if config.model.get("architecture", "transformer") == "transformer" else MaskGiTUViT
model = model_cls(**config.model.transformer)
mask_id = model.config.mask_token_id
output_size = model.output_size

# Create EMA
if config.training.get("use_ema", False):
Expand Down Expand Up @@ -575,7 +576,7 @@ def prepare_inputs_and_labels(
)
loss = soft_target_cross_entropy(logits, labels, soft_targets)
else:
logits, loss, cross_entropy_per_image = model(
logits, loss = model(
input_ids=input_ids,
encoder_hidden_states=encoder_hidden_states,
labels=labels,
Expand Down Expand Up @@ -663,7 +664,16 @@ def prepare_inputs_and_labels(
and ((global_step + 1) % config.experiment.log_cross_entropy_every == 0)
and accelerator.is_main_process
):
log_cross_entropy(cross_entropy_per_image, input_ids, mask_id, accelerator, global_step + 1)
log_cross_entropy(
logits,
labels,
input_ids,
mask_id,
output_size,
config.training.label_smoothing,
accelerator,
global_step + 1,
)

if (
("log_token_probability_distributions_every" in config.experiment)
Expand Down Expand Up @@ -749,7 +759,7 @@ def validate_model(model, eval_dataloader, accelerator, global_step, prepare_inp
input_ids, encoder_hidden_states, labels, _, _, loss_weight = prepare_inputs_and_labels(
pixel_values, input_ids
)
_, loss, _ = model(
_, loss = model(
input_ids=input_ids, encoder_hidden_states=encoder_hidden_states, labels=labels, loss_weight=loss_weight
)
eval_loss += loss.mean()
Expand Down Expand Up @@ -902,9 +912,9 @@ def log_image_entropy(logits, input_ids, mask_id, accelerator, global_step):


@torch.no_grad()
def log_cross_entropy(cross_entropy_per_image, input_ids, mask_id, accelerator, global_step):
def log_cross_entropy(logits, labels, input_ids, mask_id, output_size, label_smoothing, accelerator, global_step):
cross_entropy_per_percent_masked_bucket = muse.training_utils.cross_entropy_per_percent_masked_bucket(
cross_entropy_per_image, input_ids, mask_id
logits, labels, input_ids, mask_id, output_size, label_smoothing
)

cross_entropy_log = {}
Expand Down

0 comments on commit 8fdc934

Please sign in to comment.