-
Notifications
You must be signed in to change notification settings - Fork 17
/
models_class_cond.py
557 lines (439 loc) · 20.3 KB
/
models_class_cond.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from timm.models.layers import DropPath
import numpy as np
def zero_module(module):
"""
Zero out the parameters of a module and return it.
"""
for p in module.parameters():
p.detach().zero_()
return module
class PositionalEmbedding(torch.nn.Module):
def __init__(self, num_channels, max_positions=10000, endpoint=False):
super().__init__()
self.num_channels = num_channels
self.max_positions = max_positions
self.endpoint = endpoint
def forward(self, x):
freqs = torch.arange(start=0, end=self.num_channels//2, dtype=torch.float32, device=x.device)
freqs = freqs / (self.num_channels // 2 - (1 if self.endpoint else 0))
freqs = (1 / self.max_positions) ** freqs
x = x.ger(freqs.to(x.dtype))
x = torch.cat([x.cos(), x.sin()], dim=1)
return x
class CrossAttention(nn.Module):
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.):
super().__init__()
inner_dim = dim_head * heads
if context_dim is None:
context_dim = query_dim
self.scale = dim_head ** -0.5
self.heads = heads
self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, query_dim),
nn.Dropout(dropout)
)
def forward(self, x, context=None, mask=None):
h = self.heads
q = self.to_q(x)
if context is None:
context = x
k = self.to_k(context)
v = self.to_v(context)
q, k, v = map(lambda t: rearrange(
t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
sim = torch.einsum('b i d, b j d -> b i j', q, k) * self.scale
# attention, what we cannot get enough of
attn = sim.softmax(dim=-1)
out = torch.einsum('b i j, b j d -> b i d', attn, v)
out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
return self.to_out(out)
class LayerScale(nn.Module):
def __init__(self, dim, init_values=1e-5, inplace=False):
super().__init__()
self.inplace = inplace
self.gamma = nn.Parameter(init_values * torch.ones(dim))
def forward(self, x):
return x.mul_(self.gamma) if self.inplace else x * self.gamma
class GEGLU(nn.Module):
def __init__(self, dim_in, dim_out):
super().__init__()
self.proj = nn.Linear(dim_in, dim_out * 2)
def forward(self, x):
x, gate = self.proj(x).chunk(2, dim=-1)
return x * F.gelu(gate)
class FeedForward(nn.Module):
def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
super().__init__()
inner_dim = int(dim * mult)
if dim_out is None:
dim_out = dim
project_in = nn.Sequential(
nn.Linear(dim, inner_dim),
nn.GELU()
) if not glu else GEGLU(dim, inner_dim)
self.net = nn.Sequential(
project_in,
nn.Dropout(dropout),
nn.Linear(inner_dim, dim_out)
)
def forward(self, x):
return self.net(x)
class AdaLayerNorm(nn.Module):
def __init__(self, n_embd):
super().__init__()
self.silu = nn.SiLU()
self.linear = nn.Linear(n_embd, n_embd*2)
self.layernorm = nn.LayerNorm(n_embd, elementwise_affine=False)
def forward(self, x, timestep):
emb = self.linear(timestep)
scale, shift = torch.chunk(emb, 2, dim=2)
x = self.layernorm(x) * (1 + scale) + shift
return x
class BasicTransformerBlock(nn.Module):
def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True):
super().__init__()
self.attn1 = CrossAttention(
query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout) # is a self-attention
self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim,
heads=n_heads, dim_head=d_head, dropout=dropout) # is self-attn if context is none
self.norm1 = AdaLayerNorm(dim)
self.norm2 = AdaLayerNorm(dim)
self.norm3 = AdaLayerNorm(dim)
self.checkpoint = checkpoint
init_values = 0
drop_path = 0.0
self.ls1 = LayerScale(
dim, init_values=init_values) if init_values else nn.Identity()
self.drop_path1 = DropPath(
drop_path) if drop_path > 0. else nn.Identity()
self.ls2 = LayerScale(
dim, init_values=init_values) if init_values else nn.Identity()
self.drop_path2 = DropPath(
drop_path) if drop_path > 0. else nn.Identity()
self.ls3 = LayerScale(
dim, init_values=init_values) if init_values else nn.Identity()
self.drop_path3 = DropPath(
drop_path) if drop_path > 0. else nn.Identity()
def forward(self, x, t, context=None):
x = self.drop_path1(self.ls1(self.attn1(self.norm1(x, t)))) + x
x = self.drop_path2(self.ls2(self.attn2(self.norm2(x, t), context=context))) + x
x = self.drop_path3(self.ls3(self.ff(self.norm3(x, t)))) + x
return x
class LatentArrayTransformer(nn.Module):
"""
Transformer block for image-like data.
First, project the input (aka embedding)
and reshape to b, t, d.
Then apply standard transformer action.
Finally, reshape to image
"""
def __init__(self, in_channels, t_channels, n_heads, d_head,
depth=1, dropout=0., context_dim=None, out_channels=None):
super().__init__()
self.in_channels = in_channels
inner_dim = n_heads * d_head
self.t_channels = t_channels
self.proj_in = nn.Linear(in_channels, inner_dim, bias=False)
self.transformer_blocks = nn.ModuleList(
[BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim)
for _ in range(depth)]
)
self.norm = nn.LayerNorm(inner_dim)
if out_channels is None:
self.proj_out = zero_module(nn.Linear(inner_dim, in_channels, bias=False))
else:
self.num_cls = out_channels
self.proj_out = zero_module(nn.Linear(inner_dim, out_channels, bias=False))
self.context_dim = context_dim
self.map_noise = PositionalEmbedding(t_channels)
self.map_layer0 = nn.Linear(in_features=t_channels, out_features=inner_dim)
self.map_layer1 = nn.Linear(in_features=inner_dim, out_features=inner_dim)
# ###
# self.pos_emb = nn.Embedding(512, inner_dim)
# ###
def forward(self, x, t, cond=None):
t_emb = self.map_noise(t)[:, None]
t_emb = F.silu(self.map_layer0(t_emb))
t_emb = F.silu(self.map_layer1(t_emb))
x = self.proj_in(x)
# ###
# x = x + self.pos_emb.weight[None]
# ###
for block in self.transformer_blocks:
x = block(x, t_emb, context=cond)
x = self.norm(x)
x = self.proj_out(x)
return x
def edm_sampler(
net, latents, class_labels=None, randn_like=torch.randn_like,
num_steps=18, sigma_min=0.002, sigma_max=80, rho=7,
# S_churn=40, S_min=0.05, S_max=50, S_noise=1.003,
S_churn=0, S_min=0, S_max=float('inf'), S_noise=1,
):
# Adjust noise levels based on what's supported by the network.
sigma_min = max(sigma_min, net.sigma_min)
sigma_max = min(sigma_max, net.sigma_max)
# Time step discretization.
step_indices = torch.arange(num_steps, dtype=torch.float64, device=latents.device)
t_steps = (sigma_max ** (1 / rho) + step_indices / (num_steps - 1) * (sigma_min ** (1 / rho) - sigma_max ** (1 / rho))) ** rho
t_steps = torch.cat([net.round_sigma(t_steps), torch.zeros_like(t_steps[:1])]) # t_N = 0
# Main sampling loop.
x_next = latents.to(torch.float64) * t_steps[0]
for i, (t_cur, t_next) in enumerate(zip(t_steps[:-1], t_steps[1:])): # 0, ..., N-1
x_cur = x_next
# Increase noise temporarily.
gamma = min(S_churn / num_steps, np.sqrt(2) - 1) if S_min <= t_cur <= S_max else 0
t_hat = net.round_sigma(t_cur + gamma * t_cur)
x_hat = x_cur + (t_hat ** 2 - t_cur ** 2).sqrt() * S_noise * randn_like(x_cur)
# Euler step.
denoised = net(x_hat, t_hat, class_labels).to(torch.float64)
d_cur = (x_hat - denoised) / t_hat
x_next = x_hat + (t_next - t_hat) * d_cur
# Apply 2nd order correction.
if i < num_steps - 1:
denoised = net(x_next, t_next, class_labels).to(torch.float64)
d_prime = (x_next - denoised) / t_next
x_next = x_hat + (t_next - t_hat) * (0.5 * d_cur + 0.5 * d_prime)
return x_next
def ablation_sampler(
net, latents, class_labels=None, randn_like=torch.randn_like,
num_steps=512, sigma_min=None, sigma_max=None, rho=7,
solver='euler', discretization='vp', schedule='linear', scaling='none',
epsilon_s=1e-3, C_1=0.001, C_2=0.008, M=1000, alpha=1,
S_churn=0, S_min=0, S_max=float('inf'), S_noise=1,
replace_noise=None,
):
assert solver in ['euler', 'heun']
assert discretization in ['vp', 've', 'iddpm', 'edm']
assert schedule in ['vp', 've', 'linear']
assert scaling in ['vp', 'none']
# Helper functions for VP & VE noise level schedules.
vp_sigma = lambda beta_d, beta_min: lambda t: (np.e ** (0.5 * beta_d * (t ** 2) + beta_min * t) - 1) ** 0.5
vp_sigma_deriv = lambda beta_d, beta_min: lambda t: 0.5 * (beta_min + beta_d * t) * (sigma(t) + 1 / sigma(t))
vp_sigma_inv = lambda beta_d, beta_min: lambda sigma: ((beta_min ** 2 + 2 * beta_d * (sigma ** 2 + 1).log()).sqrt() - beta_min) / beta_d
ve_sigma = lambda t: t.sqrt()
ve_sigma_deriv = lambda t: 0.5 / t.sqrt()
ve_sigma_inv = lambda sigma: sigma ** 2
# Select default noise level range based on the specified time step discretization.
if sigma_min is None:
vp_def = vp_sigma(beta_d=19.9, beta_min=0.1)(t=epsilon_s)
sigma_min = {'vp': vp_def, 've': 0.02, 'iddpm': 0.002, 'edm': 0.002}[discretization]
if sigma_max is None:
vp_def = vp_sigma(beta_d=19.9, beta_min=0.1)(t=1)
sigma_max = {'vp': vp_def, 've': 100, 'iddpm': 81, 'edm': 80}[discretization]
# Adjust noise levels based on what's supported by the network.
sigma_min = max(sigma_min, net.sigma_min)
sigma_max = min(sigma_max, net.sigma_max)
# Compute corresponding betas for VP.
vp_beta_d = 2 * (np.log(sigma_min ** 2 + 1) / epsilon_s - np.log(sigma_max ** 2 + 1)) / (epsilon_s - 1)
vp_beta_min = np.log(sigma_max ** 2 + 1) - 0.5 * vp_beta_d
# Define time steps in terms of noise level.
step_indices = torch.arange(num_steps, dtype=torch.float64, device=latents.device)
if discretization == 'vp':
orig_t_steps = 1 + step_indices / (num_steps - 1) * (epsilon_s - 1)
sigma_steps = vp_sigma(vp_beta_d, vp_beta_min)(orig_t_steps)
elif discretization == 've':
orig_t_steps = (sigma_max ** 2) * ((sigma_min ** 2 / sigma_max ** 2) ** (step_indices / (num_steps - 1)))
sigma_steps = ve_sigma(orig_t_steps)
elif discretization == 'iddpm':
u = torch.zeros(M + 1, dtype=torch.float64, device=latents.device)
alpha_bar = lambda j: (0.5 * np.pi * j / M / (C_2 + 1)).sin() ** 2
for j in torch.arange(M, 0, -1, device=latents.device): # M, ..., 1
u[j - 1] = ((u[j] ** 2 + 1) / (alpha_bar(j - 1) / alpha_bar(j)).clip(min=C_1) - 1).sqrt()
u_filtered = u[torch.logical_and(u >= sigma_min, u <= sigma_max)]
sigma_steps = u_filtered[((len(u_filtered) - 1) / (num_steps - 1) * step_indices).round().to(torch.int64)]
else:
assert discretization == 'edm'
sigma_steps = (sigma_max ** (1 / rho) + step_indices / (num_steps - 1) * (sigma_min ** (1 / rho) - sigma_max ** (1 / rho))) ** rho
# Define noise level schedule.
if schedule == 'vp':
sigma = vp_sigma(vp_beta_d, vp_beta_min)
sigma_deriv = vp_sigma_deriv(vp_beta_d, vp_beta_min)
sigma_inv = vp_sigma_inv(vp_beta_d, vp_beta_min)
elif schedule == 've':
sigma = ve_sigma
sigma_deriv = ve_sigma_deriv
sigma_inv = ve_sigma_inv
else:
assert schedule == 'linear'
sigma = lambda t: t
sigma_deriv = lambda t: 1
sigma_inv = lambda sigma: sigma
# Define scaling schedule.
if scaling == 'vp':
s = lambda t: 1 / (1 + sigma(t) ** 2).sqrt()
s_deriv = lambda t: -sigma(t) * sigma_deriv(t) * (s(t) ** 3)
else:
assert scaling == 'none'
s = lambda t: 1
s_deriv = lambda t: 0
# Compute final time steps based on the corresponding noise levels.
t_steps = sigma_inv(net.round_sigma(sigma_steps))
t_steps = torch.cat([t_steps, torch.zeros_like(t_steps[:1])]) # t_N = 0
if replace_noise is None:
noise_list = []
# Main sampling loop.
t_next = t_steps[0]
x_next = latents.to(torch.float64) * (sigma(t_next) * s(t_next))
for i, (t_cur, t_next) in enumerate(zip(t_steps[:-1], t_steps[1:])): # 0, ..., N-1
x_cur = x_next
# Increase noise temporarily.
gamma = min(S_churn / num_steps, np.sqrt(2) - 1) if S_min <= sigma(t_cur) <= S_max else 0
t_hat = sigma_inv(net.round_sigma(sigma(t_cur) + gamma * sigma(t_cur)))
x_hat = s(t_hat) / s(t_cur) * x_cur + (sigma(t_hat) ** 2 - sigma(t_cur) ** 2).clip(min=0).sqrt() * s(t_hat) * S_noise * randn_like(x_cur)
# Euler step.
h = t_next - t_hat
if replace_noise is None:
denoised = net(x_hat / s(t_hat), sigma(t_hat), class_labels).to(torch.float64)
noise_list.append(denoised)
else:
if i >= 0 and i < 5:
print('replace', i)
denoised = replace_noise[i]
d_cur = (sigma_deriv(t_hat) / sigma(t_hat) + s_deriv(t_hat) / s(t_hat)) * x_hat - sigma_deriv(t_hat) * s(t_hat) / sigma(t_hat) * denoised
x_prime = x_hat + alpha * h * d_cur
t_prime = t_hat + alpha * h
# Apply 2nd order correction.
if solver == 'euler' or i == num_steps - 1:
x_next = x_hat + h * d_cur
else:
assert solver == 'heun'
denoised = net(x_prime / s(t_prime), sigma(t_prime), class_labels).to(torch.float64)
d_prime = (sigma_deriv(t_prime) / sigma(t_prime) + s_deriv(t_prime) / s(t_prime)) * x_prime - sigma_deriv(t_prime) * s(t_prime) / sigma(t_prime) * denoised
x_next = x_hat + h * ((1 - 1 / (2 * alpha)) * d_cur + 1 / (2 * alpha) * d_prime)
if replace_noise is None:
# noise_list#.reverse()
return x_next, noise_list
else:
return x_next
def random_masking(self, x, mask_ratio):
"""
Perform per-sample random masking by per-sample shuffling.
Per-sample shuffling is done by argsort random noise.
x: [N, L, D], sequence
"""
N, L, D = x.shape # batch, length, dim
len_keep = int(L * (1 - mask_ratio))
noise = torch.rand(N, L, device=x.device) # noise in [0, 1]
# sort noise for each sample
ids_shuffle = torch.argsort(noise, dim=1) # ascend: small is keep, large is remove
ids_restore = torch.argsort(ids_shuffle, dim=1)
# keep the first subset
ids_keep = ids_shuffle[:, :len_keep]
x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))
# generate the binary mask: 0 is keep, 1 is remove
mask = torch.ones([N, L], device=x.device)
mask[:, :len_keep] = 0
# unshuffle to get the binary mask
mask = torch.gather(mask, dim=1, index=ids_restore)
return x_masked, mask, ids_restore
class EDMLoss:
def __init__(self, P_mean=-1.2, P_std=1.2, sigma_data=1):
self.P_mean = P_mean
self.P_std = P_std
self.sigma_data = sigma_data
def __call__(self, net, inputs, labels=None, augment_pipe=None):
rnd_normal = torch.randn([inputs.shape[0], 1, 1], device=inputs.device)
# rnd_normal = torch.randn([1, 1, 1], device=inputs.device).repeat(inputs.shape[0], 1, 1)
sigma = (rnd_normal * self.P_std + self.P_mean).exp()
weight = (sigma ** 2 + self.sigma_data ** 2) / (sigma * self.sigma_data) ** 2
y, augment_labels = augment_pipe(inputs) if augment_pipe is not None else (inputs, None)
n = torch.randn_like(y) * sigma
D_yn = net(y + n, sigma, labels)
loss = weight * ((D_yn - y) ** 2)
return loss.mean()
class StackedRandomGenerator:
def __init__(self, device, seeds):
super().__init__()
self.generators = [torch.Generator(device).manual_seed(int(seed) % (1 << 32)) for seed in seeds]
def randn(self, size, **kwargs):
assert size[0] == len(self.generators)
return torch.stack([torch.randn(size[1:], generator=gen, **kwargs) for gen in self.generators])
def randn_like(self, input):
return self.randn(input.shape, dtype=input.dtype, layout=input.layout, device=input.device)
def randint(self, *args, size, **kwargs):
assert size[0] == len(self.generators)
return torch.stack([torch.randint(*args, size=size[1:], generator=gen, **kwargs) for gen in self.generators])
class EDMPrecond(torch.nn.Module):
def __init__(self,
n_latents = 512,
channels = 8,
use_fp16 = False,
sigma_min = 0,
sigma_max = float('inf'),
sigma_data = 1,
n_heads = 8,
d_head = 64,
depth = 12,
# depth = 6,
):
super().__init__()
self.n_latents = n_latents
self.channels = channels
self.use_fp16 = use_fp16
self.sigma_min = sigma_min
self.sigma_max = sigma_max
self.sigma_data = sigma_data
self.model = LatentArrayTransformer(in_channels=channels, t_channels=256, n_heads=n_heads, d_head=d_head, depth=depth)
self.category_emb = nn.Embedding(55, n_heads * d_head)
def emb_category(self, class_labels):
return self.category_emb(class_labels).unsqueeze(1)
def forward(self, x, sigma, class_labels=None, force_fp32=False, **model_kwargs):
if class_labels.dtype == torch.float32:
cond_emb = class_labels
else:
cond_emb = self.category_emb(class_labels).unsqueeze(1)
x = x.to(torch.float32)
sigma = sigma.to(torch.float32).reshape(-1, 1, 1)
dtype = torch.float16 if (self.use_fp16 and not force_fp32 and x.device.type == 'cuda') else torch.float32
c_skip = self.sigma_data ** 2 / (sigma ** 2 + self.sigma_data ** 2)
c_out = sigma * self.sigma_data / (sigma ** 2 + self.sigma_data ** 2).sqrt()
c_in = 1 / (self.sigma_data ** 2 + sigma ** 2).sqrt()
c_noise = sigma.log() / 4
F_x = self.model((c_in * x).to(dtype), c_noise.flatten(), cond=cond_emb, **model_kwargs)
assert F_x.dtype == dtype
D_x = c_skip * x + c_out * F_x.to(torch.float32)
return D_x
def round_sigma(self, sigma):
return torch.as_tensor(sigma)
@torch.no_grad()
def sample(self, cond, batch_seeds=None):
# print(batch_seeds)
if cond is not None:
batch_size, device = *cond.shape, cond.device
if batch_seeds is None:
batch_seeds = torch.arange(batch_size)
else:
device = batch_seeds.device
batch_size = batch_seeds.shape[0]
# batch_size, device = *cond.shape, cond.device
# batch_seeds = torch.arange(batch_size)
rnd = StackedRandomGenerator(device, batch_seeds)
latents = rnd.randn([batch_size, self.n_latents, self.channels], device=device)
return edm_sampler(self, latents, cond, randn_like=rnd.randn_like)
def kl_d512_m512_l8_edm():
model = EDMPrecond(n_latents=512, channels=8)
return model
def kl_d512_m512_l16_edm():
model = EDMPrecond(n_latents=512, channels=16)
return model
def kl_d512_m512_l32_edm():
model = EDMPrecond(n_latents=512, channels=32)
return model
def kl_d512_m512_l4_d24_edm():
model = EDMPrecond(n_latents=512, channels=4, depth=24)
return model
def kl_d512_m512_l8_d24_edm():
model = EDMPrecond(n_latents=512, channels=8, depth=24)
return model
def kl_d512_m512_l32_d24_edm():
model = EDMPrecond(n_latents=512, channels=32, depth=24)
return model