Skip to content

Commit

Permalink
Initial commit
Browse files Browse the repository at this point in the history
  • Loading branch information
dajes committed Jan 24, 2023
1 parent 76f4670 commit d882239
Show file tree
Hide file tree
Showing 12 changed files with 1,016 additions and 2 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
/photos/output.mp4
10 changes: 8 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,2 +1,8 @@
# frame-interpolation-pytorch
PyTorch implementation of FILM: Frame Interpolation for Large Motion, In ECCV 2022.
# Frame interpolation in PyTorch

This is a PyTorch inference implementation
of [FILM: Frame Interpolation for Large Motion, In ECCV 2022](https://film-net.github.io/).\
[Original repository link](https://github.com/google-research/frame-interpolation)

The project is focused on creating simple and TorchScript compilable inference interface for the original pretrained TF2
model.
154 changes: 154 additions & 0 deletions export.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
import warnings

import numpy as np
import tensorflow as tf
import torch

from interpolator import Interpolator


def translate_state_dict(var_dict, state_dict):
for name, (prev_name, weight) in zip(state_dict, var_dict.items()):
print('Mapping', prev_name, '->', name)
weight = torch.from_numpy(weight)
if 'kernel' in prev_name:
# Transpose the conv2d kernel weights, since TF uses (H, W, C, K) and PyTorch uses (K, C, H, W)
weight = weight.permute(3, 2, 0, 1)

assert state_dict[name].shape == weight.shape, f'Shape mismatch {state_dict[name].shape} != {weight.shape}'

state_dict[name] = weight


def import_state_dict(interpolator: Interpolator, saved_model):
variables = saved_model.keras_api.variables

extract_dict = interpolator.extract.state_dict()
flow_dict = interpolator.predict_flow.state_dict()
fuse_dict = interpolator.fuse.state_dict()

extract_vars = {}
_flow_vars = {}
_fuse_vars = {}

for var in variables:
name = var.name
if name.startswith('feat_net'):
extract_vars[name[9:]] = var.numpy()
elif name.startswith('predict_flow'):
_flow_vars[name[13:]] = var.numpy()
elif name.startswith('fusion'):
_fuse_vars[name[7:]] = var.numpy()

# reverse order of modules to allow jit export
# TODO: improve this hack
flow_vars = dict(sorted(_flow_vars.items(), key=lambda x: x[0].split('/')[0], reverse=True))
fuse_vars = dict(sorted(_fuse_vars.items(), key=lambda x: int((x[0].split('/')[0].split('_')[1:] or [0])[0]) // 3, reverse=True))

assert len(extract_vars) == len(extract_dict), f'{len(extract_vars)} != {len(extract_dict)}'
assert len(flow_vars) == len(flow_dict), f'{len(flow_vars)} != {len(flow_dict)}'
assert len(fuse_vars) == len(fuse_dict), f'{len(fuse_vars)} != {len(fuse_dict)}'

for state_dict, var_dict in ((extract_dict, extract_vars), (flow_dict, flow_vars), (fuse_dict, fuse_vars)):
translate_state_dict(var_dict, state_dict)

interpolator.extract.load_state_dict(extract_dict)
interpolator.predict_flow.load_state_dict(flow_dict)
interpolator.fuse.load_state_dict(fuse_dict)


def verify_debug_outputs(pt_outputs, tf_outputs):
max_error = 0
for name, predicted in pt_outputs.items():
if name == 'image':
continue
pred_frfp = [f.permute(0, 2, 3, 1).detach().cpu().numpy() for f in predicted]
true_frfp = [f.numpy() for f in tf_outputs[name]]

for i, (pred, true) in enumerate(zip(pred_frfp, true_frfp)):
assert pred.shape == true.shape, f'{name} {i} shape mismatch {pred.shape} != {true.shape}'
error = np.max(np.abs(pred - true))
max_error = max(max_error, error)
assert error < 1, f'{name} {i} max error: {error}'
print('Max intermediate error:', max_error)


def test_model(interpolator, model, half=False, gpu=False):
torch.manual_seed(0)
time = torch.full((1, 1), .5)
x0 = torch.rand(1, 3, 256, 256)
x1 = torch.rand(1, 3, 256, 256)

x0_ = tf.convert_to_tensor(x0.permute(0, 2, 3, 1).numpy(), dtype=tf.float32)
x1_ = tf.convert_to_tensor(x1.permute(0, 2, 3, 1).numpy(), dtype=tf.float32)
time_ = tf.convert_to_tensor(time.numpy(), dtype=tf.float32)
tf_outputs = model({'x0': x0_, 'x1': x1_, 'time': time_}, training=False)

if half:
x0 = x0.half()
x1 = x1.half()
time = time.half()

if gpu and torch.cuda.is_available():
x0 = x0.cuda()
x1 = x1.cuda()
time = time.cuda()

with torch.no_grad():
pt_outputs = interpolator.debug_forward(x0, x1, time)

verify_debug_outputs(pt_outputs, tf_outputs)

with torch.no_grad():
prediction = interpolator(x0, x1, time)
output_color = prediction.permute(0, 2, 3, 1).detach().cpu().numpy()
true_color = tf_outputs['image'].numpy()
error = np.abs(output_color - true_color).max()

print('Color max error:', error)


def main(model_path, save_path, export_to_torchscript=True, use_gpu=False, fp16=True, skiptest=False):
print(f'Exporting model to FP{["32", "16"][fp16]} {["state_dict", "torchscript"][export_to_torchscript]} '
f'using {"CG"[use_gpu]}PU')
model = tf.compat.v2.saved_model.load(model_path)
interpolator = Interpolator()
interpolator.eval()
import_state_dict(interpolator, model)

if use_gpu and torch.cuda.is_available():
interpolator = interpolator.cuda()
else:
if fp16 and use_gpu:
warnings.warn('No GPU is available, using CPU FP32', UserWarning)
fp16 = False

if fp16:
interpolator = interpolator.half()
if export_to_torchscript:
interpolator = torch.jit.script(interpolator)

if not skiptest:
test_model(interpolator, model, fp16, use_gpu)

if export_to_torchscript:
interpolator.save(save_path)
else:
torch.save(interpolator.state_dict(), save_path)


if __name__ == '__main__':
import argparse

parser = argparse.ArgumentParser(description='Export frame-interpolator model to PyTorch state dict')

parser.add_argument('model_path', type=str, help='Path to the TF SavedModel')
parser.add_argument('save_path', type=str, help='Path to save the PyTorch state dict')
parser.add_argument('--statedict', action='store_true', help='Export to state dict instead of TorchScript')
parser.add_argument('--fp32', action='store_true', help='Save at full precision')
parser.add_argument('--skiptest', action='store_true', help='Save at full precision')
parser.add_argument('--gpu', action='store_true', help='Use GPU')

args = parser.parse_args()

main(args.model_path, args.save_path, not args.statedict, args.gpu, not args.fp32, args.skiptest)
155 changes: 155 additions & 0 deletions feature_extractor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
"""PyTorch layer for extracting image features for the film_net interpolator.
The feature extractor implemented here converts an image pyramid into a pyramid
of deep features. The feature pyramid serves a similar purpose as U-Net
architecture's encoder, but we use a special cascaded architecture described in
Multi-view Image Fusion [1].
For comprehensiveness, below is a short description of the idea. While the
description is a bit involved, the cascaded feature pyramid can be used just
like any image feature pyramid.
Why cascaded architeture?
=========================
To understand the concept it is worth reviewing a traditional feature pyramid
first: *A traditional feature pyramid* as in U-net or in many optical flow
networks is built by alternating between convolutions and pooling, starting
from the input image.
It is well known that early features of such architecture correspond to low
level concepts such as edges in the image whereas later layers extract
semantically higher level concepts such as object classes etc. In other words,
the meaning of the filters in each resolution level is different. For problems
such as semantic segmentation and many others this is a desirable property.
However, the asymmetric features preclude sharing weights across resolution
levels in the feature extractor itself and in any subsequent neural networks
that follow. This can be a downside, since optical flow prediction, for
instance is symmetric across resolution levels. The cascaded feature
architecture addresses this shortcoming.
How is it built?
================
The *cascaded* feature pyramid contains feature vectors that have constant
length and meaning on each resolution level, except few of the finest ones. The
advantage of this is that the subsequent optical flow layer can learn
synergically from many resolutions. This means that coarse level prediction can
benefit from finer resolution training examples, which can be useful with
moderately sized datasets to avoid overfitting.
The cascaded feature pyramid is built by extracting shallower subtree pyramids,
each one of them similar to the traditional architecture. Each subtree
pyramid S_i is extracted starting from each resolution level:
image resolution 0 -> S_0
image resolution 1 -> S_1
image resolution 2 -> S_2
...
If we denote the features at level j of subtree i as S_i_j, the cascaded pyramid
is constructed by concatenating features as follows (assuming subtree depth=3):
lvl
feat_0 = concat( S_0_0 )
feat_1 = concat( S_1_0 S_0_1 )
feat_2 = concat( S_2_0 S_1_1 S_0_2 )
feat_3 = concat( S_3_0 S_2_1 S_1_2 )
feat_4 = concat( S_4_0 S_3_1 S_2_2 )
feat_5 = concat( S_5_0 S_4_1 S_3_2 )
....
In above, all levels except feat_0 and feat_1 have the same number of features
with similar semantic meaning. This enables training a single optical flow
predictor module shared by levels 2,3,4,5... . For more details and evaluation
see [1].
[1] Multi-view Image Fusion, Trinidad et al. 2019
"""
from typing import List

import torch
from torch import nn
from torch.nn import functional as F
from util import conv


class SubTreeExtractor(nn.Module):
"""Extracts a hierarchical set of features from an image.
This is a conventional, hierarchical image feature extractor, that extracts
[k, k*2, k*4... ] filters for the image pyramid where k=options.sub_levels.
Each level is followed by average pooling.
"""

def __init__(self, in_channels=3, channels=64, n_layers=4):
super().__init__()
convs = []
for i in range(n_layers):
convs.append(nn.Sequential(
conv(in_channels, (channels << i), 3),
conv((channels << i), (channels << i), 3)
))
in_channels = channels << i
self.convs = nn.ModuleList(convs)

def forward(self, image: torch.Tensor, n: int) -> List[torch.Tensor]:
"""Extracts a pyramid of features from the image.
Args:
image: TORCH.Tensor with shape BATCH_SIZE x HEIGHT x WIDTH x CHANNELS.
n: number of pyramid levels to extract. This can be less or equal to
options.sub_levels given in the __init__.
Returns:
The pyramid of features, starting from the finest level. Each element
contains the output after the last convolution on the corresponding
pyramid level.
"""
head = image
pyramid = []
for i, layer in enumerate(self.convs):
head = layer(head)
pyramid.append(head)
if i < n - 1:
head = F.avg_pool2d(head, kernel_size=2, stride=2)
return pyramid


class FeatureExtractor(nn.Module):
"""Extracts features from an image pyramid using a cascaded architecture.
"""

def __init__(self, in_channels=3, channels=64, sub_levels=4):
super().__init__()
self.extract_sublevels = SubTreeExtractor(in_channels, channels, sub_levels)
self.sub_levels = sub_levels

def forward(self, image_pyramid: List[torch.Tensor]) -> List[torch.Tensor]:
"""Extracts a cascaded feature pyramid.
Args:
image_pyramid: Image pyramid as a list, starting from the finest level.
Returns:
A pyramid of cascaded features.
"""
sub_pyramids: List[List[torch.Tensor]] = []
for i in range(len(image_pyramid)):
# At each level of the image pyramid, creates a sub_pyramid of features
# with 'sub_levels' pyramid levels, re-using the same SubTreeExtractor.
# We use the same instance since we want to share the weights.
#
# However, we cap the depth of the sub_pyramid so we don't create features
# that are beyond the coarsest level of the cascaded feature pyramid we
# want to generate.
capped_sub_levels = min(len(image_pyramid) - i, self.sub_levels)
sub_pyramids.append(self.extract_sublevels(image_pyramid[i], capped_sub_levels))
# Below we generate the cascades of features on each level of the feature
# pyramid. Assuming sub_levels=3, The layout of the features will be
# as shown in the example on file documentation above.
feature_pyramid: List[torch.Tensor] = []
for i in range(len(image_pyramid)):
features = sub_pyramids[i][0]
for j in range(1, self.sub_levels):
if j <= i:
features = torch.cat([features, sub_pyramids[i - j][j]], dim=1)
feature_pyramid.append(features)
return feature_pyramid
Loading

0 comments on commit d882239

Please sign in to comment.