-
Notifications
You must be signed in to change notification settings - Fork 34
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
7 changed files
with
490 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
model_path: 'pretrained/aslfeat/model.ckpt-380000' | ||
hseq: | ||
root: "/local/hpatches-sequences-release" | ||
seq: ['v', 'i'] | ||
start_idx: 1 | ||
max_dim: 2000 | ||
ignored_v: ['v_artisans', 'v_astronautis', 'v_talent'] | ||
ignored_i: ['i_contruction', 'i_crownnight', 'i_dc', 'i_pencils', 'i_whitebuilding'] | ||
eval: | ||
err_thld: 3 | ||
net: | ||
max_dim: 2000 | ||
config: | ||
kpt_n: 5000 | ||
kpt_refinement: true | ||
deform_desc: 1 | ||
score_thld: 0.5 | ||
edge_thld: 10 | ||
multi_scale: false | ||
multi_level: true | ||
nms_size: 3 | ||
eof_mask: 5 | ||
need_norm: true | ||
use_peakiness: true |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
data_name: 'tat' | ||
data_list: ['Ignatius', 'Barn', 'Caterpillar', 'Church', 'Courthouse', 'Meetingroom', 'Truck'] | ||
data_root: '/local/Tanks_and_Temples' | ||
dump_root: '/local/Tanks_and_Temples' | ||
truncate: [0, null] | ||
model_path: 'pretrained/aslfeat/model.ckpt-380000' | ||
overwrite: true | ||
net: | ||
max_dim: 2048 | ||
config: | ||
kpt_n: 10000 | ||
kpt_refinement: true | ||
deform_desc: 1 | ||
score_thld: 0.5 | ||
edge_thld: 10 | ||
multi_scale: true | ||
multi_level: true | ||
nms_size: 3 | ||
eof_mask: 5 | ||
need_norm: true | ||
use_peakiness: true |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,89 @@ | ||
import os | ||
import tensorflow as tf | ||
import numpy as np | ||
from struct import pack | ||
|
||
from utils.common import Notify | ||
from .base_dataset import BaseDataset | ||
|
||
|
||
def write_feature_repo(kpt_coord, desc, output_path): | ||
assert desc.dtype == np.uint8 | ||
if kpt_coord.size > 1: | ||
num_features = kpt_coord.shape[0] | ||
else: | ||
num_features = 0 | ||
|
||
feature_name = 1413892435 | ||
loc_dim = 5 | ||
des_dim = 128 | ||
|
||
head = np.stack((feature_name, 0, num_features, loc_dim, des_dim)) | ||
head = pack('5i', *head) | ||
|
||
if num_features > 0: | ||
zero_pad = np.ones((num_features, 3)) * 2 | ||
kpt_coord = np.concatenate((kpt_coord, zero_pad), axis=-1).astype(np.float32) | ||
kpt_coord = pack('f' * loc_dim * num_features, *(kpt_coord.flatten())) | ||
|
||
desc = pack('B' * des_dim * num_features, *(desc.flatten())) | ||
|
||
level_num = 1 | ||
per_level_num = num_features | ||
|
||
level_pack = np.stack((level_num, per_level_num)) | ||
level_num = pack('2i', *level_pack) | ||
|
||
with open(output_path, 'wb') as fout: | ||
fout.write(head) | ||
if num_features > 0: | ||
fout.write(kpt_coord) | ||
fout.write(desc) | ||
fout.write(level_num) | ||
|
||
|
||
class Tat(BaseDataset): | ||
default_config = { | ||
'num_parallel_calls': 10, 'truncate': None | ||
} | ||
|
||
def _init_dataset(self, **config): | ||
print(Notify.INFO, "Initializing dataset:", config['data_name'], Notify.ENDC) | ||
base_path = config['data_root'] | ||
|
||
img_paths = [] | ||
dump_paths = [] | ||
|
||
for d in config['data_list']: | ||
image_list = os.path.join(base_path, d, 'output', 'preprocess', 'image_list.txt') | ||
dfeat_folder = os.path.join(base_path, d, 'output', 'preprocess', 'dfeat') | ||
if not os.path.exists(dfeat_folder): | ||
os.mkdir(dfeat_folder) | ||
if not os.path.exists(image_list): | ||
exit(-1) | ||
tmp_img_paths = open(image_list).read().splitlines() | ||
img_paths.extend(tmp_img_paths) | ||
for idx, _ in enumerate(tmp_img_paths): | ||
basename = str(idx).strip().zfill(8) + '.dfeat' | ||
dump_paths.append(os.path.join(dfeat_folder, basename)) | ||
|
||
tf.data.Dataset.map_parallel = lambda self, fn: self.map( | ||
fn, num_parallel_calls=config['num_parallel_calls']) | ||
|
||
if config['truncate'] is not None: | ||
print(Notify.WARNING, "Truncate from", | ||
config['truncate'][0], "to", config['truncate'][1], Notify.ENDC) | ||
img_paths = img_paths[config['truncate'][0]:config['truncate'][1]] | ||
dump_paths = dump_paths[config['truncate'][0]:config['truncate'][1]] | ||
|
||
self.data_length = len(img_paths) | ||
|
||
files = {'image_paths': img_paths, 'dump_paths': dump_paths} | ||
return files | ||
|
||
def _format_data(self, data): | ||
dump_path = data['dump_path'].decode('utf-8') | ||
desc = data['dump_data'][0] | ||
desc = (desc * 128 + 128).astype(np.uint8) | ||
kpt = data['dump_data'][1] | ||
write_feature_repo(kpt, desc, dump_path) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,160 @@ | ||
#!/usr/bin/env python3 | ||
""" | ||
Copyright 2017, Zixin Luo, HKUST. | ||
Inference script. | ||
""" | ||
|
||
import os | ||
|
||
from queue import Queue | ||
from threading import Thread | ||
|
||
import math | ||
import yaml | ||
|
||
import cv2 | ||
import numpy as np | ||
|
||
import tensorflow as tf | ||
|
||
from models import get_model | ||
from utils.hseq_utils import HSeqUtils | ||
from utils.evaluator import Evaluator | ||
|
||
FLAGS = tf.compat.v1.app.flags.FLAGS | ||
|
||
# general config. | ||
tf.compat.v1.app.flags.DEFINE_string('config', None, """Path to the configuration file.""") | ||
|
||
|
||
def loader(hseq_utils, producer_queue): | ||
for seq_idx in range(hseq_utils.seq_num): | ||
seq_name, hseq_data = hseq_utils.get_data(seq_idx) | ||
|
||
for i in range(6): | ||
gt_homo = [seq_idx, seq_name, hseq_data.scaling] if i == 0 else hseq_data.homo[i] | ||
producer_queue.put([hseq_data.img[i], gt_homo]) | ||
producer_queue.put(None) | ||
|
||
def extractor(patch_queue, model, consumer_queue): | ||
while True: | ||
queue_data = patch_queue.get() | ||
if queue_data is None: | ||
consumer_queue.put(None) | ||
return | ||
img, gt_homo = queue_data | ||
gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) | ||
H, W = gray.shape | ||
descs, kpts, _ = model.run_test_data(np.expand_dims(gray, axis=-1)) | ||
consumer_queue.put([img, kpts, descs, gt_homo]) | ||
patch_queue.task_done() | ||
|
||
def matcher(consumer_queue, sess, evaluator, config): | ||
record = [] | ||
while True: | ||
queue_data = consumer_queue.get() | ||
if queue_data is None: | ||
return | ||
record.append(queue_data) | ||
if len(record) < 6: | ||
continue | ||
ref_img, ref_kpts, ref_descs, seq_info = record[0] | ||
|
||
eval_stats = np.array((0, 0, 0, 0, 0, 0, 0), np.float32) | ||
|
||
seq_idx = seq_info[0] | ||
seq_name = seq_info[1] | ||
scaling = seq_info[2] | ||
print(seq_idx, seq_name) | ||
|
||
for i in range(1, 6): | ||
test_img, test_kpts, test_descs, gt_homo = record[i] | ||
# get MMA | ||
num_feat = min(ref_kpts.shape[0], test_kpts.shape[0]) | ||
if num_feat > 0: | ||
mma_putative_matches = evaluator.feature_matcher( | ||
sess, ref_descs, test_descs, test_kpts) | ||
else: | ||
mma_putative_matches = [] | ||
mma_inlier_matches = evaluator.get_inlier_matches( | ||
ref_kpts, test_kpts, mma_putative_matches, gt_homo, scaling) | ||
num_mma_putative = len(mma_putative_matches) | ||
num_mma_inlier = len(mma_inlier_matches) | ||
# get covisible keypoints | ||
ref_mask, test_mask = evaluator.get_covisible_mask(ref_kpts, test_kpts, | ||
ref_img.shape, test_img.shape, | ||
gt_homo) | ||
cov_ref_coord, cov_test_coord = ref_kpts[ref_mask], test_kpts[test_mask] | ||
cov_ref_feat, cov_test_feat = ref_descs[ref_mask], test_descs[test_mask] | ||
num_cov_feat = min(cov_ref_coord.shape[0], cov_test_coord.shape[0]) | ||
# get gt matches | ||
gt_num = evaluator.get_gt_matches(cov_ref_coord, cov_test_coord, gt_homo, scaling) | ||
# establish putative matches | ||
if num_cov_feat > 0: | ||
putative_matches = evaluator.feature_matcher( | ||
sess, cov_ref_feat, cov_test_feat, cov_test_coord) | ||
else: | ||
putative_matches = [] | ||
num_putative = max(len(putative_matches), 1) | ||
# get inlier matches | ||
inlier_matches = evaluator.get_inlier_matches( | ||
cov_ref_coord, cov_test_coord, putative_matches, gt_homo, scaling) | ||
num_inlier = len(inlier_matches) | ||
|
||
eval_stats += np.array((1, # counter | ||
num_feat, # feature number | ||
gt_num / max(num_cov_feat, 1), # repeatability | ||
num_inlier / max(num_putative, 1), # precision | ||
num_inlier / max(num_cov_feat, 1), # matching score | ||
num_inlier / max(gt_num, 1), # recall | ||
num_mma_inlier / max(num_mma_putative, 1))) / 5 # MMA | ||
|
||
print(int(eval_stats[1]), eval_stats[2:]) | ||
evaluator.stats['all_eval_stats'] += eval_stats | ||
if os.path.basename(seq_name)[0] == 'i': | ||
evaluator.stats['i_eval_stats'] += eval_stats | ||
if os.path.basename(seq_name)[0] == 'v': | ||
evaluator.stats['v_eval_stats'] += eval_stats | ||
|
||
exit() | ||
record = [] | ||
|
||
def hseq_eval(): | ||
with open(FLAGS.config, 'r') as f: | ||
test_config = yaml.load(f, Loader=yaml.FullLoader) | ||
# Configure dataset | ||
hseq_utils = HSeqUtils(test_config['hseq']) | ||
# Configure evaluation | ||
evaluator = Evaluator(test_config['eval']) | ||
# Construct inference networks. | ||
model = get_model('feat_model')(test_config['model_path'], **(test_config['net'])) | ||
# Create the initializier. | ||
config = tf.compat.v1.ConfigProto() | ||
config.gpu_options.allow_growth = True | ||
|
||
producer_queue = Queue(maxsize=18) | ||
consumer_queue = Queue() | ||
|
||
producer0 = Thread(target=loader, args=(hseq_utils, producer_queue)) | ||
producer0.daemon = True | ||
producer0.start() | ||
|
||
producer1 = Thread(target=extractor, args=(producer_queue, model, consumer_queue)) | ||
producer1.daemon = True | ||
producer1.start() | ||
|
||
consumer = Thread(target=matcher, args=(consumer_queue, model.sess, evaluator, test_config['eval'])) | ||
consumer.daemon = True | ||
consumer.start() | ||
|
||
producer0.join() | ||
producer1.join() | ||
consumer.join() | ||
|
||
evaluator.print_stats('i_eval_stats') | ||
evaluator.print_stats('v_eval_stats') | ||
evaluator.print_stats('all_eval_stats') | ||
|
||
if __name__ == '__main__': | ||
tf.compat.v1.flags.mark_flags_as_required(['config']) | ||
hseq_eval() |
Oops, something went wrong.