|
| 1 | +import os |
| 2 | +from typing import * |
| 3 | + |
| 4 | +import ffmpeg |
| 5 | +import fire |
| 6 | +import PIL.Image |
| 7 | +import torch |
| 8 | +from diffusers import AutoencoderTiny, LCMScheduler, StableDiffusionPipeline |
| 9 | +from tqdm import tqdm |
| 10 | + |
| 11 | +from streamdiffusion import StreamDiffusion |
| 12 | +from streamdiffusion.acceleration.sfast import accelerate_with_stable_fast |
| 13 | +from streamdiffusion.image_utils import pil2tensor, postprocess_image |
| 14 | + |
| 15 | + |
| 16 | +def extract_frames(video_path: str, output_dir: str): |
| 17 | + os.makedirs(output_dir, exist_ok=True) |
| 18 | + ffmpeg.input(video_path).output(f"{output_dir}/%04d.png").run() |
| 19 | + |
| 20 | + |
| 21 | +def get_frame_rate(video_path: str): |
| 22 | + probe = ffmpeg.probe(video_path) |
| 23 | + video_info = next(s for s in probe["streams"] if s["codec_type"] == "video") |
| 24 | + return int(video_info["r_frame_rate"].split("/")[0]) |
| 25 | + |
| 26 | + |
| 27 | +def main(input: str, output: str, scale: int = 1): |
| 28 | + if os.path.isdir(output): |
| 29 | + raise ValueError("Output directory already exists") |
| 30 | + frame_rate = get_frame_rate(input) |
| 31 | + extract_frames(input, os.path.join(output, "frames")) |
| 32 | + images = sorted(os.listdir(os.path.join(output, "frames"))) |
| 33 | + |
| 34 | + pipe: StableDiffusionPipeline = StableDiffusionPipeline.from_single_file("./model.safetensors").to( |
| 35 | + device=torch.device("cuda"), |
| 36 | + dtype=torch.float16, |
| 37 | + ) |
| 38 | + pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config) |
| 39 | + pipe.vae = AutoencoderTiny.from_pretrained("madebyollin/taesd").to(device=pipe.device, dtype=pipe.dtype) |
| 40 | + pipe.load_lora_weights("latent-consistency/lcm-lora-sdv1-5") |
| 41 | + pipe.fuse_lora() |
| 42 | + |
| 43 | + sample_image = PIL.Image.open(os.path.join(output, "frames", images[0])) |
| 44 | + width = int(sample_image.width * scale) |
| 45 | + height = int(sample_image.height * scale) |
| 46 | + |
| 47 | + stream = StreamDiffusion( |
| 48 | + pipe, |
| 49 | + [40, 49], |
| 50 | + torch_dtype=torch.float16, |
| 51 | + width=width, |
| 52 | + height=height, |
| 53 | + ) |
| 54 | + stream = accelerate_with_stable_fast(stream) |
| 55 | + stream.prepare( |
| 56 | + "Girl with panda ears wearing a hood", |
| 57 | + num_inference_steps=50, |
| 58 | + generator=torch.manual_seed(2), |
| 59 | + ) |
| 60 | + |
| 61 | + for _ in range(stream.batch_size - 1): |
| 62 | + stream( |
| 63 | + pil2tensor(sample_image.resize((width, height))) |
| 64 | + .detach() |
| 65 | + .clone() |
| 66 | + .to(device=stream.device, dtype=stream.dtype) |
| 67 | + ) |
| 68 | + |
| 69 | + for image_path in tqdm(images + [images[0]] * (stream.batch_size - 1)): |
| 70 | + pil_image = PIL.Image.open(os.path.join(output, "frames", image_path)) |
| 71 | + pil_image = pil_image.resize((width, height)) |
| 72 | + input_tensor = pil2tensor(pil_image) |
| 73 | + output_x = stream(input_tensor.detach().clone().to(device=stream.device, dtype=stream.dtype)) |
| 74 | + output_image = postprocess_image(output_x, output_type="pil")[0] |
| 75 | + output_image.save(os.path.join(output, image_path)) |
| 76 | + |
| 77 | + output_video_path = os.path.join(output, "output.mp4") |
| 78 | + |
| 79 | + ffmpeg.input(os.path.join(output, "%04d.png"), framerate=frame_rate).output( |
| 80 | + output_video_path, crf=17, pix_fmt="yuv420p", vcodec="libx264" |
| 81 | + ).run() |
| 82 | + |
| 83 | + |
| 84 | +if __name__ == "__main__": |
| 85 | + fire.Fire(main) |
0 commit comments