forked from modelscope/DiffSynth-Studio
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathExVideo_svd_train.py
364 lines (308 loc) · 14.2 KB
/
ExVideo_svd_train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
import torch, json, os, imageio, argparse
from torchvision.transforms import v2
import numpy as np
from einops import rearrange, repeat
import lightning as pl
from diffsynth import ModelManager, SVDImageEncoder, SVDUNet, SVDVAEEncoder, ContinuousODEScheduler, load_state_dict
from diffsynth.pipelines.svd_video import SVDCLIPImageProcessor
from diffsynth.models.svd_unet import TemporalAttentionBlock
class TextVideoDataset(torch.utils.data.Dataset):
def __init__(self, base_path, metadata_path, steps_per_epoch=10000, training_shapes=[(128, 1, 128, 512, 512)]):
with open(metadata_path, "r") as f:
metadata = json.load(f)
self.path = [os.path.join(base_path, i["path"]) for i in metadata]
self.steps_per_epoch = steps_per_epoch
self.training_shapes = training_shapes
self.frame_process = []
for max_num_frames, interval, num_frames, height, width in training_shapes:
self.frame_process.append(v2.Compose([
v2.Resize(size=max(height, width), antialias=True),
v2.CenterCrop(size=(height, width)),
v2.Normalize(mean=[127.5, 127.5, 127.5], std=[127.5, 127.5, 127.5]),
]))
def load_frames_using_imageio(self, file_path, max_num_frames, start_frame_id, interval, num_frames, frame_process):
reader = imageio.get_reader(file_path)
if reader.count_frames() < max_num_frames or reader.count_frames() - 1 < start_frame_id + (num_frames - 1) * interval:
reader.close()
return None
frames = []
for frame_id in range(num_frames):
frame = reader.get_data(start_frame_id + frame_id * interval)
frame = torch.tensor(frame, dtype=torch.float32)
frame = rearrange(frame, "H W C -> 1 C H W")
frame = frame_process(frame)
frames.append(frame)
reader.close()
frames = torch.concat(frames, dim=0)
frames = rearrange(frames, "T C H W -> C T H W")
return frames
def load_video(self, file_path, training_shape_id):
data = {}
max_num_frames, interval, num_frames, height, width = self.training_shapes[training_shape_id]
frame_process = self.frame_process[training_shape_id]
start_frame_id = torch.randint(0, max_num_frames - (num_frames - 1) * interval, (1,))[0]
frames = self.load_frames_using_imageio(file_path, max_num_frames, start_frame_id, interval, num_frames, frame_process)
if frames is None:
return None
else:
data[f"frames_{training_shape_id}"] = frames
return data
def __getitem__(self, index):
video_data = {}
for training_shape_id in range(len(self.training_shapes)):
while True:
data_id = torch.randint(0, len(self.path), (1,))[0]
data_id = (data_id + index) % len(self.path) # For fixed seed.
video_file = self.path[data_id]
try:
data = self.load_video(video_file, training_shape_id)
except:
data = None
if data is not None:
break
video_data.update(data)
return video_data
def __len__(self):
return self.steps_per_epoch
class MotionBucketManager:
def __init__(self):
self.thresholds = [
0.000000000, 0.012205946, 0.015117834, 0.018080613, 0.020614484, 0.021959992, 0.024088068, 0.026323952,
0.028277775, 0.029968588, 0.031836554, 0.033596724, 0.035121530, 0.037200287, 0.038914755, 0.040696491,
0.042368013, 0.044265781, 0.046311017, 0.048243891, 0.050294187, 0.052142400, 0.053634230, 0.055612389,
0.057594258, 0.059410289, 0.061283995, 0.063603796, 0.065192916, 0.067146860, 0.069066539, 0.070390493,
0.072588451, 0.073959745, 0.075889029, 0.077695683, 0.079783581, 0.082162730, 0.084092639, 0.085958421,
0.087700523, 0.089684933, 0.091688842, 0.093335517, 0.094987206, 0.096664011, 0.098314710, 0.100262381,
0.101984538, 0.103404313, 0.105280340, 0.106974818, 0.109028399, 0.111164779, 0.113065213, 0.114362158,
0.116407216, 0.118063427, 0.119524263, 0.121835820, 0.124242283, 0.126202747, 0.128989249, 0.131672353,
0.133417681, 0.135567948, 0.137313649, 0.139189199, 0.140912935, 0.143525436, 0.145718485, 0.148315132,
0.151039496, 0.153218940, 0.155252382, 0.157651082, 0.159966752, 0.162195817, 0.164811596, 0.167341709,
0.170251891, 0.172651157, 0.175550997, 0.178372145, 0.181039348, 0.183565900, 0.186599866, 0.190071866,
0.192574754, 0.195026234, 0.198099136, 0.200210452, 0.202522039, 0.205410406, 0.208610669, 0.211623028,
0.214723110, 0.218520239, 0.222194016, 0.225363150, 0.229384825, 0.233422622, 0.237012610, 0.240735114,
0.243622541, 0.247465774, 0.252190471, 0.257356376, 0.261856794, 0.266556412, 0.271076709, 0.277361482,
0.281250387, 0.286582440, 0.291158527, 0.296712339, 0.303008437, 0.311793238, 0.318485111, 0.326999635,
0.332138240, 0.341770738, 0.354188830, 0.365194678, 0.379234344, 0.401538879, 0.416078776, 0.440871328,
]
def get_motion_score(self, frames):
score = frames.std(dim=2).mean(dim=[1, 2, 3]).tolist()
return score
def get_bucket_id(self, motion_score):
for bucket_id in range(len(self.thresholds) - 1):
if self.thresholds[bucket_id + 1] > motion_score:
return bucket_id
return len(self.thresholds) - 1
def __call__(self, frames):
scores = self.get_motion_score(frames)
bucket_ids = [self.get_bucket_id(score) for score in scores]
return bucket_ids
class LightningModel(pl.LightningModule):
def __init__(self, learning_rate=1e-5, svd_ckpt_path=None, add_positional_conv=128, contrast_enhance_scale=1.01):
super().__init__()
state_dict = load_state_dict(svd_ckpt_path)
self.image_encoder = SVDImageEncoder().to(dtype=torch.float16, device=self.device)
self.image_encoder.load_state_dict(SVDImageEncoder.state_dict_converter().from_civitai(state_dict))
self.image_encoder.eval()
self.image_encoder.requires_grad_(False)
self.unet = SVDUNet(add_positional_conv=add_positional_conv).to(dtype=torch.float16, device=self.device)
self.unet.load_state_dict(SVDUNet.state_dict_converter().from_civitai(state_dict, add_positional_conv=add_positional_conv), strict=False)
self.unet.train()
self.unet.requires_grad_(False)
for block in self.unet.blocks:
if isinstance(block, TemporalAttentionBlock):
block.requires_grad_(True)
self.vae_encoder = SVDVAEEncoder().to(dtype=torch.float16, device=self.device)
self.vae_encoder.load_state_dict(SVDVAEEncoder.state_dict_converter().from_civitai(state_dict))
self.vae_encoder.eval()
self.vae_encoder.requires_grad_(False)
self.noise_scheduler = ContinuousODEScheduler(num_inference_steps=1000)
self.learning_rate = learning_rate
self.motion_bucket_manager = MotionBucketManager()
self.contrast_enhance_scale = contrast_enhance_scale
def encode_image_with_clip(self, image):
image = SVDCLIPImageProcessor().resize_with_antialiasing(image, (224, 224))
image = (image + 1.0) / 2.0
mean = torch.tensor([0.48145466, 0.4578275, 0.40821073]).reshape(1, 3, 1, 1).to(device=self.device, dtype=self.dtype)
std = torch.tensor([0.26862954, 0.26130258, 0.27577711]).reshape(1, 3, 1, 1).to(device=self.device, dtype=self.dtype)
image = (image - mean) / std
image_emb = self.image_encoder(image)
return image_emb
def encode_video_with_vae(self, video):
video = video.to(device=self.device, dtype=self.dtype)
video = video.unsqueeze(0)
latents = self.vae_encoder.encode_video(video)
latents = rearrange(latents[0], "C T H W -> T C H W")
return latents
def tensor2video(self, frames):
frames = rearrange(frames, "C T H W -> T H W C")
frames = ((frames.float() + 1) * 127.5).clip(0, 255).cpu().numpy().astype(np.uint8)
return frames
def calculate_loss(self, frames):
with torch.no_grad():
# Call video encoder
latents = self.encode_video_with_vae(frames)
image_emb_vae = repeat(latents[0] / self.vae_encoder.scaling_factor, "C H W -> T C H W", T=frames.shape[1])
image_emb_clip = self.encode_image_with_clip(frames[:,0].unsqueeze(0))
# Call scheduler
timestep = torch.randint(0, len(self.noise_scheduler.timesteps), (1,))[0]
timestep = self.noise_scheduler.timesteps[timestep]
noise = torch.randn_like(latents)
noisy_latents = self.noise_scheduler.add_noise(latents, noise, timestep)
# Prepare positional id
fps = 30
motion_bucket_id = self.motion_bucket_manager(frames.unsqueeze(0))[0]
noise_aug_strength = 0
add_time_id = torch.tensor([[fps-1, motion_bucket_id, noise_aug_strength]], device=self.device)
# Calculate loss
latents_input = torch.cat([noisy_latents, image_emb_vae], dim=1)
model_pred = self.unet(latents_input, timestep, image_emb_clip, add_time_id, use_gradient_checkpointing=True)
latents_output = self.noise_scheduler.step(model_pred.float(), timestep, noisy_latents.float(), to_final=True)
loss = torch.nn.functional.mse_loss(latents_output, latents.float() * self.contrast_enhance_scale, reduction="mean")
# Re-weighting
reweighted_loss = loss * self.noise_scheduler.training_weight(timestep)
return loss, reweighted_loss
def training_step(self, batch, batch_idx):
# Loss
frames = batch["frames_0"][0]
loss, reweighted_loss = self.calculate_loss(frames)
# Record log
self.log("train_loss", loss, prog_bar=True)
self.log("reweighted_train_loss", reweighted_loss, prog_bar=True)
return reweighted_loss
def configure_optimizers(self):
trainable_modules = []
for block in self.unet.blocks:
if isinstance(block, TemporalAttentionBlock):
trainable_modules += block.parameters()
optimizer = torch.optim.AdamW(trainable_modules, lr=self.learning_rate)
return optimizer
def on_save_checkpoint(self, checkpoint):
trainable_param_names = list(filter(lambda named_param: named_param[1].requires_grad, self.unet.named_parameters()))
trainable_param_names = [named_param[0] for named_param in trainable_param_names]
checkpoint["trainable_param_names"] = trainable_param_names
def parse_args():
parser = argparse.ArgumentParser(description="Simple example of a training script.")
parser.add_argument(
"--pretrained_path",
type=str,
default=None,
required=True,
help="Path to pretrained model. For example, `models/stable_video_diffusion/svd_xt.safetensors`.",
)
parser.add_argument(
"--resume_from_checkpoint",
type=str,
default=None,
required=False,
help="Path to checkpoint, in case your training program is stopped unexpectedly and you want to resume.",
)
parser.add_argument(
"--dataset_path",
type=str,
default=None,
required=True,
help="The path of the Dataset.",
)
parser.add_argument(
"--output_path",
type=str,
default="./",
help="Path to save the model.",
)
parser.add_argument(
"--steps_per_epoch",
type=int,
default=500,
help="Number of steps per epoch.",
)
parser.add_argument(
"--num_frames",
type=int,
default=128,
help="Number of frames.",
)
parser.add_argument(
"--height",
type=int,
default=512,
help="Image height.",
)
parser.add_argument(
"--width",
type=int,
default=512,
help="Image width.",
)
parser.add_argument(
"--dataloader_num_workers",
type=int,
default=2,
help="Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.",
)
parser.add_argument(
"--learning_rate",
type=float,
default=1e-5,
help="Learning rate.",
)
parser.add_argument(
"--accumulate_grad_batches",
type=int,
default=1,
help="The number of batches in gradient accumulation.",
)
parser.add_argument(
"--max_epochs",
type=int,
default=1,
help="Number of epochs.",
)
parser.add_argument(
"--contrast_enhance_scale",
type=float,
default=1.01,
help="Avoid generating gray videos.",
)
args = parser.parse_args()
return args
if __name__ == '__main__':
# args
args = parse_args()
# dataset and data loader
dataset = TextVideoDataset(
args.dataset_path,
os.path.join(args.dataset_path, "metadata.json"),
training_shapes=[(args.num_frames, 1, args.num_frames, args.height, args.width)],
steps_per_epoch=args.steps_per_epoch,
)
train_loader = torch.utils.data.DataLoader(
dataset,
shuffle=True,
# We don't support batch_size > 1,
# because sometimes our GPU cannot process even one video.
batch_size=1,
num_workers=args.dataloader_num_workers
)
# model
model = LightningModel(
learning_rate=args.learning_rate,
svd_ckpt_path=args.pretrained_path,
add_positional_conv=args.num_frames,
contrast_enhance_scale=args.contrast_enhance_scale
)
# train
trainer = pl.Trainer(
max_epochs=args.max_epochs,
accelerator="gpu",
devices="auto",
strategy="deepspeed_stage_2",
precision="16-mixed",
default_root_dir=args.output_path,
accumulate_grad_batches=args.accumulate_grad_batches,
callbacks=[pl.pytorch.callbacks.ModelCheckpoint(save_top_k=-1)]
)
trainer.fit(
model=model,
train_dataloaders=train_loader,
ckpt_path=args.resume_from_checkpoint
)