Skip to content

Commit

Permalink
pass loggers to trainer (#166)
Browse files Browse the repository at this point in the history
  • Loading branch information
jazcollins authored Aug 20, 2024
1 parent 52088bf commit 1d4e6fa
Showing 1 changed file with 9 additions and 14 deletions.
23 changes: 9 additions & 14 deletions diffusion/evaluation/clean_fid_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,10 @@

import clip
import torch
import wandb
from cleanfid import fid
from composer import ComposerModel, Trainer
from composer.core import get_precision_context
from composer.loggers import LoggerDestination, WandBLogger
from composer.loggers import LoggerDestination
from composer.utils import dist
from torch.utils.data import DataLoader
from torchmetrics.multimodal import CLIPScore
Expand Down Expand Up @@ -91,19 +90,15 @@ def __init__(self,
self.additional_generate_kwargs = additional_generate_kwargs if additional_generate_kwargs is not None else {}
self.sdxl = model.sdxl

# Init loggers
if self.loggers and dist.get_local_rank() == 0:
for logger in self.loggers:
if isinstance(logger, WandBLogger):
wandb.init(**logger._init_kwargs)

# Load the model
Trainer(model=self.model,
load_path=self.load_path,
load_weights_only=True,
load_strict_model_weights=load_strict_model_weights,
eval_dataloader=self.eval_dataloader,
seed=self.seed)
trainer = Trainer(model=self.model,
load_path=self.load_path,
load_weights_only=True,
load_strict_model_weights=load_strict_model_weights,
eval_dataloader=self.eval_dataloader,
seed=self.seed,
loggers=self.loggers)
self.trainer = trainer

# Move CLIP metric to device
self.device = dist.get_local_rank()
Expand Down

0 comments on commit 1d4e6fa

Please sign in to comment.