From e90baeb5cb00dae031c56356ab074c9ce5874e6c Mon Sep 17 00:00:00 2001 From: Will Berman Date: Wed, 12 Jul 2023 22:29:40 +0000 Subject: [PATCH] add image entropy --- muse/training_utils.py | 4 ++-- training/train_muse.py | 37 ++++++++++++++++++++++++++++++------- 2 files changed, 32 insertions(+), 9 deletions(-) diff --git a/muse/training_utils.py b/muse/training_utils.py index ac6a6f21..b5b87cda 100644 --- a/muse/training_utils.py +++ b/muse/training_utils.py @@ -296,7 +296,7 @@ def load_state_dict(self, state_dict: dict) -> None: # calculates entropy over each pixel distribution -def entropy_per_percent_masked_bucket(logits, input_ids, mask_id): +def pixel_entropy_per_percent_masked_bucket(logits, input_ids, mask_id): # only calculated entropy over image tokens that were masked in the original image masked_tokens = input_ids == mask_id num_masked_pixels = masked_tokens.sum(-1) @@ -324,7 +324,7 @@ def entropy_per_percent_masked_bucket(logits, input_ids, mask_id): def image_entropy_per_percent_masked_bucket(logits, input_ids, mask_id): # only calculated entropy over image tokens that were masked in the original image masked_tokens = input_ids == mask_id - num_masked_pixels = masked_tokens.sum(-1) + num_masked_pixels = masked_tokens.sum(-1, keepdim=True) pixel_probs = F.softmax(logits, dim=-1) pixel_probs[~masked_tokens] = 0 diff --git a/training/train_muse.py b/training/train_muse.py index 5400efc4..8d4c69a7 100644 --- a/training/train_muse.py +++ b/training/train_muse.py @@ -645,11 +645,18 @@ def prepare_inputs_and_labels( data_time_m.reset() if ( - ("log_entropy_every" in config.experiment) - and ((global_step + 1) % config.experiment.log_entropy_every == 0) + ("log_pixel_entropy_every" in config.experiment) + and ((global_step + 1) % config.experiment.log_pixel_entropy_every == 0) and accelerator.is_main_process ): - log_entropy(logits, input_ids, mask_id, accelerator, global_step + 1) + log_pixel_entropy(logits, input_ids, mask_id, accelerator, global_step + 1) + + if ( + ("log_image_entropy_every" in config.experiment) + and ((global_step + 1) % config.experiment.log_image_entropy_every == 0) + and accelerator.is_main_process + ): + log_image_entropy(logits, input_ids, mask_id, accelerator, global_step + 1) if ( ("log_cross_entropy_every" in config.experiment) @@ -863,19 +870,35 @@ def log_grad_norm(model, accelerator, global_step): @torch.no_grad() -def log_entropy(logits, input_ids, mask_id, accelerator, global_step): - entropy_per_percent_masked_bucket = muse.training_utils.entropy_per_percent_masked_bucket( +def log_pixel_entropy(logits, input_ids, mask_id, accelerator, global_step): + pixel_entropy_per_percent_masked_bucket = muse.training_utils.pixel_entropy_per_percent_masked_bucket( + logits, input_ids, mask_id + ) + + entropy_log = {} + + for bucket, bucket_entropy in enumerate(pixel_entropy_per_percent_masked_bucket): + bucket_entropy = bucket_entropy.item() + if bucket_entropy != 0: + entropy_log[f"bucket {bucket}"] = bucket_entropy + + accelerator.log({"pixel_entropy": entropy_log}, step=global_step) + + +@torch.no_grad() +def log_image_entropy(logits, input_ids, mask_id, accelerator, global_step): + image_entropy_per_percent_masked_bucket = muse.training_utils.image_entropy_per_percent_masked_bucket( logits, input_ids, mask_id ) entropy_log = {} - for bucket, bucket_entropy in enumerate(entropy_per_percent_masked_bucket): + for bucket, bucket_entropy in enumerate(image_entropy_per_percent_masked_bucket): bucket_entropy = bucket_entropy.item() if bucket_entropy != 0: entropy_log[f"bucket {bucket}"] = bucket_entropy - accelerator.log({"entropy": entropy_log}, step=global_step) + accelerator.log({"image_entropy": entropy_log}, step=global_step) @torch.no_grad()