Skip to content

Commit

Permalink
Add argument to toggle using tables
Browse files Browse the repository at this point in the history
  • Loading branch information
Landanjs committed Sep 11, 2023
1 parent 6da1560 commit da5eeab
Showing 1 changed file with 5 additions and 2 deletions.
7 changes: 5 additions & 2 deletions diffusion/callbacks/log_diffusion_images.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ class LogDiffusionImages(Callback):
to use for evaluation. Default: ``None``.
seed (int, optional): Random seed to use for generation. Set a seed for reproducible generation.
Default: ``1138``.
use_table (bool): Whether to make a table of the images or not. Default: ``False``.
"""

def __init__(self,
Expand All @@ -40,14 +41,16 @@ def __init__(self,
guidance_scale: Optional[float] = 0.0,
text_key: Optional[str] = 'captions',
tokenized_prompts: Optional[torch.LongTensor] = None,
seed: Optional[int] = 1138):
seed: Optional[int] = 1138,
use_table: bool = False):
self.prompts = prompts
self.size = size
self.num_inference_steps = num_inference_steps
self.guidance_scale = guidance_scale
self.text_key = text_key
self.seed = seed
self.tokenized_prompts = tokenized_prompts
self.use_table = use_table

def eval_batch_end(self, state: State, logger: Logger):
# Only log once per eval epoch
Expand Down Expand Up @@ -81,4 +84,4 @@ def eval_batch_end(self, state: State, logger: Logger):

# Log images to wandb
for prompt, image in zip(self.prompts, gen_images):
logger.log_images(images=image, name=prompt, step=state.timestamp.batch.value, use_table=False)
logger.log_images(images=image, name=prompt, step=state.timestamp.batch.value, use_table=self.use_table)

0 comments on commit da5eeab

Please sign in to comment.