Skip to content

Commit d882239

Browse files
committed
Initial commit
1 parent 76f4670 commit d882239

File tree

12 files changed

+1016
-2
lines changed

12 files changed

+1016
-2
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
/photos/output.mp4

README.md

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,8 @@
1-
# frame-interpolation-pytorch
2-
PyTorch implementation of FILM: Frame Interpolation for Large Motion, In ECCV 2022.
1+
# Frame interpolation in PyTorch
2+
3+
This is a PyTorch inference implementation
4+
of [FILM: Frame Interpolation for Large Motion, In ECCV 2022](https://film-net.github.io/).\
5+
[Original repository link](https://github.com/google-research/frame-interpolation)
6+
7+
The project is focused on creating simple and TorchScript compilable inference interface for the original pretrained TF2
8+
model.

export.py

Lines changed: 154 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,154 @@
1+
import warnings
2+
3+
import numpy as np
4+
import tensorflow as tf
5+
import torch
6+
7+
from interpolator import Interpolator
8+
9+
10+
def translate_state_dict(var_dict, state_dict):
11+
for name, (prev_name, weight) in zip(state_dict, var_dict.items()):
12+
print('Mapping', prev_name, '->', name)
13+
weight = torch.from_numpy(weight)
14+
if 'kernel' in prev_name:
15+
# Transpose the conv2d kernel weights, since TF uses (H, W, C, K) and PyTorch uses (K, C, H, W)
16+
weight = weight.permute(3, 2, 0, 1)
17+
18+
assert state_dict[name].shape == weight.shape, f'Shape mismatch {state_dict[name].shape} != {weight.shape}'
19+
20+
state_dict[name] = weight
21+
22+
23+
def import_state_dict(interpolator: Interpolator, saved_model):
24+
variables = saved_model.keras_api.variables
25+
26+
extract_dict = interpolator.extract.state_dict()
27+
flow_dict = interpolator.predict_flow.state_dict()
28+
fuse_dict = interpolator.fuse.state_dict()
29+
30+
extract_vars = {}
31+
_flow_vars = {}
32+
_fuse_vars = {}
33+
34+
for var in variables:
35+
name = var.name
36+
if name.startswith('feat_net'):
37+
extract_vars[name[9:]] = var.numpy()
38+
elif name.startswith('predict_flow'):
39+
_flow_vars[name[13:]] = var.numpy()
40+
elif name.startswith('fusion'):
41+
_fuse_vars[name[7:]] = var.numpy()
42+
43+
# reverse order of modules to allow jit export
44+
# TODO: improve this hack
45+
flow_vars = dict(sorted(_flow_vars.items(), key=lambda x: x[0].split('/')[0], reverse=True))
46+
fuse_vars = dict(sorted(_fuse_vars.items(), key=lambda x: int((x[0].split('/')[0].split('_')[1:] or [0])[0]) // 3, reverse=True))
47+
48+
assert len(extract_vars) == len(extract_dict), f'{len(extract_vars)} != {len(extract_dict)}'
49+
assert len(flow_vars) == len(flow_dict), f'{len(flow_vars)} != {len(flow_dict)}'
50+
assert len(fuse_vars) == len(fuse_dict), f'{len(fuse_vars)} != {len(fuse_dict)}'
51+
52+
for state_dict, var_dict in ((extract_dict, extract_vars), (flow_dict, flow_vars), (fuse_dict, fuse_vars)):
53+
translate_state_dict(var_dict, state_dict)
54+
55+
interpolator.extract.load_state_dict(extract_dict)
56+
interpolator.predict_flow.load_state_dict(flow_dict)
57+
interpolator.fuse.load_state_dict(fuse_dict)
58+
59+
60+
def verify_debug_outputs(pt_outputs, tf_outputs):
61+
max_error = 0
62+
for name, predicted in pt_outputs.items():
63+
if name == 'image':
64+
continue
65+
pred_frfp = [f.permute(0, 2, 3, 1).detach().cpu().numpy() for f in predicted]
66+
true_frfp = [f.numpy() for f in tf_outputs[name]]
67+
68+
for i, (pred, true) in enumerate(zip(pred_frfp, true_frfp)):
69+
assert pred.shape == true.shape, f'{name} {i} shape mismatch {pred.shape} != {true.shape}'
70+
error = np.max(np.abs(pred - true))
71+
max_error = max(max_error, error)
72+
assert error < 1, f'{name} {i} max error: {error}'
73+
print('Max intermediate error:', max_error)
74+
75+
76+
def test_model(interpolator, model, half=False, gpu=False):
77+
torch.manual_seed(0)
78+
time = torch.full((1, 1), .5)
79+
x0 = torch.rand(1, 3, 256, 256)
80+
x1 = torch.rand(1, 3, 256, 256)
81+
82+
x0_ = tf.convert_to_tensor(x0.permute(0, 2, 3, 1).numpy(), dtype=tf.float32)
83+
x1_ = tf.convert_to_tensor(x1.permute(0, 2, 3, 1).numpy(), dtype=tf.float32)
84+
time_ = tf.convert_to_tensor(time.numpy(), dtype=tf.float32)
85+
tf_outputs = model({'x0': x0_, 'x1': x1_, 'time': time_}, training=False)
86+
87+
if half:
88+
x0 = x0.half()
89+
x1 = x1.half()
90+
time = time.half()
91+
92+
if gpu and torch.cuda.is_available():
93+
x0 = x0.cuda()
94+
x1 = x1.cuda()
95+
time = time.cuda()
96+
97+
with torch.no_grad():
98+
pt_outputs = interpolator.debug_forward(x0, x1, time)
99+
100+
verify_debug_outputs(pt_outputs, tf_outputs)
101+
102+
with torch.no_grad():
103+
prediction = interpolator(x0, x1, time)
104+
output_color = prediction.permute(0, 2, 3, 1).detach().cpu().numpy()
105+
true_color = tf_outputs['image'].numpy()
106+
error = np.abs(output_color - true_color).max()
107+
108+
print('Color max error:', error)
109+
110+
111+
def main(model_path, save_path, export_to_torchscript=True, use_gpu=False, fp16=True, skiptest=False):
112+
print(f'Exporting model to FP{["32", "16"][fp16]} {["state_dict", "torchscript"][export_to_torchscript]} '
113+
f'using {"CG"[use_gpu]}PU')
114+
model = tf.compat.v2.saved_model.load(model_path)
115+
interpolator = Interpolator()
116+
interpolator.eval()
117+
import_state_dict(interpolator, model)
118+
119+
if use_gpu and torch.cuda.is_available():
120+
interpolator = interpolator.cuda()
121+
else:
122+
if fp16 and use_gpu:
123+
warnings.warn('No GPU is available, using CPU FP32', UserWarning)
124+
fp16 = False
125+
126+
if fp16:
127+
interpolator = interpolator.half()
128+
if export_to_torchscript:
129+
interpolator = torch.jit.script(interpolator)
130+
131+
if not skiptest:
132+
test_model(interpolator, model, fp16, use_gpu)
133+
134+
if export_to_torchscript:
135+
interpolator.save(save_path)
136+
else:
137+
torch.save(interpolator.state_dict(), save_path)
138+
139+
140+
if __name__ == '__main__':
141+
import argparse
142+
143+
parser = argparse.ArgumentParser(description='Export frame-interpolator model to PyTorch state dict')
144+
145+
parser.add_argument('model_path', type=str, help='Path to the TF SavedModel')
146+
parser.add_argument('save_path', type=str, help='Path to save the PyTorch state dict')
147+
parser.add_argument('--statedict', action='store_true', help='Export to state dict instead of TorchScript')
148+
parser.add_argument('--fp32', action='store_true', help='Save at full precision')
149+
parser.add_argument('--skiptest', action='store_true', help='Save at full precision')
150+
parser.add_argument('--gpu', action='store_true', help='Use GPU')
151+
152+
args = parser.parse_args()
153+
154+
main(args.model_path, args.save_path, not args.statedict, args.gpu, not args.fp32, args.skiptest)

feature_extractor.py

Lines changed: 155 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,155 @@
1+
"""PyTorch layer for extracting image features for the film_net interpolator.
2+
3+
The feature extractor implemented here converts an image pyramid into a pyramid
4+
of deep features. The feature pyramid serves a similar purpose as U-Net
5+
architecture's encoder, but we use a special cascaded architecture described in
6+
Multi-view Image Fusion [1].
7+
8+
For comprehensiveness, below is a short description of the idea. While the
9+
description is a bit involved, the cascaded feature pyramid can be used just
10+
like any image feature pyramid.
11+
12+
Why cascaded architeture?
13+
=========================
14+
To understand the concept it is worth reviewing a traditional feature pyramid
15+
first: *A traditional feature pyramid* as in U-net or in many optical flow
16+
networks is built by alternating between convolutions and pooling, starting
17+
from the input image.
18+
19+
It is well known that early features of such architecture correspond to low
20+
level concepts such as edges in the image whereas later layers extract
21+
semantically higher level concepts such as object classes etc. In other words,
22+
the meaning of the filters in each resolution level is different. For problems
23+
such as semantic segmentation and many others this is a desirable property.
24+
25+
However, the asymmetric features preclude sharing weights across resolution
26+
levels in the feature extractor itself and in any subsequent neural networks
27+
that follow. This can be a downside, since optical flow prediction, for
28+
instance is symmetric across resolution levels. The cascaded feature
29+
architecture addresses this shortcoming.
30+
31+
How is it built?
32+
================
33+
The *cascaded* feature pyramid contains feature vectors that have constant
34+
length and meaning on each resolution level, except few of the finest ones. The
35+
advantage of this is that the subsequent optical flow layer can learn
36+
synergically from many resolutions. This means that coarse level prediction can
37+
benefit from finer resolution training examples, which can be useful with
38+
moderately sized datasets to avoid overfitting.
39+
40+
The cascaded feature pyramid is built by extracting shallower subtree pyramids,
41+
each one of them similar to the traditional architecture. Each subtree
42+
pyramid S_i is extracted starting from each resolution level:
43+
44+
image resolution 0 -> S_0
45+
image resolution 1 -> S_1
46+
image resolution 2 -> S_2
47+
...
48+
49+
If we denote the features at level j of subtree i as S_i_j, the cascaded pyramid
50+
is constructed by concatenating features as follows (assuming subtree depth=3):
51+
52+
lvl
53+
feat_0 = concat( S_0_0 )
54+
feat_1 = concat( S_1_0 S_0_1 )
55+
feat_2 = concat( S_2_0 S_1_1 S_0_2 )
56+
feat_3 = concat( S_3_0 S_2_1 S_1_2 )
57+
feat_4 = concat( S_4_0 S_3_1 S_2_2 )
58+
feat_5 = concat( S_5_0 S_4_1 S_3_2 )
59+
....
60+
61+
In above, all levels except feat_0 and feat_1 have the same number of features
62+
with similar semantic meaning. This enables training a single optical flow
63+
predictor module shared by levels 2,3,4,5... . For more details and evaluation
64+
see [1].
65+
66+
[1] Multi-view Image Fusion, Trinidad et al. 2019
67+
"""
68+
from typing import List
69+
70+
import torch
71+
from torch import nn
72+
from torch.nn import functional as F
73+
from util import conv
74+
75+
76+
class SubTreeExtractor(nn.Module):
77+
"""Extracts a hierarchical set of features from an image.
78+
79+
This is a conventional, hierarchical image feature extractor, that extracts
80+
[k, k*2, k*4... ] filters for the image pyramid where k=options.sub_levels.
81+
Each level is followed by average pooling.
82+
"""
83+
84+
def __init__(self, in_channels=3, channels=64, n_layers=4):
85+
super().__init__()
86+
convs = []
87+
for i in range(n_layers):
88+
convs.append(nn.Sequential(
89+
conv(in_channels, (channels << i), 3),
90+
conv((channels << i), (channels << i), 3)
91+
))
92+
in_channels = channels << i
93+
self.convs = nn.ModuleList(convs)
94+
95+
def forward(self, image: torch.Tensor, n: int) -> List[torch.Tensor]:
96+
"""Extracts a pyramid of features from the image.
97+
98+
Args:
99+
image: TORCH.Tensor with shape BATCH_SIZE x HEIGHT x WIDTH x CHANNELS.
100+
n: number of pyramid levels to extract. This can be less or equal to
101+
options.sub_levels given in the __init__.
102+
Returns:
103+
The pyramid of features, starting from the finest level. Each element
104+
contains the output after the last convolution on the corresponding
105+
pyramid level.
106+
"""
107+
head = image
108+
pyramid = []
109+
for i, layer in enumerate(self.convs):
110+
head = layer(head)
111+
pyramid.append(head)
112+
if i < n - 1:
113+
head = F.avg_pool2d(head, kernel_size=2, stride=2)
114+
return pyramid
115+
116+
117+
class FeatureExtractor(nn.Module):
118+
"""Extracts features from an image pyramid using a cascaded architecture.
119+
"""
120+
121+
def __init__(self, in_channels=3, channels=64, sub_levels=4):
122+
super().__init__()
123+
self.extract_sublevels = SubTreeExtractor(in_channels, channels, sub_levels)
124+
self.sub_levels = sub_levels
125+
126+
def forward(self, image_pyramid: List[torch.Tensor]) -> List[torch.Tensor]:
127+
"""Extracts a cascaded feature pyramid.
128+
129+
Args:
130+
image_pyramid: Image pyramid as a list, starting from the finest level.
131+
Returns:
132+
A pyramid of cascaded features.
133+
"""
134+
sub_pyramids: List[List[torch.Tensor]] = []
135+
for i in range(len(image_pyramid)):
136+
# At each level of the image pyramid, creates a sub_pyramid of features
137+
# with 'sub_levels' pyramid levels, re-using the same SubTreeExtractor.
138+
# We use the same instance since we want to share the weights.
139+
#
140+
# However, we cap the depth of the sub_pyramid so we don't create features
141+
# that are beyond the coarsest level of the cascaded feature pyramid we
142+
# want to generate.
143+
capped_sub_levels = min(len(image_pyramid) - i, self.sub_levels)
144+
sub_pyramids.append(self.extract_sublevels(image_pyramid[i], capped_sub_levels))
145+
# Below we generate the cascades of features on each level of the feature
146+
# pyramid. Assuming sub_levels=3, The layout of the features will be
147+
# as shown in the example on file documentation above.
148+
feature_pyramid: List[torch.Tensor] = []
149+
for i in range(len(image_pyramid)):
150+
features = sub_pyramids[i][0]
151+
for j in range(1, self.sub_levels):
152+
if j <= i:
153+
features = torch.cat([features, sub_pyramids[i - j][j]], dim=1)
154+
feature_pyramid.append(features)
155+
return feature_pyramid

0 commit comments

Comments
 (0)