From e06d3cfa64711e79e0906281bda8407df3bd36f1 Mon Sep 17 00:00:00 2001 From: Ross Morgan-Linial Date: Wed, 29 Mar 2023 16:42:56 -0700 Subject: [PATCH] Suppress safety checker warnings --- scripts/trainer.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/scripts/trainer.py b/scripts/trainer.py index 92f1b91..199b23e 100644 --- a/scripts/trainer.py +++ b/scripts/trainer.py @@ -472,7 +472,7 @@ def main(): safety_checker=None, vae=AutoencoderKL.from_pretrained(args.pretrained_vae_name_or_path or args.pretrained_model_name_or_path,subfolder=None if args.pretrained_vae_name_or_path else "vae" ,safe_serialization=True), torch_dtype=torch_dtype, - + requires_safety_checker=False, ) pipeline.set_progress_bar_config(disable=True) pipeline.to(accelerator.device) @@ -934,6 +934,7 @@ def mid_train_playground(step): safety_checker=None, torch_dtype=weight_dtype, local_files_only=False, + requires_safety_checker=False, ) pipeline.scheduler = scheduler if is_xformers_available() and args.attention=='xformers': @@ -1087,6 +1088,7 @@ def save_and_sample_weights(step,context='checkpoint',save_model=True): safety_checker=None, torch_dtype=weight_dtype, local_files_only=False, + requires_safety_checker=False, ) pipeline.scheduler = scheduler if is_xformers_available() and args.attention=='xformers':