diff --git a/basicsr/utils/logger.py b/basicsr/utils/logger.py index 73553dc66..2bcd20ccd 100644 --- a/basicsr/utils/logger.py +++ b/basicsr/utils/logger.py @@ -128,6 +128,7 @@ def init_wandb_logger(opt): import wandb logger = get_root_logger() + entity = opt['logger']['wandb'].get('entity', None) project = opt['logger']['wandb']['project'] resume_id = opt['logger']['wandb'].get('resume_id') if resume_id: @@ -138,7 +139,7 @@ def init_wandb_logger(opt): wandb_id = wandb.util.generate_id() resume = 'never' - wandb.init(id=wandb_id, resume=resume, name=opt['name'], config=opt, project=project, sync_tensorboard=True) + wandb.init(id=wandb_id, resume=resume, name=opt['name'], config=opt, entity=entity, project=project, sync_tensorboard=True) logger.info(f'Use wandb logger with id={wandb_id}; project={project}.')