diff --git a/code/__pycache__/globalvar.cpython-37.pyc b/code/__pycache__/globalvar.cpython-37.pyc
new file mode 100644
index 0000000..e33dd9f
Binary files /dev/null and b/code/__pycache__/globalvar.cpython-37.pyc differ
diff --git a/code/api_controlers/__pycache__/__init__.cpython-37.pyc b/code/api_controlers/__pycache__/__init__.cpython-37.pyc
new file mode 100644
index 0000000..4db31ff
Binary files /dev/null and b/code/api_controlers/__pycache__/__init__.cpython-37.pyc differ
diff --git a/code/api_controlers/__pycache__/apis.cpython-37.pyc b/code/api_controlers/__pycache__/apis.cpython-37.pyc
new file mode 100644
index 0000000..9fdf7e3
Binary files /dev/null and b/code/api_controlers/__pycache__/apis.cpython-37.pyc differ
diff --git a/code/api_controlers/__pycache__/base_function.cpython-37.pyc b/code/api_controlers/__pycache__/base_function.cpython-37.pyc
new file mode 100644
index 0000000..af56423
Binary files /dev/null and b/code/api_controlers/__pycache__/base_function.cpython-37.pyc differ
diff --git a/code/api_controlers/__pycache__/delete_encode_ctl.cpython-37.pyc b/code/api_controlers/__pycache__/delete_encode_ctl.cpython-37.pyc
new file mode 100644
index 0000000..bf04c05
Binary files /dev/null and b/code/api_controlers/__pycache__/delete_encode_ctl.cpython-37.pyc differ
diff --git a/code/api_controlers/__pycache__/image_encode_ctl.cpython-37.pyc b/code/api_controlers/__pycache__/image_encode_ctl.cpython-37.pyc
new file mode 100644
index 0000000..b00822c
Binary files /dev/null and b/code/api_controlers/__pycache__/image_encode_ctl.cpython-37.pyc differ
diff --git a/code/api_controlers/__pycache__/image_search_ctl.cpython-37.pyc b/code/api_controlers/__pycache__/image_search_ctl.cpython-37.pyc
new file mode 100644
index 0000000..074d5c8
Binary files /dev/null and b/code/api_controlers/__pycache__/image_search_ctl.cpython-37.pyc differ
diff --git a/code/api_controlers/__pycache__/model_init_ctl.cpython-37.pyc b/code/api_controlers/__pycache__/model_init_ctl.cpython-37.pyc
new file mode 100644
index 0000000..b7e257c
Binary files /dev/null and b/code/api_controlers/__pycache__/model_init_ctl.cpython-37.pyc differ
diff --git a/code/api_controlers/__pycache__/semantic_localization_ctl.cpython-37.pyc b/code/api_controlers/__pycache__/semantic_localization_ctl.cpython-37.pyc
new file mode 100644
index 0000000..a4b72b2
Binary files /dev/null and b/code/api_controlers/__pycache__/semantic_localization_ctl.cpython-37.pyc differ
diff --git a/code/api_controlers/__pycache__/text_search_ctl.cpython-37.pyc b/code/api_controlers/__pycache__/text_search_ctl.cpython-37.pyc
new file mode 100644
index 0000000..b50f6ed
Binary files /dev/null and b/code/api_controlers/__pycache__/text_search_ctl.cpython-37.pyc differ
diff --git a/code/api_controlers/__pycache__/utils.cpython-37.pyc b/code/api_controlers/__pycache__/utils.cpython-37.pyc
new file mode 100644
index 0000000..ddcca41
Binary files /dev/null and b/code/api_controlers/__pycache__/utils.cpython-37.pyc differ
diff --git a/code/api_controlers/apis.py b/code/api_controlers/apis.py
index 4343c77..b361819 100644
--- a/code/api_controlers/apis.py
+++ b/code/api_controlers/apis.py
@@ -1,10 +1,9 @@
from flask import Flask, request, jsonify
import json, os
-import _thread
from threading import Timer
from api_controlers import base_function, image_encode_ctl, delete_encode_ctl,\
- crossmodal_search_ctl, image_search_ctl, semantic_localization_ctl, utils
+ text_search_ctl, image_search_ctl, semantic_localization_ctl, utils
def api_run(cfg):
app = Flask(__name__) # Flask 初始化
@@ -28,7 +27,7 @@ def delete_encode():
@app.route(cfg['apis']['text_search']['route'], methods=['post'])
def text_search():
request_data = json.loads(request.data.decode('utf-8'))
- return_json = crossmodal_search_ctl.text_search(request_data)
+ return_json = text_search_ctl.text_search(request_data)
return return_json
# 图像检索
diff --git a/code/api_controlers/delete_encode_ctl.py b/code/api_controlers/delete_encode_ctl.py
index e1ed738..bfadf50 100644
--- a/code/api_controlers/delete_encode_ctl.py
+++ b/code/api_controlers/delete_encode_ctl.py
@@ -15,6 +15,8 @@
def delete_encode(request_data):
logger.info("\nRequest json: {}".format(request_data))
+
+ request_data = [int(i) for i in request_data['deleteID'].split(",")]
if request_data != []:
# 删除数据
for k in request_data:
@@ -22,7 +24,7 @@ def delete_encode(request_data):
if k not in rsd.keys():
return utils.get_stand_return(False, "Key {} not found in encode pool.".format(k))
else:
- rsd = utils.dict_delete(k, rsd)
+ rsd = utils.dict_delete(int(k), rsd)
globalvar.set_value("rsd", rsd)
utils.dict_save(rsd, cfg['data_paths']['rsd_path'])
diff --git a/code/api_controlers/image_encode_ctl.py b/code/api_controlers/image_encode_ctl.py
index db841b8..593e76f 100644
--- a/code/api_controlers/image_encode_ctl.py
+++ b/code/api_controlers/image_encode_ctl.py
@@ -17,9 +17,9 @@ def image_encode_append(request_data):
logger.info("Request json: {}".format(request_data))
# 加入未编码数据
- for k,v in request_data.items():
+ for item in request_data:
unembeded_images = globalvar.get_value("unembeded_images")
- unembeded_images = utils.dict_insert(k, v, unembeded_images)
+ unembeded_images = utils.dict_insert(int(item["image_id"]), item, unembeded_images)
globalvar.set_value("unembeded_images", value=unembeded_images)
logger.info("Request append successfully for above request.\n")
@@ -33,13 +33,15 @@ def image_encode_runner():
if unembeded_images != {}:
logger.info("{} images in unembeded image pool have been detected ...".format(len(unembeded_images.keys())))
for img_id in list(unembeded_images.keys()):
- img_path = unembeded_images[img_id]
+ img_path = unembeded_images[img_id]["image_path"]
image_vector = base_function.image_encoder_api(model, img_path)
# 更新rsd数据
rsd = globalvar.get_value("rsd")
- rsd = utils.dict_insert(img_id, image_vector, rsd)
+ unembeded_images[img_id]["image_vector"] = image_vector
+ rsd = utils.dict_insert(img_id, unembeded_images[img_id], rsd)
+
globalvar.set_value("rsd", value=rsd)
# 删除未编码池
diff --git a/code/api_controlers/image_search_ctl.py b/code/api_controlers/image_search_ctl.py
index 026833b..28ba91d 100644
--- a/code/api_controlers/image_search_ctl.py
+++ b/code/api_controlers/image_search_ctl.py
@@ -15,35 +15,44 @@ def image_search(request_data):
# 检测请求完备性
if not isinstance(request_data, dict):
- return utils.get_stand_return(False, "Request must be dicts, and have keys: image_path, retrieved_ids, start, end.")
+ return utils.get_stand_return(False, "Request must be dicts, and have keys: image_path, user_id, page_no, page_size.")
if 'image_path' not in request_data.keys():
return utils.get_stand_return(False, "Request must have keys: str image_path.")
- if 'retrieved_ids' not in request_data.keys():
- return utils.get_stand_return(False, "Request must have keys: list retrieved_ids, default = *.")
- if 'start' not in request_data.keys():
- return utils.get_stand_return(False, "Request must have keys: int start.")
- if 'end' not in request_data.keys():
- return utils.get_stand_return(False, "Request must have keys: int end.")
+ if 'user_id' not in request_data.keys():
+ return utils.get_stand_return(False, "Request must have keys: int user_id")
+ if 'page_no' not in request_data.keys():
+ return utils.get_stand_return(False, "Request must have keys: int page_no.")
+ if 'page_size' not in request_data.keys():
+ return utils.get_stand_return(False, "Request must have keys: int page_size.")
# 解析
- image_path, retrieved_ids, start, end = request_data['image_path'], request_data['retrieved_ids'], request_data['start'], request_data['end']
+ image_path, user_id, page_no, page_size = request_data['image_path'], int(request_data['user_id']), int(request_data['page_no']), int(request_data['page_size'])
- #编码文本
+ # 出错机制
+ if page_no <=0 or page_size <=0 :
+ return utils.get_stand_return(False, "Request page_no and page_size must >= 1.")
+
+ #编码图像
image_vector = base_function.image_encoder_api(model, image_path)
# 向量比对
logger.info("Parse request correct, start image retrieval ...")
time_start = time.time()
- retrieval_results = {}
+ # 统计匹配数据
rsd = globalvar.get_value("rsd")
- if retrieved_ids == "*": # 检索所有影像
- for k in rsd.keys():
- retrieval_results[k] = base_function.cosine_sim_api(image_vector, rsd[k])
- else:
- for k in retrieved_ids: # 检索指定影像
- retrieval_results[k] = base_function.cosine_sim_api(image_vector, rsd[k])
- sorted_keys = utils.sort_based_values(retrieval_results)[start:end] # 排序
+ rsd_retrieved, retrieval_results = {}, {}
+ for k,v in rsd.items():
+ if (rsd[k]["privilege"] == 1) or (rsd[k]["user_id"] == user_id):
+ rsd_retrieved[k] = v
+
+ # 计算
+ for k in rsd_retrieved.keys():
+ retrieval_results[k] = base_function.cosine_sim_api(image_vector, rsd[k]["image_vector"])
+
+ # 排序
+ start, end = page_size * (page_no-1), page_size * page_no
+ sorted_keys = utils.sort_based_values(retrieval_results)[start:end]
time_end = time.time()
logger.info("Retrieval finished in {:.4f}s.".format(time_end - time_start))
diff --git a/code/api_controlers/semantic_localization_ctl.py b/code/api_controlers/semantic_localization_ctl.py
index c934c08..17b8a60 100644
--- a/code/api_controlers/semantic_localization_ctl.py
+++ b/code/api_controlers/semantic_localization_ctl.py
@@ -57,17 +57,18 @@ def split_image(img_path, steps):
logger.info("Image {} has been split successfully.".format(img_path))
-def generate_heatmap(img_path, text):
+def generate_heatmap(img_path, text, output_file_h, output_file_a):
subimages_dir = os.path.join(cfg['data_paths']['temp_path'], os.path.basename(img_path).split(".")[0]) +'_subimages'
- heatmap_subdir = utils.create_random_dirs_name(cfg['data_paths']['temp_path'])
- heatmap_dir = os.path.join(cfg['data_paths']['semantic_localization_path'], heatmap_subdir)
+# heatmap_subdir = utils.create_random_dirs_name(cfg['data_paths']['temp_path'])
+# heatmap_dir = os.path.join(cfg['data_paths']['semantic_localization_path'], heatmap_subdir)
+# heatmap_dir = output_path
# 清除缓存
- if os.path.exists(heatmap_dir):
- utils.delete_dire(heatmap_dir)
- else:
- os.makedirs(heatmap_dir)
+# if os.path.exists(heatmap_dir):
+# utils.delete_dire(heatmap_dir)
+# else:
+# os.makedirs(heatmap_dir)
logger.info("Start calculate similarities ...")
cal_start = time.time()
@@ -121,16 +122,21 @@ def generate_heatmap(img_path, text):
logger.info("Generate heatmap in {}s".format(generate_end-generate_start))
# save
- logger.info("Saving heatmap in {} ...".format(heatmap_dir))
- cv2.imwrite(os.path.join(heatmap_dir, "heatmap.png"),heatmap)
- cv2.imwrite(os.path.join(heatmap_dir, "heatmap_add.png"),img_add)
+# logger.info("Saving heatmap in {} ...".format(heatmap_dir))
+# cv2.imwrite(os.path.join(heatmap_dir, "heatmap.png"),heatmap)
+# cv2.imwrite(os.path.join(heatmap_dir, "heatmap_add.png"),img_add)
+
+ logger.info("Saving heatmap in {} ...".format(output_file_h))
+ logger.info("Saving heatmap in {} ...".format(output_file_a))
+ cv2.imwrite( output_file_h ,heatmap)
+ cv2.imwrite( output_file_a ,img_add)
logger.info("Saved ok.")
# clear temp
utils.delete_dire(subimages_dir)
os.rmdir(subimages_dir)
- return heatmap_dir
+# return heatmap_dir
def semantic_localization(request_data):
@@ -138,23 +144,23 @@ def semantic_localization(request_data):
# 检测请求完备性
if not isinstance(request_data, dict):
- return utils.get_stand_return(False, "Request must be dicts, and have keys: image_path, text, and params.")
- if 'image_path' not in request_data.keys():
- return utils.get_stand_return(False, "Request must have keys: str image_path.")
- if 'text' not in request_data.keys():
- return utils.get_stand_return(False, "Request must have keys: str text.")
+ return utils.get_stand_return(False, "Request must be dicts, and have keys: input_file, output_file, params.")
+ if 'input_file' not in request_data.keys():
+ return utils.get_stand_return(False, "Request must have keys: list input_file.")
+ if 'output_file' not in request_data.keys():
+ return utils.get_stand_return(False, "Request must have keys: list output_file.")
if ('params' in request_data.keys()) and ('steps' in request_data['params'].keys()):
steps = request_data['params']['steps']
else:
steps = [128,256,512]
# 解析
- image_path, text, params = request_data['image_path'], request_data['text'], request_data['params']
+ image_path, text, params, output_file_h, output_file_a = request_data['input_file'][0], request_data['params']['text'], request_data['params'], request_data['output_file'][0], request_data['output_file'][1]
# 判断文件格式
if not (image_path.endswith('.tif') or image_path.endswith('.jpg') or image_path.endswith('.tiff') or image_path.endswith('.png')):
return utils.get_stand_return(False, "File format is uncorrect: only support .tif, .tiff, .jpg, and .png .")
else:
split_image(image_path, steps)
- heatmap_dir = generate_heatmap(image_path, text)
- return utils.get_stand_return(True, "Generate successfully in {}".format(heatmap_dir))
\ No newline at end of file
+ generate_heatmap(image_path, text, output_file_h, output_file_a)
+ return utils.get_stand_return(True, "Generate successfully.")
diff --git a/code/api_controlers/crossmodal_search_ctl.py b/code/api_controlers/text_search_ctl.py
similarity index 55%
rename from code/api_controlers/crossmodal_search_ctl.py
rename to code/api_controlers/text_search_ctl.py
index 71cface..1a8805e 100644
--- a/code/api_controlers/crossmodal_search_ctl.py
+++ b/code/api_controlers/text_search_ctl.py
@@ -15,35 +15,44 @@ def text_search(request_data):
# 检测请求完备性
if not isinstance(request_data, dict):
- return utils.get_stand_return(False, "Request must be dicts, and have keys: text, retrieved_ids, start, end.")
+ return utils.get_stand_return(False, "Request must be dicts, and have keys: text, user_id, page_no, page_size.")
if 'text' not in request_data.keys():
return utils.get_stand_return(False, "Request must have keys: str text.")
- if 'retrieved_ids' not in request_data.keys():
- return utils.get_stand_return(False, "Request must have keys: list retrieved_ids, default = *.")
- if 'start' not in request_data.keys():
- return utils.get_stand_return(False, "Request must have keys: int start.")
- if 'end' not in request_data.keys():
- return utils.get_stand_return(False, "Request must have keys: int end.")
+ if 'user_id' not in request_data.keys():
+ return utils.get_stand_return(False, "Request must have keys: int user_id")
+ if 'page_no' not in request_data.keys():
+ return utils.get_stand_return(False, "Request must have keys: int page_no.")
+ if 'page_size' not in request_data.keys():
+ return utils.get_stand_return(False, "Request must have keys: int page_size.")
# 解析
- text, retrieved_ids, start, end = request_data['text'], request_data['retrieved_ids'], request_data['start'], request_data['end']
+ text, user_id, page_no, page_size = request_data['text'], int(request_data['user_id']), int(request_data['page_no']), int(request_data['page_size'])
+
+ # 出错机制
+ if page_no <=0 or page_size <=0 :
+ return utils.get_stand_return(False, "Request page_no and page_size must >= 1.")
#编码文本
text_vector = base_function.text_encoder_api(model, vocab_word, text)
- # 向量比对
+ # 开始向量比对
logger.info("Parse request correct, start cross-modal retrieval ...")
time_start = time.time()
- retrieval_results = {}
+ # 统计匹配数据
rsd = globalvar.get_value("rsd")
- if retrieved_ids == "*": # 检索所有影像
- for k in rsd.keys():
- retrieval_results[k] = base_function.cosine_sim_api(text_vector, rsd[k])
- else:
- for k in retrieved_ids: # 检索指定影像
- retrieval_results[k] = base_function.cosine_sim_api(text_vector, rsd[k])
- sorted_keys = utils.sort_based_values(retrieval_results)[start:end] # 排序
+ rsd_retrieved, retrieval_results = {}, {}
+ for k,v in rsd.items():
+ if (rsd[k]["privilege"] == 1) or (rsd[k]["user_id"] == user_id):
+ rsd_retrieved[k] = v
+
+ # 计算
+ for k in rsd_retrieved.keys():
+ retrieval_results[k] = base_function.cosine_sim_api(text_vector, rsd[k]["image_vector"])
+
+ # 排序
+ start, end = page_size * (page_no-1), page_size * page_no
+ sorted_keys = utils.sort_based_values(retrieval_results)[start:end]
time_end = time.time()
logger.info("Retrieval finished in {:.4f}s.".format(time_end - time_start))
diff --git a/code/common/config.yaml b/code/common/config.yaml
index cb7b201..d5af281 100644
--- a/code/common/config.yaml
+++ b/code/common/config.yaml
@@ -1,7 +1,7 @@
apis:
hosts:
- ip: '192.168.43.216'
- port: 49205
+ ip: '192.168.140.241'
+ port: 33133
image_encode:
route: '/api/image_encode/'
delete_encode:
@@ -14,6 +14,7 @@ apis:
route: '/api/semantic_localization/'
data_paths:
images_dir: '../data/test_data'
+ rsd_dir_path: '../data/retrieval_system_data/rsd'
rsd_path: '../data/retrieval_system_data/rsd/rsd.pkl'
semantic_localization_path: '../data/retrieval_system_data/semantic_localization_data'
temp_path: '../data/retrieval_system_data/tmp'
diff --git a/code/main.py b/code/main.py
index 8ac0877..7753ab9 100644
--- a/code/main.py
+++ b/code/main.py
@@ -18,8 +18,8 @@
# 创建起始变量
logger.info("Create init variables")
globalvar.set_value("unembeded_images", value={})
- globalvar.set_value("rsd", value=utils.init_rsd(cfg['data_paths']['rsd_dir_path']))
- utils.create_dirs(cfg['data_paths']['rsd_path'])
+ globalvar.set_value("rsd", value=utils.init_rsd(cfg['data_paths']['rsd_path']))
+ utils.create_dirs(cfg['data_paths']['rsd_dir_path'])
utils.create_dirs(cfg['data_paths']['semantic_localization_path'])
utils.create_dirs(cfg['data_paths']['temp_path'])
@@ -36,4 +36,4 @@
# 开启接口
from api_controlers import apis
logger.info("Start apis and running ...\n")
- apis.api_run(cfg)
\ No newline at end of file
+ apis.api_run(cfg)
diff --git a/code/models/__pycache__/__init__.cpython-37.pyc b/code/models/__pycache__/__init__.cpython-37.pyc
new file mode 100644
index 0000000..43ed7ae
Binary files /dev/null and b/code/models/__pycache__/__init__.cpython-37.pyc differ
diff --git a/code/models/__pycache__/encoder.cpython-37.pyc b/code/models/__pycache__/encoder.cpython-37.pyc
new file mode 100644
index 0000000..89c28e9
Binary files /dev/null and b/code/models/__pycache__/encoder.cpython-37.pyc differ
diff --git a/code/models/__pycache__/init.cpython-37.pyc b/code/models/__pycache__/init.cpython-37.pyc
new file mode 100644
index 0000000..cf4bf2f
Binary files /dev/null and b/code/models/__pycache__/init.cpython-37.pyc differ
diff --git a/code/models/__pycache__/vocab.cpython-37.pyc b/code/models/__pycache__/vocab.cpython-37.pyc
new file mode 100644
index 0000000..1bb7928
Binary files /dev/null and b/code/models/__pycache__/vocab.cpython-37.pyc differ
diff --git a/code/models/layers/__pycache__/GaLR.cpython-37.pyc b/code/models/layers/__pycache__/GaLR.cpython-37.pyc
new file mode 100644
index 0000000..8c42ea3
Binary files /dev/null and b/code/models/layers/__pycache__/GaLR.cpython-37.pyc differ
diff --git a/code/models/layers/__pycache__/GaLR_utils.cpython-37.pyc b/code/models/layers/__pycache__/GaLR_utils.cpython-37.pyc
new file mode 100644
index 0000000..9c90816
Binary files /dev/null and b/code/models/layers/__pycache__/GaLR_utils.cpython-37.pyc differ
diff --git a/code/models/layers/__pycache__/__init__.cpython-37.pyc b/code/models/layers/__pycache__/__init__.cpython-37.pyc
new file mode 100644
index 0000000..e0c3d3d
Binary files /dev/null and b/code/models/layers/__pycache__/__init__.cpython-37.pyc differ
diff --git a/code/models/layers/__pycache__/mca.cpython-37.pyc b/code/models/layers/__pycache__/mca.cpython-37.pyc
new file mode 100644
index 0000000..2668c29
Binary files /dev/null and b/code/models/layers/__pycache__/mca.cpython-37.pyc differ
diff --git a/code/models/layers/__pycache__/seq2vec.cpython-37.pyc b/code/models/layers/__pycache__/seq2vec.cpython-37.pyc
new file mode 100644
index 0000000..c934578
Binary files /dev/null and b/code/models/layers/__pycache__/seq2vec.cpython-37.pyc differ
diff --git a/data/retrieval_system_data/rsd/rsd.pkl b/data/retrieval_system_data/rsd/rsd.pkl
index b3083ab..22dcee6 100644
Binary files a/data/retrieval_system_data/rsd/rsd.pkl and b/data/retrieval_system_data/rsd/rsd.pkl differ
diff --git a/data/retrieval_system_data/semantic_localization_data/8e9yjvt4/heatmap.png b/data/retrieval_system_data/semantic_localization_data/8e9yjvt4/heatmap.png
deleted file mode 100644
index d739048..0000000
Binary files a/data/retrieval_system_data/semantic_localization_data/8e9yjvt4/heatmap.png and /dev/null differ
diff --git a/data/retrieval_system_data/semantic_localization_data/8e9yjvt4/heatmap_add.png b/data/retrieval_system_data/semantic_localization_data/8e9yjvt4/heatmap_add.png
deleted file mode 100644
index 210602f..0000000
Binary files a/data/retrieval_system_data/semantic_localization_data/8e9yjvt4/heatmap_add.png and /dev/null differ
diff --git a/readme.md b/readme.md
index 4e40a7a..1a71b10 100644
--- a/readme.md
+++ b/readme.md
@@ -1,6 +1,11 @@
-# Backend of cross-modal retrieval system
+# Backend of Cross-modal Retrieval System
##### Author: Zhiqiang Yuan
+
+
+
+
+
### -------------------------------------------------------------------------------------
### Welcome :+1:_`Fork and Star`_:+1:, then we'll let you know when we update
@@ -9,6 +14,29 @@ The back-end of cross-modal retrieval system,wihch will contain services such
The purpose of this project is to provide a set of applicable retrieval framework for the retrieval model.
We will use RS image data as the baseline for development, and demonstrate the potential of the project through services such as semantic positioning and cross-modal retrieval.
+
+#### Summary
+
+* [Requirements](#requirements)
+* [Apis](#apis)
+* [Architecture](#architecture)
+* [Three Steps to Use This Framework](#three-steps-to-use-this-framework)
+* [Customize Your Rerieval Model](#customize-your-rerieval-model)
+* [Citation](#citation)
+### -------------------------------------------------------------------------------------
+### Requirements
+```bash
+numpy>=1.7.1
+six>=1.1.0
+PyTorch > 0.3
+flask >= 1.1.1
+Numpy
+h5py
+nltk
+yaml
+```
+------------------------------------------
+
### -------------------------------------------------------------------------------------
### Apis
```bash
@@ -16,13 +44,27 @@ We will use RS image data as the baseline for development, and demonstrate the p
#/api/image_encode/ [POST]
# FUNC: encode images
-data = {
- # image_id: file_path
- 11:"../data/test_data/images/00013.jpg",
- 33: "../data/test_data/images/00013.jpg",
- 32: "../data/test_data/images/00013.jpg",
-}
-url = 'http://192.168.43.216:49205/api/image_encode/'
+data = [
+ {
+ "image_id": 11,
+ "image_path": "../data/test_data/images/00013.jpg",
+ "user_id": 1,
+ "privilege": 1
+ },
+ {
+ "image_id": 33,
+ "image_path": "../data/test_data/images/00013.jpg",
+ "user_id": 1,
+ "privilege": 1
+ },
+ {
+ "image_id": 32,
+ "image_path": "../data/test_data/images/00013.jpg",
+ "user_id": 2,
+ "privilege": 1
+ }
+]
+url = 'http://192.168.97.241:33133/api/image_encode/'
r = requests.post(url, data=json.dumps(data))
print(r.json())
@@ -34,8 +76,9 @@ print(r.json())
# FUNC: delete encodes
# image_id
-data = [3, 4]
-url = 'http://192.168.43.216:49205/api/delete_encode/'
+data = {"deleteID":"32"}
+url = 'http://192.168.140.241:33133/api/delete_encode/'
+
r = requests.post(url, data=json.dumps(data))
print(r.json())
```
@@ -46,12 +89,13 @@ print(r.json())
# FUNC: cross-modal retrieval
data = {
- 'text': "One block has a cross shaped roof church.", # retrieved text
- 'retrieved_ids': "*", # retrieved images pool
- 'start': 0, # from top
- 'end': 100 # to end
- }
-url = 'http://192.168.43.216:49205/api/text_search/'
+ 'text': "One block has a cross shaped roof church.",
+ 'user_id': 1,
+ 'page_no': 1,
+ 'page_size': 10
+}
+url = 'http://0.0.0.0:33133/api/text_search/'
+
r = requests.post(url, data=json.dumps(data))
print(r.json())
```
@@ -62,12 +106,13 @@ print(r.json())
# FUNC: image-image retrieval
data = {
- 'image_path': "../data/test_data/images/00013.jpg",, # retrieved image
- 'retrieved_ids': "*", # retrieved images pool: 1) * represents all, 2) [1, 2, 4] represent images pool
- 'start': 0, # from top
- 'end': 100 # to end
- }
-url = 'http://192.168.43.216:49205/api/image_search/'
+ 'image_path': "/data/test_data/images/00013.jpg",
+ 'user_id': 1,
+ 'page_no': 1,
+ 'page_size': 10
+}
+url = 'http://0.0.0.0:33133/api/image_search/'
+
r = requests.post(url, data=json.dumps(data))
print(r.json())
```
@@ -78,13 +123,17 @@ print(r.json())
# FUNC: semantic localization
data = {
- 'image_path': "../data/test_data/images/demo1.tif",
- 'text': "there are two tennis courts beside the playground",
- 'params': {
- 'steps': [64, 128,256,512]
- },
-}
-url = 'http://192.168.43.216:49205/api/semantic_localization/'
+ "input_file": ["../data/test_data/images/demo1.tif"],
+ "output_file": [
+ "../data/retrieval_system_data/semantic_localization_data/heatmap.png",
+ "../data/retrieval_system_data/semantic_localization_data/heatmap_add.png"],
+ "params": {
+ "text": "there are two tennis courts beside the playground",
+ "steps": [128,256,512]
+ }
+ }
+url = 'http://192.168.97.241:33133/api/semantic_localization/'
+
r = requests.post(url, data=json.dumps(data))
print(r.json())
```
@@ -110,5 +159,28 @@ print(r.json())
```
### -------------------------------------------------------------------------------------
-### Environments
+### Three Steps to Use This Framework
+
+Step 1. Install the environment, download the code to the local, and change the path setting of the ./code/common/config file. At the same time, you need to change the yaml path file under ./code/models/options/ .
+
+Step 2. Enter the ./code directory and run main.py to start the flask service.
+
+Step 3. Use Postman etc. or python's built-in request service for sample requests. Some interface samples have been shown in ./test/test_qpi.py .
+
+
+### -------------------------------------------------------------------------------------
+### Customize Your Rerieval Model
+
+You only need to change the ./code/models folder to make your retrieval model run in the service. For this, you should provide encoding interfaces and model initialization interfaces for different modal data. For more information about this, please see the README file under ./code/models/ .
+
## Under Updating
+
+## Citation
+If you feel this code helpful or use this code or dataset, please cite it as
+```
+Z. Yuan et al., "Exploring a Fine-Grained Multiscale Method for Cross-Modal Remote Sensing Image Retrieval," in IEEE Transactions on Geoscience and Remote Sensing, doi: 10.1109/TGRS.2021.3078451.
+
+Z. Yuan et al., "A Lightweight Multi-scale Crossmodal Text-Image Retrieval Method In Remote Sensing," in IEEE Transactions on Geoscience and Remote Sensing, doi: 10.1109/TGRS.2021.3124252.
+```
+
+
diff --git a/test/test_api.py b/test/test_api.py
index a0de14f..515714e 100644
--- a/test/test_api.py
+++ b/test/test_api.py
@@ -3,12 +3,27 @@
def post_encode_image():
# 编码请求
# image_id: file_path
- data = {
- 11:"../data/test_data/images/00013.jpg",
- 33: "../data/test_data/images/00013.jpg",
- 32: "../data/test_data/images/00013.jpg",
- }
- url = 'http://192.168.43.216:49205/api/image_encode/'
+ data = [
+ {
+ "image_id": 12,
+ "image_path": "../data/test_data/images/00013.jpg",
+ "user_id": 1,
+ "privilege": 1
+ },
+ {
+ "image_id": 33,
+ "image_path": "../data/test_data/images/00013.jpg",
+ "user_id": 1,
+ "privilege": 1
+ },
+ {
+ "image_id": 32,
+ "image_path": "../data/test_data/images/00013.jpg",
+ "user_id": 2,
+ "privilege": 1
+ }
+ ]
+ url = 'http://192.168.140.241:33133/api/image_encode/'
r = requests.post(url, data=json.dumps(data))
print(r.json())
@@ -16,8 +31,8 @@ def post_encode_image():
def post_delete_encode():
# 删除编码请求
# image_id
- data = ['3']
- url = 'http://192.168.43.216:49205/api/delete_encode/'
+ data = {"deleteID":"32"}
+ url = 'http://192.168.140.241:33133/api/delete_encode/'
r = requests.post(url, data=json.dumps(data))
print(r.json())
@@ -27,11 +42,11 @@ def post_t2i_rerieval():
# text
data = {
'text': "One block has a cross shaped roof church.",
- 'retrieved_ids': "*",
- 'start': 0,
- 'end': 100
+ 'user_id': 2,
+ 'page_no': 2,
+ 'page_size': 0
}
- url = 'http://192.168.43.216:49205/api/text_search/'
+ url = 'http://192.168.140.241:33133/api/text_search/'
r = requests.post(url, data=json.dumps(data))
print(r.json())
@@ -41,11 +56,11 @@ def post_i2i_retrieval():
# image
data = {
'image_path': "../data/test_data/images/00013.jpg",
- 'retrieved_ids': "*",
- 'start': 0,
- 'end': 100
+ 'user_id': 1,
+ 'page_no': 1,
+ 'page_size': 10
}
- url = 'http://192.168.43.216:49205/api/image_search/'
+ url = 'http://192.168.140.241:33133/api/image_search/'
r = requests.post(url, data=json.dumps(data))
print(r.json())
@@ -54,16 +69,23 @@ def post_semantic_localization():
# 语义定位请求
# # semantic localization
data = {
- 'image_path': "../data/test_data/images/demo1.tif",
- 'text': "there are two tennis courts beside the playground",
- 'params': {
- 'steps': [128,256,512]
- },
- }
- url = 'http://192.168.43.216:49205/api/semantic_localization/'
+ "input_file": ["../data/test_data/images/demo1.tif"],
+ "output_file": [
+ "../data/retrieval_system_data/semantic_localization_data/heatmap.png",
+ "../data/retrieval_system_data/semantic_localization_data/heatmap_add.png"],
+ "params": {
+ "text": "there are two tennis courts beside the playground",
+ "steps": [128,256,512]
+ }
+ }
+ url = 'http://192.168.140.241:33133/api/semantic_localization/'
r = requests.post(url, data=json.dumps(data))
print(r.json())
if __name__=="__main__":
- post_semantic_localization()
\ No newline at end of file
+ # post_semantic_localization()
+ # post_encode_image()
+ # post_delete_encode()
+ # post_t2i_rerieval()
+ post_i2i_retrieval()
diff --git a/test/test_data.py b/test/test_data.py
index b5b8ae3..d11861c 100644
--- a/test/test_data.py
+++ b/test/test_data.py
@@ -5,5 +5,9 @@ def dict_load(name="rsd.pkl"):
with open(name, 'rb') as f:
return pickle.load(f)
-pkl = dict_load("../data/retrieval_system_data/rsd.pkl")
-print(pkl.keys())
\ No newline at end of file
+pkl = dict_load("../data/retrieval_system_data/rsd/rsd.pkl")
+
+for k,v in pkl.items():
+ print(k)
+ print(v)
+ print("========")
\ No newline at end of file
diff --git a/test/test_function.py b/test/test_function.py
index 65a3d2c..ffcc475 100644
--- a/test/test_function.py
+++ b/test/test_function.py
@@ -1,16 +1,24 @@
import numpy as np
-def l2norm(X, eps=1e-8):
- """L2-normalize columns of X
- """
- norm = np.array([sum([i**2 for i in X]) + eps for ii in X])
- X = np.divide(X, norm)
- return X
+data = [
+ {
+ "image_id": 11,
+ "image_path": "../data/test_data/images/00013.jpg",
+ "user_id": 1,
+ "privilege": 1
+ },
+ {
+ "image_id": 33,
+ "image_path": "../data/test_data/images/00013.jpg",
+ "user_id": 1,
+ "privilege": 1
+ },
+ {
+ "image_id": 32,
+ "image_path": "../data/test_data/images/00013.jpg",
+ "user_id": 2,
+ "privilege": 1
+ }
+]
-X = np.random.random_sample((512))
-
-a = l2norm(X)
-print(a)
-
-b = X/np.linalg.norm(X)
-print(b)
\ No newline at end of file
+print(data)
\ No newline at end of file