-
Notifications
You must be signed in to change notification settings - Fork 864
/
cli_vae_demo.py
140 lines (112 loc) · 6.14 KB
/
cli_vae_demo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
"""
This script is designed to demonstrate how to use the CogVideoX-2b VAE model for video encoding and decoding.
It allows you to encode a video into a latent representation, decode it back into a video, or perform both operations sequentially.
Before running the script, make sure to clone the CogVideoX Hugging Face model repository and set the
`{your local diffusers path}` argument to the path of the cloned repository.
Command 1: Encoding Video
Encodes the video located at ../resources/videos/1.mp4 using the CogVideoX-5b VAE model.
Memory Usage: ~18GB of GPU memory for encoding.
If you do not have enough GPU memory, we provide a pre-encoded tensor file (encoded.pt) in the resources folder,
and you can still run the decoding command.
$ python cli_vae_demo.py --model_path {your local diffusers path}/CogVideoX-2b/vae/ --video_path ../resources/videos/1.mp4 --mode encode
Command 2: Decoding Video
Decodes the latent representation stored in encoded.pt back into a video.
Memory Usage: ~4GB of GPU memory for decoding.
$ python cli_vae_demo.py --model_path {your local diffusers path}/CogVideoX-2b/vae/ --encoded_path ./encoded.pt --mode decode
Command 3: Encoding and Decoding Video
Encodes the video located at ../resources/videos/1.mp4 and then immediately decodes it.
Memory Usage: 34GB for encoding + 19GB for decoding (sequentially).
$ python cli_vae_demo.py --model_path {your local diffusers path}/CogVideoX-2b/vae/ --video_path ../resources/videos/1.mp4 --mode both
"""
import argparse
import torch
import imageio
from diffusers import AutoencoderKLCogVideoX
from torchvision import transforms
import numpy as np
def encode_video(model_path, video_path, dtype, device):
"""
Loads a pre-trained AutoencoderKLCogVideoX model and encodes the video frames.
Parameters:
- model_path (str): The path to the pre-trained model.
- video_path (str): The path to the video file.
- dtype (torch.dtype): The data type for computation.
- device (str): The device to use for computation (e.g., "cuda" or "cpu").
Returns:
- torch.Tensor: The encoded video frames.
"""
model = AutoencoderKLCogVideoX.from_pretrained(model_path, torch_dtype=dtype).to(device)
model.enable_slicing()
model.enable_tiling()
video_reader = imageio.get_reader(video_path, "ffmpeg")
frames = [transforms.ToTensor()(frame) for frame in video_reader]
video_reader.close()
frames_tensor = torch.stack(frames).to(device).permute(1, 0, 2, 3).unsqueeze(0).to(dtype)
with torch.no_grad():
encoded_frames = model.encode(frames_tensor)[0].sample()
return encoded_frames
def decode_video(model_path, encoded_tensor_path, dtype, device):
"""
Loads a pre-trained AutoencoderKLCogVideoX model and decodes the encoded video frames.
Parameters:
- model_path (str): The path to the pre-trained model.
- encoded_tensor_path (str): The path to the encoded tensor file.
- dtype (torch.dtype): The data type for computation.
- device (str): The device to use for computation (e.g., "cuda" or "cpu").
Returns:
- torch.Tensor: The decoded video frames.
"""
model = AutoencoderKLCogVideoX.from_pretrained(model_path, torch_dtype=dtype).to(device)
encoded_frames = torch.load(encoded_tensor_path, weights_only=True).to(device).to(dtype)
with torch.no_grad():
decoded_frames = model.decode(encoded_frames).sample
return decoded_frames
def save_video(tensor, output_path):
"""
Saves the video frames to a video file.
Parameters:
- tensor (torch.Tensor): The video frames' tensor.
- output_path (str): The path to save the output video.
"""
tensor = tensor.to(dtype=torch.float32)
frames = tensor[0].squeeze(0).permute(1, 2, 3, 0).cpu().numpy()
frames = np.clip(frames, 0, 1) * 255
frames = frames.astype(np.uint8)
writer = imageio.get_writer(output_path + "/output.mp4", fps=8)
for frame in frames:
writer.append_data(frame)
writer.close()
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="CogVideoX encode/decode demo")
parser.add_argument("--model_path", type=str, required=True, help="The path to the CogVideoX model")
parser.add_argument("--video_path", type=str, help="The path to the video file (for encoding)")
parser.add_argument("--encoded_path", type=str, help="The path to the encoded tensor file (for decoding)")
parser.add_argument("--output_path", type=str, default=".", help="The path to save the output file")
parser.add_argument(
"--mode", type=str, choices=["encode", "decode", "both"], required=True, help="Mode: encode, decode, or both"
)
parser.add_argument(
"--dtype", type=str, default="bfloat16", help="The data type for computation (e.g., 'float16' or 'bfloat16')"
)
parser.add_argument(
"--device", type=str, default="cuda", help="The device to use for computation (e.g., 'cuda' or 'cpu')"
)
args = parser.parse_args()
device = torch.device(args.device)
dtype = torch.float16 if args.dtype == "float16" else torch.bfloat16
if args.mode == "encode":
assert args.video_path, "Video path must be provided for encoding."
encoded_output = encode_video(args.model_path, args.video_path, dtype, device)
torch.save(encoded_output, args.output_path + "/encoded.pt")
print(f"Finished encoding the video to a tensor, save it to a file at {encoded_output}/encoded.pt")
elif args.mode == "decode":
assert args.encoded_path, "Encoded tensor path must be provided for decoding."
decoded_output = decode_video(args.model_path, args.encoded_path, dtype, device)
save_video(decoded_output, args.output_path)
print(f"Finished decoding the video and saved it to a file at {args.output_path}/output.mp4")
elif args.mode == "both":
assert args.video_path, "Video path must be provided for encoding."
encoded_output = encode_video(args.model_path, args.video_path, dtype, device)
torch.save(encoded_output, args.output_path + "/encoded.pt")
decoded_output = decode_video(args.model_path, args.output_path + "/encoded.pt", dtype, device)
save_video(decoded_output, args.output_path)