-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathflow_utils.py
76 lines (60 loc) · 2.25 KB
/
flow_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import torch
import torch.nn as nn
import torch.nn.functional as F
def conv(batchNorm, in_planes, out_planes, kernel_size=3, stride=1):
if batchNorm:
return nn.Sequential(
nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=(kernel_size - 1) // 2,
bias=False),
nn.BatchNorm2d(out_planes),
nn.LeakyReLU(0.1, inplace=True)
)
else:
return nn.Sequential(
nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=(kernel_size - 1) // 2,
bias=True),
nn.LeakyReLU(0.1, inplace=True)
)
def predict_flow(in_planes):
return nn.Conv2d(in_planes, 2, kernel_size=3, stride=1, padding=1, bias=False)
def deconv(in_planes, out_planes):
return nn.Sequential(
nn.ConvTranspose2d(in_planes, out_planes, kernel_size=4, stride=2, padding=1, bias=False),
nn.LeakyReLU(0.1, inplace=True)
)
def crop_like(input, target):
if input.size()[2:] == target.size()[2:]:
return input
else:
return input[:, :, :target.size(2), :target.size(3)]
def warp(x, flo):
"""
warp an image/tensor (im2) back to im1, according to the optical flow
x: [B, C, H, W] (im2)
flo: [B, 2, H, W] flow
"""
B, C, H, W = x.size()
# mesh grid
xx = torch.arange(0, W).view(1, -1).repeat(H, 1)
yy = torch.arange(0, H).view(-1, 1).repeat(1, W)
xx = xx.view(1, 1, H, W).repeat(B, 1, 1, 1)
yy = yy.view(1, 1, H, W).repeat(B, 1, 1, 1)
grid = torch.cat((xx, yy), 1).float()
if x.is_cuda:
grid = grid.cuda()
vgrid = grid + flo
# scale grid to [-1,1]
vgrid[:, 0, :, :] = 2.0 * vgrid[:, 0, :, :].clone() / max(W - 1, 1) - 1.0
vgrid[:, 1, :, :] = 2.0 * vgrid[:, 1, :, :].clone() / max(H - 1, 1) - 1.0
vgrid = vgrid.permute(0, 2, 3, 1)
output = F.grid_sample(x, vgrid, align_corners=True)
mask = torch.ones(x.size())
if x.is_cuda:
mask = mask.cuda()
mask = F.grid_sample(mask, vgrid, align_corners=True)
# if W==128:
# np.save('mask.npy', mask.cpu().data.numpy())
# np.save('warp.npy', output.cpu().data.numpy())
mask[mask < 0.9999] = 0
mask[mask > 0] = 1
return output * mask