Skip to content

Commit

Permalink
Merge branch 'main' into add_prestartup_script
Browse files Browse the repository at this point in the history
  • Loading branch information
hjchen2 authored Nov 14, 2024
2 parents b92151e + 424c81a commit a93150f
Showing 1 changed file with 33 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,18 @@
diffusers_0240_v = version.parse("0.24.0")
diffusers_0251_v = version.parse("0.25.1")
diffusers_0263_v = version.parse("0.26.3")
diffusers_0280_v = version.parse("0.28.0")
diffusers_version = version.parse(importlib.metadata.version("diffusers"))

import numpy as np
import PIL.Image
import torch

from diffusers.image_processor import VaeImageProcessor

if diffusers_version >= diffusers_0280_v:
from diffusers.video_processor import VideoProcessor
else:
from diffusers.image_processor import VaeImageProcessor
from diffusers.models import AutoencoderKLTemporalDecoder
from diffusers.schedulers import EulerDiscreteScheduler
from diffusers.utils import BaseOutput, logging
Expand Down Expand Up @@ -100,7 +105,14 @@ def __init__(
feature_extractor=feature_extractor,
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
if diffusers_version >= diffusers_0280_v:
self.video_processor = VideoProcessor(
do_resize=True, vae_scale_factor=self.vae_scale_factor
)
else:
self.image_processor = VaeImageProcessor(
vae_scale_factor=self.vae_scale_factor
)

self.fast_unet = FastUNetSpatioTemporalConditionModel(self.unet)

Expand Down Expand Up @@ -174,18 +186,26 @@ def __call__(
fps = fps - 1

# 4. Encode input image using VAE
if diffusers_version > diffusers_0251_v:
if diffusers_version <= diffusers_0251_v:
image = self.image_processor.preprocess(image, height=height, width=width)
noise = randn_tensor(
image.shape, generator=generator, device=image.device, dtype=image.dtype
)
elif diffusers_version < diffusers_0280_v:
image = self.image_processor.preprocess(
image, height=height, width=width
).to(device)
noise = randn_tensor(
image.shape, generator=generator, device=device, dtype=image.dtype
)
else:
image = self.image_processor.preprocess(image, height=height, width=width)
image = self.video_processor.preprocess(
image, height=height, width=width
).to(device)
noise = randn_tensor(
image.shape, generator=generator, device=image.device, dtype=image.dtype
image.shape, generator=generator, device=device, dtype=image.dtype
)

image = image + noise_aug_strength * noise

needs_upcasting = (
Expand Down Expand Up @@ -359,7 +379,14 @@ def __call__(
if needs_upcasting:
self.vae.to(dtype=torch.float16)
frames = self.decode_latents(latents, num_frames, decode_chunk_size)
frames = tensor2vid(frames, self.image_processor, output_type=output_type)
if diffusers_version < diffusers_0280_v:
frames = tensor2vid(
frames, self.image_processor, output_type=output_type
)
else:
frames = self.video_processor.postprocess_video(
video=frames, output_type=output_type
)
else:
frames = latents

Expand Down

0 comments on commit a93150f

Please sign in to comment.