diff --git a/models/turbine_models/custom_models/sdxl_inference/vae.py b/models/turbine_models/custom_models/sdxl_inference/vae.py index 772589a3b..047caf46f 100644 --- a/models/turbine_models/custom_models/sdxl_inference/vae.py +++ b/models/turbine_models/custom_models/sdxl_inference/vae.py @@ -91,7 +91,7 @@ def decode_inp(self, inp): inp = inp / 0.13025 x = self.vae.decode(inp, return_dict=False)[0] x = (x / 2 + 0.5).clamp(0, 1) - return x + return x.round() def encode_inp(self, inp): latents = self.vae.encode(inp).latent_dist.sample() diff --git a/models/turbine_models/custom_models/sdxl_inference/vae_runner.py b/models/turbine_models/custom_models/sdxl_inference/vae_runner.py index 9050096e5..253a3268d 100644 --- a/models/turbine_models/custom_models/sdxl_inference/vae_runner.py +++ b/models/turbine_models/custom_models/sdxl_inference/vae_runner.py @@ -37,9 +37,9 @@ "--batch_size", type=int, default=1, help="Batch size for inference" ) parser.add_argument( - "--height", type=int, default=512, help="Height of Stable Diffusion" + "--height", type=int, default=1024, help="Height of Stable Diffusion" ) -parser.add_argument("--width", type=int, default=512, help="Width of Stable Diffusion") +parser.add_argument("--width", type=int, default=1024, help="Width of Stable Diffusion") parser.add_argument("--variant", type=str, default="decode") @@ -58,51 +58,44 @@ class VaeModel(torch.nn.Module): def __init__( self, hf_model_name, - base_vae=False, custom_vae="", - low_cpu_mem_usage=False, - hf_auth_token="", ): super().__init__() self.vae = None - if custom_vae == "": + if custom_vae in ["", None]: self.vae = AutoencoderKL.from_pretrained( hf_model_name, subfolder="vae", - low_cpu_mem_usage=low_cpu_mem_usage, - hf_auth_token=hf_auth_token, ) elif not isinstance(custom_vae, dict): - self.vae = AutoencoderKL.from_pretrained( - custom_vae, - subfolder="vae", - low_cpu_mem_usage=low_cpu_mem_usage, - hf_auth_token=hf_auth_token, - ) + try: + # custom HF repo with no vae subfolder + self.vae = AutoencoderKL.from_pretrained( + custom_vae, + ) + except: + # some larger repo with vae subfolder + self.vae = AutoencoderKL.from_pretrained( + custom_vae, + subfolder="vae", + ) else: + # custom vae as a HF state dict self.vae = AutoencoderKL.from_pretrained( hf_model_name, subfolder="vae", - low_cpu_mem_usage=low_cpu_mem_usage, - hf_auth_token=hf_auth_token, ) self.vae.load_state_dict(custom_vae) - self.base_vae = base_vae - - def decode_inp(self, input): - with torch.no_grad(): - if not self.base_vae: - input = 1 / 0.18215 * input - x = self.vae.decode(input, return_dict=False)[0] - x = (x / 2 + 0.5).clamp(0, 1) - if self.base_vae: - return x - x = x * 255.0 + + def decode_inp(self, inp): + inp = inp / 0.13025 + x = self.vae.decode(inp, return_dict=False)[0] + x = (x / 2 + 0.5).clamp(0, 1) return x.round() def encode_inp(self, inp): latents = self.vae.encode(inp).latent_dist.sample() - return 0.18215 * latents + return 0.13025 * latents vae_model = VaeModel( hf_model_name, @@ -144,9 +137,7 @@ def encode_inp(self, inp): print("generating torch output: ") from turbine_models.custom_models.sd_inference import utils - torch_output = run_torch_vae( - args.hf_model_name, args.hf_auth_token, args.variant, example_input - ) + torch_output = run_torch_vae(args.hf_model_name, args.variant, example_input) print("TORCH OUTPUT:", torch_output, torch_output.shape, torch_output.dtype) err = utils.largest_error(torch_output, turbine_results) print("Largest Error: ", err)