Skip to content

Commit

Permalink
Merge pull request #305 from sarlinpe/convert-tf
Browse files Browse the repository at this point in the history
Convert TensorFlow checkpoint to PyTorch
  • Loading branch information
rpautrat authored Sep 26, 2023
2 parents 1742343 + 8b8e9e7 commit e5a006e
Show file tree
Hide file tree
Showing 5 changed files with 678 additions and 6 deletions.
503 changes: 503 additions & 0 deletions convert_to_pytorch.ipynb

Large diffs are not rendered by default.

12 changes: 7 additions & 5 deletions superpoint/models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def __init__(self, data={}, n_gpus=1, data_shape=None, **config):
assert r in self.config, 'Required configuration entry: \'{}\''.format(r)
assert set(self.datasets) <= self.dataset_names, \
'Unknown dataset name: {}'.format(set(self.datasets)-self.dataset_names)
assert n_gpus > 0, 'TODO: CPU-only training is currently not supported.'
assert n_gpus >= 0

with tf.variable_scope(self.name, reuse=tf.AUTO_REUSE):
self._build_graph()
Expand All @@ -137,19 +137,21 @@ def _shard_nested_dict(self, d, num):

def _gpu_tower(self, data, mode, batch_size):
# Split the batch between the GPUs (data parallelism)
n_shards = max(1, self.n_gpus)
device = 'cpu' if self.n_gpus == 0 else 'gpu'
with tf.device('/cpu:0'):
with tf.name_scope('{}_data_sharding'.format(mode)):
shards = self._unstack_nested_dict(data, batch_size*self.n_gpus)
shards = self._shard_nested_dict(shards, self.n_gpus)
shards = self._unstack_nested_dict(data, batch_size*n_shards)
shards = self._shard_nested_dict(shards, n_shards)

# Create towers, i.e. copies of the model for each GPU,
# with their own loss and gradients.
tower_losses = []
tower_gradvars = []
tower_preds = []
tower_metrics = []
for i in range(self.n_gpus):
worker = '/gpu:{}'.format(i)
for i in range(n_shards):
worker = '/{}:{}'.format(device, i)
device_setter = tf.train.replica_device_setter(
worker_device=worker, ps_device='/cpu:0', ps_tasks=1)
with tf.name_scope('{}_tower{}'.format(mode, i)) as scope:
Expand Down
2 changes: 1 addition & 1 deletion superpoint/utils/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ def dict_update(d, u):
The updated dictionary.
"""
for k, v in u.items():
if isinstance(v, collections.Mapping):
if isinstance(v, collections.abc.Mapping):
d[k] = dict_update(d.get(k, {}), v)
else:
d[k] = v
Expand Down
167 changes: 167 additions & 0 deletions superpoint_pytorch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
"""PyTorch implementation of the SuperPoint model,
derived from the TensorFlow re-implementation (2018).
Authors: Rémi Pautrat, Paul-Edouard Sarlin
"""
import torch.nn as nn
import torch
from collections import OrderedDict
from types import SimpleNamespace


def sample_descriptors(keypoints, descriptors, s: int = 8):
"""Interpolate descriptors at keypoint locations"""
b, c, h, w = descriptors.shape
keypoints = (keypoints + 0.5) / (keypoints.new_tensor([w, h]) * s)
keypoints = keypoints * 2 - 1 # normalize to (-1, 1)
descriptors = torch.nn.functional.grid_sample(
descriptors, keypoints.view(b, 1, -1, 2), mode="bilinear", align_corners=False
)
descriptors = torch.nn.functional.normalize(
descriptors.reshape(b, c, -1), p=2, dim=1
)
return descriptors


def batched_nms(scores, nms_radius: int):
assert nms_radius >= 0

def max_pool(x):
return torch.nn.functional.max_pool2d(
x, kernel_size=nms_radius * 2 + 1, stride=1, padding=nms_radius
)

zeros = torch.zeros_like(scores)
max_mask = scores == max_pool(scores)
for _ in range(2):
supp_mask = max_pool(max_mask.float()) > 0
supp_scores = torch.where(supp_mask, zeros, scores)
new_max_mask = supp_scores == max_pool(supp_scores)
max_mask = max_mask | (new_max_mask & (~supp_mask))
return torch.where(max_mask, scores, zeros)


def select_top_k_keypoints(keypoints, scores, k):
if k >= len(keypoints):
return keypoints, scores
scores, indices = torch.topk(scores, k, dim=0, sorted=True)
return keypoints[indices], scores


class VGGBlock(nn.Sequential):
def __init__(self, c_in, c_out, kernel_size, relu=True):
padding = (kernel_size - 1) // 2
conv = nn.Conv2d(
c_in, c_out, kernel_size=kernel_size, stride=1, padding=padding
)
activation = nn.ReLU(inplace=True) if relu else nn.Identity()
bn = nn.BatchNorm2d(c_out, eps=0.001)
super().__init__(
OrderedDict(
[
("conv", conv),
("activation", activation),
("bn", bn),
]
)
)


class SuperPoint(nn.Module):
default_conf = {
"descriptor_dim": 256,
"nms_radius": 4,
"max_num_keypoints": None,
"detection_threshold": 0.005,
"remove_borders": 4,
"descriptor_dim": 256,
"channels": [64, 64, 128, 128, 256],
}

def __init__(self, **conf):
super().__init__()
conf = {**self.default_conf, **conf}
self.conf = SimpleNamespace(**conf)
self.stride = 2 ** (len(self.conf.channels) - 2)
channels = [1, *self.conf.channels[:-1]]

backbone = []
for i, c in enumerate(channels[1:], 1):
layers = [VGGBlock(channels[i - 1], c, 3), VGGBlock(c, c, 3)]
if i < len(channels) - 1:
layers.append(nn.MaxPool2d(kernel_size=2, stride=2))
backbone.append(nn.Sequential(*layers))
self.backbone = nn.Sequential(*backbone)

c = self.conf.channels[-1]
self.detector = nn.Sequential(
VGGBlock(channels[-1], c, 3),
VGGBlock(c, self.stride**2 + 1, 1, relu=False),
)
self.descriptor = nn.Sequential(
VGGBlock(channels[-1], c, 3),
VGGBlock(c, self.conf.descriptor_dim, 1, relu=False),
)

def forward(self, data):
image = data["image"]
if image.shape[1] == 3: # RGB to gray
scale = image.new_tensor([0.299, 0.587, 0.114]).view(1, 3, 1, 1)
image = (image * scale).sum(1, keepdim=True)

features = self.backbone(image)
descriptors_dense = torch.nn.functional.normalize(
self.descriptor(features), p=2, dim=1
)

# Decode the detection scores
scores = self.detector(features)
scores = torch.nn.functional.softmax(scores, 1)[:, :-1]
b, _, h, w = scores.shape
scores = scores.permute(0, 2, 3, 1).reshape(b, h, w, self.stride, self.stride)
scores = scores.permute(0, 1, 3, 2, 4).reshape(
b, h * self.stride, w * self.stride
)
scores = batched_nms(scores, self.conf.nms_radius)

# Discard keypoints near the image borders
if self.conf.remove_borders:
pad = self.conf.remove_borders
scores[:, :pad] = -1
scores[:, :, :pad] = -1
scores[:, -pad:] = -1
scores[:, :, -pad:] = -1

# Extract keypoints
if b > 1:
idxs = torch.where(scores > self.conf.detection_threshold)
mask = idxs[0] == torch.arange(b, device=scores.device)[:, None]
else: # Faster shortcut
scores = scores.squeeze(0)
idxs = torch.where(scores > self.conf.detection_threshold)

# Convert (i, j) to (x, y)
keypoints_all = torch.stack(idxs[-2:], dim=-1).flip(1).float()
scores_all = scores[idxs]

keypoints = []
scores = []
descriptors = []
for i in range(b):
if b > 1:
k = keypoints_all[mask[i]]
s = scores_all[mask[i]]
else:
k = keypoints_all
s = scores_all
if self.conf.max_num_keypoints is not None:
k, s = select_top_k_keypoints(k, s, self.conf.max_num_keypoints)
d = sample_descriptors(k[None], descriptors_dense[i, None], self.stride)
keypoints.append(k)
scores.append(s)
descriptors.append(d.squeeze(0).transpose(0, 1))

return {
"keypoints": keypoints,
"keypoint_scores": scores,
"descriptors": descriptors,
}
Binary file added weights/superpoint_v6_from_tf.pth
Binary file not shown.

0 comments on commit e5a006e

Please sign in to comment.