-
Notifications
You must be signed in to change notification settings - Fork 17
/
models_ae.py
475 lines (353 loc) · 14.5 KB
/
models_ae.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
from functools import wraps
import numpy as np
import torch
from torch import nn, einsum
import torch.nn.functional as F
from einops import rearrange, repeat
from torch_cluster import fps
from timm.models.layers import DropPath
def exists(val):
return val is not None
def default(val, d):
return val if exists(val) else d
def cache_fn(f):
cache = None
@wraps(f)
def cached_fn(*args, _cache = True, **kwargs):
if not _cache:
return f(*args, **kwargs)
nonlocal cache
if cache is not None:
return cache
cache = f(*args, **kwargs)
return cache
return cached_fn
class PreNorm(nn.Module):
def __init__(self, dim, fn, context_dim = None):
super().__init__()
self.fn = fn
self.norm = nn.LayerNorm(dim)
self.norm_context = nn.LayerNorm(context_dim) if exists(context_dim) else None
def forward(self, x, **kwargs):
x = self.norm(x)
if exists(self.norm_context):
context = kwargs['context']
normed_context = self.norm_context(context)
kwargs.update(context = normed_context)
return self.fn(x, **kwargs)
class GEGLU(nn.Module):
def forward(self, x):
x, gates = x.chunk(2, dim = -1)
return x * F.gelu(gates)
class FeedForward(nn.Module):
def __init__(self, dim, mult = 4, drop_path_rate = 0.0):
super().__init__()
self.net = nn.Sequential(
nn.Linear(dim, dim * mult * 2),
GEGLU(),
nn.Linear(dim * mult, dim)
)
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()
def forward(self, x):
return self.drop_path(self.net(x))
class Attention(nn.Module):
def __init__(self, query_dim, context_dim = None, heads = 8, dim_head = 64, drop_path_rate = 0.0):
super().__init__()
inner_dim = dim_head * heads
context_dim = default(context_dim, query_dim)
self.scale = dim_head ** -0.5
self.heads = heads
self.to_q = nn.Linear(query_dim, inner_dim, bias = False)
self.to_kv = nn.Linear(context_dim, inner_dim * 2, bias = False)
self.to_out = nn.Linear(inner_dim, query_dim)
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()
def forward(self, x, context = None, mask = None):
h = self.heads
q = self.to_q(x)
context = default(context, x)
k, v = self.to_kv(context).chunk(2, dim = -1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h = h), (q, k, v))
sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
if exists(mask):
mask = rearrange(mask, 'b ... -> b (...)')
max_neg_value = -torch.finfo(sim.dtype).max
mask = repeat(mask, 'b j -> (b h) () j', h = h)
sim.masked_fill_(~mask, max_neg_value)
# attention, what we cannot get enough of
attn = sim.softmax(dim = -1)
out = einsum('b i j, b j d -> b i d', attn, v)
out = rearrange(out, '(b h) n d -> b n (h d)', h = h)
return self.drop_path(self.to_out(out))
class PointEmbed(nn.Module):
def __init__(self, hidden_dim=48, dim=128):
super().__init__()
assert hidden_dim % 6 == 0
self.embedding_dim = hidden_dim
e = torch.pow(2, torch.arange(self.embedding_dim // 6)).float() * np.pi
e = torch.stack([
torch.cat([e, torch.zeros(self.embedding_dim // 6),
torch.zeros(self.embedding_dim // 6)]),
torch.cat([torch.zeros(self.embedding_dim // 6), e,
torch.zeros(self.embedding_dim // 6)]),
torch.cat([torch.zeros(self.embedding_dim // 6),
torch.zeros(self.embedding_dim // 6), e]),
])
self.register_buffer('basis', e) # 3 x 16
self.mlp = nn.Linear(self.embedding_dim+3, dim)
@staticmethod
def embed(input, basis):
projections = torch.einsum(
'bnd,de->bne', input, basis)
embeddings = torch.cat([projections.sin(), projections.cos()], dim=2)
return embeddings
def forward(self, input):
# input: B x N x 3
embed = self.mlp(torch.cat([self.embed(input, self.basis), input], dim=2)) # B x N x C
return embed
class DiagonalGaussianDistribution(object):
def __init__(self, mean, logvar, deterministic=False):
self.mean = mean
self.logvar = logvar
self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
self.deterministic = deterministic
self.std = torch.exp(0.5 * self.logvar)
self.var = torch.exp(self.logvar)
if self.deterministic:
self.var = self.std = torch.zeros_like(self.mean).to(device=self.mean.device)
def sample(self):
x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.mean.device)
return x
def kl(self, other=None):
if self.deterministic:
return torch.Tensor([0.])
else:
if other is None:
return 0.5 * torch.mean(torch.pow(self.mean, 2)
+ self.var - 1.0 - self.logvar,
dim=[1, 2])
else:
return 0.5 * torch.mean(
torch.pow(self.mean - other.mean, 2) / other.var
+ self.var / other.var - 1.0 - self.logvar + other.logvar,
dim=[1, 2, 3])
def nll(self, sample, dims=[1,2,3]):
if self.deterministic:
return torch.Tensor([0.])
logtwopi = np.log(2.0 * np.pi)
return 0.5 * torch.sum(
logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
dim=dims)
def mode(self):
return self.mean
class AutoEncoder(nn.Module):
def __init__(
self,
*,
depth=24,
dim=512,
queries_dim=512,
output_dim = 1,
num_inputs = 2048,
num_latents = 512,
heads = 8,
dim_head = 64,
weight_tie_layers = False,
decoder_ff = False
):
super().__init__()
self.depth = depth
self.num_inputs = num_inputs
self.num_latents = num_latents
self.cross_attend_blocks = nn.ModuleList([
PreNorm(dim, Attention(dim, dim, heads = 1, dim_head = dim), context_dim = dim),
PreNorm(dim, FeedForward(dim))
])
self.point_embed = PointEmbed(dim=dim)
get_latent_attn = lambda: PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, drop_path_rate=0.1))
get_latent_ff = lambda: PreNorm(dim, FeedForward(dim, drop_path_rate=0.1))
get_latent_attn, get_latent_ff = map(cache_fn, (get_latent_attn, get_latent_ff))
self.layers = nn.ModuleList([])
cache_args = {'_cache': weight_tie_layers}
for i in range(depth):
self.layers.append(nn.ModuleList([
get_latent_attn(**cache_args),
get_latent_ff(**cache_args)
]))
self.decoder_cross_attn = PreNorm(queries_dim, Attention(queries_dim, dim, heads = 1, dim_head = dim), context_dim = dim)
self.decoder_ff = PreNorm(queries_dim, FeedForward(queries_dim)) if decoder_ff else None
self.to_outputs = nn.Linear(queries_dim, output_dim) if exists(output_dim) else nn.Identity()
def encode(self, pc):
# pc: B x N x 3
B, N, D = pc.shape
assert N == self.num_inputs
###### fps
flattened = pc.view(B*N, D)
batch = torch.arange(B).to(pc.device)
batch = torch.repeat_interleave(batch, N)
pos = flattened
ratio = 1.0 * self.num_latents / self.num_inputs
idx = fps(pos, batch, ratio=ratio)
sampled_pc = pos[idx]
sampled_pc = sampled_pc.view(B, -1, 3)
######
sampled_pc_embeddings = self.point_embed(sampled_pc)
pc_embeddings = self.point_embed(pc)
cross_attn, cross_ff = self.cross_attend_blocks
x = cross_attn(sampled_pc_embeddings, context = pc_embeddings, mask = None) + sampled_pc_embeddings
x = cross_ff(x) + x
return x
def decode(self, x, queries):
for self_attn, self_ff in self.layers:
x = self_attn(x) + x
x = self_ff(x) + x
# cross attend from decoder queries to latents
queries_embeddings = self.point_embed(queries)
latents = self.decoder_cross_attn(queries_embeddings, context = x)
# optional decoder feedforward
if exists(self.decoder_ff):
latents = latents + self.decoder_ff(latents)
return self.to_outputs(latents)
def forward(self, pc, queries):
x = self.encode(pc)
o = self.decode(x, queries).squeeze(-1)
return {'logits': o}
class KLAutoEncoder(nn.Module):
def __init__(
self,
*,
depth=24,
dim=512,
queries_dim=512,
output_dim = 1,
num_inputs = 2048,
num_latents = 512,
latent_dim = 64,
heads = 8,
dim_head = 64,
weight_tie_layers = False,
decoder_ff = False
):
super().__init__()
self.depth = depth
self.num_inputs = num_inputs
self.num_latents = num_latents
self.cross_attend_blocks = nn.ModuleList([
PreNorm(dim, Attention(dim, dim, heads = 1, dim_head = dim), context_dim = dim),
PreNorm(dim, FeedForward(dim))
])
self.point_embed = PointEmbed(dim=dim)
get_latent_attn = lambda: PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, drop_path_rate=0.1))
get_latent_ff = lambda: PreNorm(dim, FeedForward(dim, drop_path_rate=0.1))
get_latent_attn, get_latent_ff = map(cache_fn, (get_latent_attn, get_latent_ff))
self.layers = nn.ModuleList([])
cache_args = {'_cache': weight_tie_layers}
for i in range(depth):
self.layers.append(nn.ModuleList([
get_latent_attn(**cache_args),
get_latent_ff(**cache_args)
]))
self.decoder_cross_attn = PreNorm(queries_dim, Attention(queries_dim, dim, heads = 1, dim_head = dim), context_dim = dim)
self.decoder_ff = PreNorm(queries_dim, FeedForward(queries_dim)) if decoder_ff else None
self.to_outputs = nn.Linear(queries_dim, output_dim) if exists(output_dim) else nn.Identity()
self.proj = nn.Linear(latent_dim, dim)
self.mean_fc = nn.Linear(dim, latent_dim)
self.logvar_fc = nn.Linear(dim, latent_dim)
def encode(self, pc):
# pc: B x N x 3
B, N, D = pc.shape
assert N == self.num_inputs
###### fps
flattened = pc.view(B*N, D)
batch = torch.arange(B).to(pc.device)
batch = torch.repeat_interleave(batch, N)
pos = flattened
ratio = 1.0 * self.num_latents / self.num_inputs
idx = fps(pos, batch, ratio=ratio)
sampled_pc = pos[idx]
sampled_pc = sampled_pc.view(B, -1, 3)
######
sampled_pc_embeddings = self.point_embed(sampled_pc)
pc_embeddings = self.point_embed(pc)
cross_attn, cross_ff = self.cross_attend_blocks
x = cross_attn(sampled_pc_embeddings, context = pc_embeddings, mask = None) + sampled_pc_embeddings
x = cross_ff(x) + x
mean = self.mean_fc(x)
logvar = self.logvar_fc(x)
posterior = DiagonalGaussianDistribution(mean, logvar)
x = posterior.sample()
kl = posterior.kl()
return kl, x
def decode(self, x, queries):
x = self.proj(x)
for self_attn, self_ff in self.layers:
x = self_attn(x) + x
x = self_ff(x) + x
# cross attend from decoder queries to latents
queries_embeddings = self.point_embed(queries)
latents = self.decoder_cross_attn(queries_embeddings, context = x)
# optional decoder feedforward
if exists(self.decoder_ff):
latents = latents + self.decoder_ff(latents)
return self.to_outputs(latents)
def forward(self, pc, queries):
kl, x = self.encode(pc)
o = self.decode(x, queries).squeeze(-1)
# return o.squeeze(-1), kl
return {'logits': o, 'kl': kl}
def create_autoencoder(dim=512, M=512, latent_dim=64, N=2048, determinisitc=False):
if determinisitc:
model = AutoEncoder(
depth=24,
dim=dim,
queries_dim=dim,
output_dim = 1,
num_inputs = N,
num_latents = M,
heads = 8,
dim_head = 64,
)
else:
model = KLAutoEncoder(
depth=24,
dim=dim,
queries_dim=dim,
output_dim = 1,
num_inputs = N,
num_latents = M,
latent_dim = latent_dim,
heads = 8,
dim_head = 64,
)
return model
def kl_d512_m512_l512(N=2048):
return create_autoencoder(dim=512, M=512, latent_dim=512, N=N, determinisitc=False)
def kl_d512_m512_l64(N=2048):
return create_autoencoder(dim=512, M=512, latent_dim=64, N=N, determinisitc=False)
def kl_d512_m512_l32(N=2048):
return create_autoencoder(dim=512, M=512, latent_dim=32, N=N, determinisitc=False)
def kl_d512_m512_l16(N=2048):
return create_autoencoder(dim=512, M=512, latent_dim=16, N=N, determinisitc=False)
def kl_d512_m512_l8(N=2048):
return create_autoencoder(dim=512, M=512, latent_dim=8, N=N, determinisitc=False)
def kl_d512_m512_l4(N=2048):
return create_autoencoder(dim=512, M=512, latent_dim=4, N=N, determinisitc=False)
def kl_d512_m512_l2(N=2048):
return create_autoencoder(dim=512, M=512, latent_dim=2, N=N, determinisitc=False)
def kl_d512_m512_l1(N=2048):
return create_autoencoder(dim=512, M=512, latent_dim=1, N=N, determinisitc=False)
###
def ae_d512_m512(N=2048):
return create_autoencoder(dim=512, M=512, N=N, determinisitc=True)
def ae_d512_m256(N=2048):
return create_autoencoder(dim=512, M=256, N=N, determinisitc=True)
def ae_d512_m128(N=2048):
return create_autoencoder(dim=512, M=128, N=N, determinisitc=True)
def ae_d512_m64(N=2048):
return create_autoencoder(dim=512, M=64, N=N, determinisitc=True)
###
def ae_d256_m512(N=2048):
return create_autoencoder(dim=256, M=512, N=N, determinisitc=True)
def ae_d128_m512(N=2048):
return create_autoencoder(dim=128, M=512, N=N, determinisitc=True)
def ae_d64_m512(N=2048):
return create_autoencoder(dim=64, M=512, N=N, determinisitc=True)