-
Notifications
You must be signed in to change notification settings - Fork 10
/
test.py
211 lines (188 loc) · 7.29 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
import os
from typing import Dict, List, Literal, Optional
import fire
import numpy as np
import torch
from decord import VideoReader
from PIL import Image
from torchvision import transforms
from torchvision.io import write_video
from tqdm import tqdm
from live2diff.utils.config import load_config
from live2diff.utils.io import read_video_frames, save_videos_grid
from live2diff.utils.wrapper import StreamAnimateDiffusionDepthWrapper
CURRENT_DIR = os.path.dirname(os.path.abspath(__file__))
def main(
input: str,
config_path: str,
prompt: Optional[str] = None,
prompt_template: Optional[str] = None,
output: str = os.path.join("outputs", "output.mp4"),
dreambooth_path: Optional[str] = None,
lora_dict: Optional[Dict[str, float]] = None,
height: int = 512,
width: int = 512,
max_frames: int = -1,
num_inference_steps: Optional[int] = None,
t_index_list: Optional[List[int]] = None,
strength: Optional[float] = None,
acceleration: Literal["none", "xformers", "tensorrt"] = "tensorrt",
enable_similar_image_filter: bool = False,
few_step_model_type: str = "lcm",
enable_tiny_vae: bool = True,
fps: int = 16,
save_input: bool = True,
seed: int = 42,
):
"""
Process for generating images based on a prompt using a specified model.
Parameters
----------
input : str
The input video name or name of video frames to load images from.
config_path: str, optional
The path to config file.
prompt : str
The prompt to generate images from.
prompt_template: str, optional
The template for specific dreambooth / LoRA. If not None, `{}` must be contained,
and the prompt used for inference will be `prompt_template.format(prompt)`.
output : str, optional
The output video name to save images to.
model_id_or_path : str
The name of the model to use for image generation.
lora_dict : Optional[Dict[str, float]], optional
The lora_dict to load, by default None.
Keys are the LoRA names and values are the LoRA scales.
Example: `python main.py --lora_dict='{"LoRA_1" : 0.5 , "LoRA_2" : 0.7 ,...}'`
height: int, optional
The height of the image, by default 512.
width: int, optional
The width of the image, by default 512.
max_frames : int, optional
The maximum number of frames to process, by default -1.
acceleration : Literal["none", "xformers", "tensorrt"]
The type of acceleration to use for image generation.
enable_similar_image_filter : bool, optional
Whether to enable similar image filter or not,
by default True.
fps: int
The fps of the output video, by default 16.
save_input: bool, optional
Whether to save the input video or not, by default True.
If true, the input video will be saved as `output` + "_inp.mp4".
seed : int, optional
The seed, by default 42. if -1, use random seed.
"""
if os.path.isdir(input):
video = read_video_frames(input) / 255
elif input.endswith(".mp4"):
reader = VideoReader(input)
total_frames = len(reader)
frame_indices = np.arange(total_frames)
video = reader.get_batch(frame_indices).asnumpy() / 255
video = torch.from_numpy(video)
elif input.endswith(".gif"):
video_frames = []
image = Image.open(input)
for frames in range(image.n_frames):
image.seek(frames)
video_frames.append(np.array(image.convert("RGB")))
video = torch.from_numpy(np.array(video_frames)) / 255
video = video[2:]
height = int(height // 8 * 8)
width = int(width // 8 * 8)
trans = transforms.Compose(
[
transforms.Resize(min(height, width), antialias=True),
transforms.CenterCrop((height, width)),
]
)
video = trans(video.permute(0, 3, 1, 2)).permute(0, 2, 3, 1)
if max_frames > 0:
video = video[: min(max_frames, len(video))]
print(f"Clipping video to {len(video)} frames.")
cfg = load_config(config_path)
print("Inference Config:")
print(cfg)
# handle prompt
cfg_prompt = cfg.get("prompt", None)
prompt = prompt or cfg_prompt
prompt_template = prompt_template or cfg.get("prompt_template", None)
if prompt_template is not None:
assert "{}" in prompt_template, '"{}" must be contained in "prompt_template".'
prompt = prompt_template.format(prompt)
print(f'Convert input prompt to "{prompt}".')
# handle timesteps
num_inference_steps = num_inference_steps or cfg.get("num_inference_steps", None)
strength = strength or cfg.get("strength", None)
t_index_list = t_index_list or cfg.get("t_index_list", None)
stream = StreamAnimateDiffusionDepthWrapper(
few_step_model_type=few_step_model_type,
config_path=config_path,
cfg_type="none",
dreambooth_path=dreambooth_path,
lora_dict=lora_dict,
strength=strength,
num_inference_steps=num_inference_steps,
t_index_list=t_index_list,
frame_buffer_size=1,
width=width,
height=height,
acceleration=acceleration,
do_add_noise=True,
output_type="pt",
enable_similar_image_filter=enable_similar_image_filter,
similar_image_filter_threshold=0.98,
use_denoising_batch=True,
use_tiny_vae=enable_tiny_vae,
seed=seed,
)
warmup_frames = video[:8].permute(0, 3, 1, 2)
warmup_results = stream.prepare(
warmup_frames=warmup_frames,
prompt=prompt,
guidance_scale=1,
)
video_result = torch.zeros(video.shape[0], height, width, 3)
warmup_results = warmup_results.cpu().float()
video_result[:8] = warmup_results
skip_frames = stream.batch_size - 1
for i in tqdm(range(8, video.shape[0])):
output_image = stream(video[i].permute(2, 0, 1))
if i - 8 >= skip_frames:
video_result[i - skip_frames] = output_image.permute(1, 2, 0)
video_result = video_result[:-skip_frames]
# video_result = video_result[:8]
save_root = os.path.dirname(output)
if save_root != "":
os.makedirs(save_root, exist_ok=True)
if output.endswith(".mp4"):
video_result = video_result * 255
write_video(output, video_result, fps=fps)
if save_input:
write_video(output.replace(".mp4", "_inp.mp4"), video * 255, fps=fps)
elif output.endswith(".gif"):
save_videos_grid(
video_result.permute(3, 0, 1, 2)[None, ...],
output,
rescale=False,
fps=fps,
)
if save_input:
save_videos_grid(
video.permute(3, 0, 1, 2)[None, ...],
output.replace(".gif", "_inp.gif"),
rescale=False,
fps=fps,
)
else:
raise TypeError(f"Unsupported output format: {output}")
print("Inference time ema: ", stream.stream.inference_time_ema)
inference_time_list = np.array(stream.stream.inference_time_list)
print(f"Inference time mean & std: {inference_time_list.mean()} +/- {inference_time_list.std()}")
if hasattr(stream.stream, "depth_time_ema"):
print("Depth time ema: ", stream.stream.depth_time_ema)
print(f'Video saved to "{output}".')
if __name__ == "__main__":
fire.Fire(main)