Skip to content

Commit

Permalink
add image entropy
Browse files Browse the repository at this point in the history
  • Loading branch information
williamberman committed Jul 12, 2023
1 parent a5379f5 commit e90baeb
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 9 deletions.
4 changes: 2 additions & 2 deletions muse/training_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,7 @@ def load_state_dict(self, state_dict: dict) -> None:


# calculates entropy over each pixel distribution
def entropy_per_percent_masked_bucket(logits, input_ids, mask_id):
def pixel_entropy_per_percent_masked_bucket(logits, input_ids, mask_id):
# only calculated entropy over image tokens that were masked in the original image
masked_tokens = input_ids == mask_id
num_masked_pixels = masked_tokens.sum(-1)
Expand Down Expand Up @@ -324,7 +324,7 @@ def entropy_per_percent_masked_bucket(logits, input_ids, mask_id):
def image_entropy_per_percent_masked_bucket(logits, input_ids, mask_id):
# only calculated entropy over image tokens that were masked in the original image
masked_tokens = input_ids == mask_id
num_masked_pixels = masked_tokens.sum(-1)
num_masked_pixels = masked_tokens.sum(-1, keepdim=True)

pixel_probs = F.softmax(logits, dim=-1)
pixel_probs[~masked_tokens] = 0
Expand Down
37 changes: 30 additions & 7 deletions training/train_muse.py
Original file line number Diff line number Diff line change
Expand Up @@ -645,11 +645,18 @@ def prepare_inputs_and_labels(
data_time_m.reset()

if (
("log_entropy_every" in config.experiment)
and ((global_step + 1) % config.experiment.log_entropy_every == 0)
("log_pixel_entropy_every" in config.experiment)
and ((global_step + 1) % config.experiment.log_pixel_entropy_every == 0)
and accelerator.is_main_process
):
log_entropy(logits, input_ids, mask_id, accelerator, global_step + 1)
log_pixel_entropy(logits, input_ids, mask_id, accelerator, global_step + 1)

if (
("log_image_entropy_every" in config.experiment)
and ((global_step + 1) % config.experiment.log_image_entropy_every == 0)
and accelerator.is_main_process
):
log_image_entropy(logits, input_ids, mask_id, accelerator, global_step + 1)

if (
("log_cross_entropy_every" in config.experiment)
Expand Down Expand Up @@ -863,19 +870,35 @@ def log_grad_norm(model, accelerator, global_step):


@torch.no_grad()
def log_entropy(logits, input_ids, mask_id, accelerator, global_step):
entropy_per_percent_masked_bucket = muse.training_utils.entropy_per_percent_masked_bucket(
def log_pixel_entropy(logits, input_ids, mask_id, accelerator, global_step):
pixel_entropy_per_percent_masked_bucket = muse.training_utils.pixel_entropy_per_percent_masked_bucket(
logits, input_ids, mask_id
)

entropy_log = {}

for bucket, bucket_entropy in enumerate(pixel_entropy_per_percent_masked_bucket):
bucket_entropy = bucket_entropy.item()
if bucket_entropy != 0:
entropy_log[f"bucket {bucket}"] = bucket_entropy

accelerator.log({"pixel_entropy": entropy_log}, step=global_step)


@torch.no_grad()
def log_image_entropy(logits, input_ids, mask_id, accelerator, global_step):
image_entropy_per_percent_masked_bucket = muse.training_utils.image_entropy_per_percent_masked_bucket(
logits, input_ids, mask_id
)

entropy_log = {}

for bucket, bucket_entropy in enumerate(entropy_per_percent_masked_bucket):
for bucket, bucket_entropy in enumerate(image_entropy_per_percent_masked_bucket):
bucket_entropy = bucket_entropy.item()
if bucket_entropy != 0:
entropy_log[f"bucket {bucket}"] = bucket_entropy

accelerator.log({"entropy": entropy_log}, step=global_step)
accelerator.log({"image_entropy": entropy_log}, step=global_step)


@torch.no_grad()
Expand Down

0 comments on commit e90baeb

Please sign in to comment.