diff --git a/code/__pycache__/globalvar.cpython-37.pyc b/code/__pycache__/globalvar.cpython-37.pyc new file mode 100644 index 0000000..e33dd9f Binary files /dev/null and b/code/__pycache__/globalvar.cpython-37.pyc differ diff --git a/code/api_controlers/__pycache__/__init__.cpython-37.pyc b/code/api_controlers/__pycache__/__init__.cpython-37.pyc new file mode 100644 index 0000000..4db31ff Binary files /dev/null and b/code/api_controlers/__pycache__/__init__.cpython-37.pyc differ diff --git a/code/api_controlers/__pycache__/apis.cpython-37.pyc b/code/api_controlers/__pycache__/apis.cpython-37.pyc new file mode 100644 index 0000000..9fdf7e3 Binary files /dev/null and b/code/api_controlers/__pycache__/apis.cpython-37.pyc differ diff --git a/code/api_controlers/__pycache__/base_function.cpython-37.pyc b/code/api_controlers/__pycache__/base_function.cpython-37.pyc new file mode 100644 index 0000000..af56423 Binary files /dev/null and b/code/api_controlers/__pycache__/base_function.cpython-37.pyc differ diff --git a/code/api_controlers/__pycache__/delete_encode_ctl.cpython-37.pyc b/code/api_controlers/__pycache__/delete_encode_ctl.cpython-37.pyc new file mode 100644 index 0000000..bf04c05 Binary files /dev/null and b/code/api_controlers/__pycache__/delete_encode_ctl.cpython-37.pyc differ diff --git a/code/api_controlers/__pycache__/image_encode_ctl.cpython-37.pyc b/code/api_controlers/__pycache__/image_encode_ctl.cpython-37.pyc new file mode 100644 index 0000000..b00822c Binary files /dev/null and b/code/api_controlers/__pycache__/image_encode_ctl.cpython-37.pyc differ diff --git a/code/api_controlers/__pycache__/image_search_ctl.cpython-37.pyc b/code/api_controlers/__pycache__/image_search_ctl.cpython-37.pyc new file mode 100644 index 0000000..074d5c8 Binary files /dev/null and b/code/api_controlers/__pycache__/image_search_ctl.cpython-37.pyc differ diff --git a/code/api_controlers/__pycache__/model_init_ctl.cpython-37.pyc b/code/api_controlers/__pycache__/model_init_ctl.cpython-37.pyc new file mode 100644 index 0000000..b7e257c Binary files /dev/null and b/code/api_controlers/__pycache__/model_init_ctl.cpython-37.pyc differ diff --git a/code/api_controlers/__pycache__/semantic_localization_ctl.cpython-37.pyc b/code/api_controlers/__pycache__/semantic_localization_ctl.cpython-37.pyc new file mode 100644 index 0000000..a4b72b2 Binary files /dev/null and b/code/api_controlers/__pycache__/semantic_localization_ctl.cpython-37.pyc differ diff --git a/code/api_controlers/__pycache__/text_search_ctl.cpython-37.pyc b/code/api_controlers/__pycache__/text_search_ctl.cpython-37.pyc new file mode 100644 index 0000000..b50f6ed Binary files /dev/null and b/code/api_controlers/__pycache__/text_search_ctl.cpython-37.pyc differ diff --git a/code/api_controlers/__pycache__/utils.cpython-37.pyc b/code/api_controlers/__pycache__/utils.cpython-37.pyc new file mode 100644 index 0000000..ddcca41 Binary files /dev/null and b/code/api_controlers/__pycache__/utils.cpython-37.pyc differ diff --git a/code/api_controlers/apis.py b/code/api_controlers/apis.py index 4343c77..b361819 100644 --- a/code/api_controlers/apis.py +++ b/code/api_controlers/apis.py @@ -1,10 +1,9 @@ from flask import Flask, request, jsonify import json, os -import _thread from threading import Timer from api_controlers import base_function, image_encode_ctl, delete_encode_ctl,\ - crossmodal_search_ctl, image_search_ctl, semantic_localization_ctl, utils + text_search_ctl, image_search_ctl, semantic_localization_ctl, utils def api_run(cfg): app = Flask(__name__) # Flask 初始化 @@ -28,7 +27,7 @@ def delete_encode(): @app.route(cfg['apis']['text_search']['route'], methods=['post']) def text_search(): request_data = json.loads(request.data.decode('utf-8')) - return_json = crossmodal_search_ctl.text_search(request_data) + return_json = text_search_ctl.text_search(request_data) return return_json # 图像检索 diff --git a/code/api_controlers/delete_encode_ctl.py b/code/api_controlers/delete_encode_ctl.py index e1ed738..bfadf50 100644 --- a/code/api_controlers/delete_encode_ctl.py +++ b/code/api_controlers/delete_encode_ctl.py @@ -15,6 +15,8 @@ def delete_encode(request_data): logger.info("\nRequest json: {}".format(request_data)) + + request_data = [int(i) for i in request_data['deleteID'].split(",")] if request_data != []: # 删除数据 for k in request_data: @@ -22,7 +24,7 @@ def delete_encode(request_data): if k not in rsd.keys(): return utils.get_stand_return(False, "Key {} not found in encode pool.".format(k)) else: - rsd = utils.dict_delete(k, rsd) + rsd = utils.dict_delete(int(k), rsd) globalvar.set_value("rsd", rsd) utils.dict_save(rsd, cfg['data_paths']['rsd_path']) diff --git a/code/api_controlers/image_encode_ctl.py b/code/api_controlers/image_encode_ctl.py index db841b8..593e76f 100644 --- a/code/api_controlers/image_encode_ctl.py +++ b/code/api_controlers/image_encode_ctl.py @@ -17,9 +17,9 @@ def image_encode_append(request_data): logger.info("Request json: {}".format(request_data)) # 加入未编码数据 - for k,v in request_data.items(): + for item in request_data: unembeded_images = globalvar.get_value("unembeded_images") - unembeded_images = utils.dict_insert(k, v, unembeded_images) + unembeded_images = utils.dict_insert(int(item["image_id"]), item, unembeded_images) globalvar.set_value("unembeded_images", value=unembeded_images) logger.info("Request append successfully for above request.\n") @@ -33,13 +33,15 @@ def image_encode_runner(): if unembeded_images != {}: logger.info("{} images in unembeded image pool have been detected ...".format(len(unembeded_images.keys()))) for img_id in list(unembeded_images.keys()): - img_path = unembeded_images[img_id] + img_path = unembeded_images[img_id]["image_path"] image_vector = base_function.image_encoder_api(model, img_path) # 更新rsd数据 rsd = globalvar.get_value("rsd") - rsd = utils.dict_insert(img_id, image_vector, rsd) + unembeded_images[img_id]["image_vector"] = image_vector + rsd = utils.dict_insert(img_id, unembeded_images[img_id], rsd) + globalvar.set_value("rsd", value=rsd) # 删除未编码池 diff --git a/code/api_controlers/image_search_ctl.py b/code/api_controlers/image_search_ctl.py index 026833b..28ba91d 100644 --- a/code/api_controlers/image_search_ctl.py +++ b/code/api_controlers/image_search_ctl.py @@ -15,35 +15,44 @@ def image_search(request_data): # 检测请求完备性 if not isinstance(request_data, dict): - return utils.get_stand_return(False, "Request must be dicts, and have keys: image_path, retrieved_ids, start, end.") + return utils.get_stand_return(False, "Request must be dicts, and have keys: image_path, user_id, page_no, page_size.") if 'image_path' not in request_data.keys(): return utils.get_stand_return(False, "Request must have keys: str image_path.") - if 'retrieved_ids' not in request_data.keys(): - return utils.get_stand_return(False, "Request must have keys: list retrieved_ids, default = *.") - if 'start' not in request_data.keys(): - return utils.get_stand_return(False, "Request must have keys: int start.") - if 'end' not in request_data.keys(): - return utils.get_stand_return(False, "Request must have keys: int end.") + if 'user_id' not in request_data.keys(): + return utils.get_stand_return(False, "Request must have keys: int user_id") + if 'page_no' not in request_data.keys(): + return utils.get_stand_return(False, "Request must have keys: int page_no.") + if 'page_size' not in request_data.keys(): + return utils.get_stand_return(False, "Request must have keys: int page_size.") # 解析 - image_path, retrieved_ids, start, end = request_data['image_path'], request_data['retrieved_ids'], request_data['start'], request_data['end'] + image_path, user_id, page_no, page_size = request_data['image_path'], int(request_data['user_id']), int(request_data['page_no']), int(request_data['page_size']) - #编码文本 + # 出错机制 + if page_no <=0 or page_size <=0 : + return utils.get_stand_return(False, "Request page_no and page_size must >= 1.") + + #编码图像 image_vector = base_function.image_encoder_api(model, image_path) # 向量比对 logger.info("Parse request correct, start image retrieval ...") time_start = time.time() - retrieval_results = {} + # 统计匹配数据 rsd = globalvar.get_value("rsd") - if retrieved_ids == "*": # 检索所有影像 - for k in rsd.keys(): - retrieval_results[k] = base_function.cosine_sim_api(image_vector, rsd[k]) - else: - for k in retrieved_ids: # 检索指定影像 - retrieval_results[k] = base_function.cosine_sim_api(image_vector, rsd[k]) - sorted_keys = utils.sort_based_values(retrieval_results)[start:end] # 排序 + rsd_retrieved, retrieval_results = {}, {} + for k,v in rsd.items(): + if (rsd[k]["privilege"] == 1) or (rsd[k]["user_id"] == user_id): + rsd_retrieved[k] = v + + # 计算 + for k in rsd_retrieved.keys(): + retrieval_results[k] = base_function.cosine_sim_api(image_vector, rsd[k]["image_vector"]) + + # 排序 + start, end = page_size * (page_no-1), page_size * page_no + sorted_keys = utils.sort_based_values(retrieval_results)[start:end] time_end = time.time() logger.info("Retrieval finished in {:.4f}s.".format(time_end - time_start)) diff --git a/code/api_controlers/semantic_localization_ctl.py b/code/api_controlers/semantic_localization_ctl.py index c934c08..17b8a60 100644 --- a/code/api_controlers/semantic_localization_ctl.py +++ b/code/api_controlers/semantic_localization_ctl.py @@ -57,17 +57,18 @@ def split_image(img_path, steps): logger.info("Image {} has been split successfully.".format(img_path)) -def generate_heatmap(img_path, text): +def generate_heatmap(img_path, text, output_file_h, output_file_a): subimages_dir = os.path.join(cfg['data_paths']['temp_path'], os.path.basename(img_path).split(".")[0]) +'_subimages' - heatmap_subdir = utils.create_random_dirs_name(cfg['data_paths']['temp_path']) - heatmap_dir = os.path.join(cfg['data_paths']['semantic_localization_path'], heatmap_subdir) +# heatmap_subdir = utils.create_random_dirs_name(cfg['data_paths']['temp_path']) +# heatmap_dir = os.path.join(cfg['data_paths']['semantic_localization_path'], heatmap_subdir) +# heatmap_dir = output_path # 清除缓存 - if os.path.exists(heatmap_dir): - utils.delete_dire(heatmap_dir) - else: - os.makedirs(heatmap_dir) +# if os.path.exists(heatmap_dir): +# utils.delete_dire(heatmap_dir) +# else: +# os.makedirs(heatmap_dir) logger.info("Start calculate similarities ...") cal_start = time.time() @@ -121,16 +122,21 @@ def generate_heatmap(img_path, text): logger.info("Generate heatmap in {}s".format(generate_end-generate_start)) # save - logger.info("Saving heatmap in {} ...".format(heatmap_dir)) - cv2.imwrite(os.path.join(heatmap_dir, "heatmap.png"),heatmap) - cv2.imwrite(os.path.join(heatmap_dir, "heatmap_add.png"),img_add) +# logger.info("Saving heatmap in {} ...".format(heatmap_dir)) +# cv2.imwrite(os.path.join(heatmap_dir, "heatmap.png"),heatmap) +# cv2.imwrite(os.path.join(heatmap_dir, "heatmap_add.png"),img_add) + + logger.info("Saving heatmap in {} ...".format(output_file_h)) + logger.info("Saving heatmap in {} ...".format(output_file_a)) + cv2.imwrite( output_file_h ,heatmap) + cv2.imwrite( output_file_a ,img_add) logger.info("Saved ok.") # clear temp utils.delete_dire(subimages_dir) os.rmdir(subimages_dir) - return heatmap_dir +# return heatmap_dir def semantic_localization(request_data): @@ -138,23 +144,23 @@ def semantic_localization(request_data): # 检测请求完备性 if not isinstance(request_data, dict): - return utils.get_stand_return(False, "Request must be dicts, and have keys: image_path, text, and params.") - if 'image_path' not in request_data.keys(): - return utils.get_stand_return(False, "Request must have keys: str image_path.") - if 'text' not in request_data.keys(): - return utils.get_stand_return(False, "Request must have keys: str text.") + return utils.get_stand_return(False, "Request must be dicts, and have keys: input_file, output_file, params.") + if 'input_file' not in request_data.keys(): + return utils.get_stand_return(False, "Request must have keys: list input_file.") + if 'output_file' not in request_data.keys(): + return utils.get_stand_return(False, "Request must have keys: list output_file.") if ('params' in request_data.keys()) and ('steps' in request_data['params'].keys()): steps = request_data['params']['steps'] else: steps = [128,256,512] # 解析 - image_path, text, params = request_data['image_path'], request_data['text'], request_data['params'] + image_path, text, params, output_file_h, output_file_a = request_data['input_file'][0], request_data['params']['text'], request_data['params'], request_data['output_file'][0], request_data['output_file'][1] # 判断文件格式 if not (image_path.endswith('.tif') or image_path.endswith('.jpg') or image_path.endswith('.tiff') or image_path.endswith('.png')): return utils.get_stand_return(False, "File format is uncorrect: only support .tif, .tiff, .jpg, and .png .") else: split_image(image_path, steps) - heatmap_dir = generate_heatmap(image_path, text) - return utils.get_stand_return(True, "Generate successfully in {}".format(heatmap_dir)) \ No newline at end of file + generate_heatmap(image_path, text, output_file_h, output_file_a) + return utils.get_stand_return(True, "Generate successfully.") diff --git a/code/api_controlers/crossmodal_search_ctl.py b/code/api_controlers/text_search_ctl.py similarity index 55% rename from code/api_controlers/crossmodal_search_ctl.py rename to code/api_controlers/text_search_ctl.py index 71cface..1a8805e 100644 --- a/code/api_controlers/crossmodal_search_ctl.py +++ b/code/api_controlers/text_search_ctl.py @@ -15,35 +15,44 @@ def text_search(request_data): # 检测请求完备性 if not isinstance(request_data, dict): - return utils.get_stand_return(False, "Request must be dicts, and have keys: text, retrieved_ids, start, end.") + return utils.get_stand_return(False, "Request must be dicts, and have keys: text, user_id, page_no, page_size.") if 'text' not in request_data.keys(): return utils.get_stand_return(False, "Request must have keys: str text.") - if 'retrieved_ids' not in request_data.keys(): - return utils.get_stand_return(False, "Request must have keys: list retrieved_ids, default = *.") - if 'start' not in request_data.keys(): - return utils.get_stand_return(False, "Request must have keys: int start.") - if 'end' not in request_data.keys(): - return utils.get_stand_return(False, "Request must have keys: int end.") + if 'user_id' not in request_data.keys(): + return utils.get_stand_return(False, "Request must have keys: int user_id") + if 'page_no' not in request_data.keys(): + return utils.get_stand_return(False, "Request must have keys: int page_no.") + if 'page_size' not in request_data.keys(): + return utils.get_stand_return(False, "Request must have keys: int page_size.") # 解析 - text, retrieved_ids, start, end = request_data['text'], request_data['retrieved_ids'], request_data['start'], request_data['end'] + text, user_id, page_no, page_size = request_data['text'], int(request_data['user_id']), int(request_data['page_no']), int(request_data['page_size']) + + # 出错机制 + if page_no <=0 or page_size <=0 : + return utils.get_stand_return(False, "Request page_no and page_size must >= 1.") #编码文本 text_vector = base_function.text_encoder_api(model, vocab_word, text) - # 向量比对 + # 开始向量比对 logger.info("Parse request correct, start cross-modal retrieval ...") time_start = time.time() - retrieval_results = {} + # 统计匹配数据 rsd = globalvar.get_value("rsd") - if retrieved_ids == "*": # 检索所有影像 - for k in rsd.keys(): - retrieval_results[k] = base_function.cosine_sim_api(text_vector, rsd[k]) - else: - for k in retrieved_ids: # 检索指定影像 - retrieval_results[k] = base_function.cosine_sim_api(text_vector, rsd[k]) - sorted_keys = utils.sort_based_values(retrieval_results)[start:end] # 排序 + rsd_retrieved, retrieval_results = {}, {} + for k,v in rsd.items(): + if (rsd[k]["privilege"] == 1) or (rsd[k]["user_id"] == user_id): + rsd_retrieved[k] = v + + # 计算 + for k in rsd_retrieved.keys(): + retrieval_results[k] = base_function.cosine_sim_api(text_vector, rsd[k]["image_vector"]) + + # 排序 + start, end = page_size * (page_no-1), page_size * page_no + sorted_keys = utils.sort_based_values(retrieval_results)[start:end] time_end = time.time() logger.info("Retrieval finished in {:.4f}s.".format(time_end - time_start)) diff --git a/code/common/config.yaml b/code/common/config.yaml index cb7b201..d5af281 100644 --- a/code/common/config.yaml +++ b/code/common/config.yaml @@ -1,7 +1,7 @@ apis: hosts: - ip: '192.168.43.216' - port: 49205 + ip: '192.168.140.241' + port: 33133 image_encode: route: '/api/image_encode/' delete_encode: @@ -14,6 +14,7 @@ apis: route: '/api/semantic_localization/' data_paths: images_dir: '../data/test_data' + rsd_dir_path: '../data/retrieval_system_data/rsd' rsd_path: '../data/retrieval_system_data/rsd/rsd.pkl' semantic_localization_path: '../data/retrieval_system_data/semantic_localization_data' temp_path: '../data/retrieval_system_data/tmp' diff --git a/code/main.py b/code/main.py index 8ac0877..7753ab9 100644 --- a/code/main.py +++ b/code/main.py @@ -18,8 +18,8 @@ # 创建起始变量 logger.info("Create init variables") globalvar.set_value("unembeded_images", value={}) - globalvar.set_value("rsd", value=utils.init_rsd(cfg['data_paths']['rsd_dir_path'])) - utils.create_dirs(cfg['data_paths']['rsd_path']) + globalvar.set_value("rsd", value=utils.init_rsd(cfg['data_paths']['rsd_path'])) + utils.create_dirs(cfg['data_paths']['rsd_dir_path']) utils.create_dirs(cfg['data_paths']['semantic_localization_path']) utils.create_dirs(cfg['data_paths']['temp_path']) @@ -36,4 +36,4 @@ # 开启接口 from api_controlers import apis logger.info("Start apis and running ...\n") - apis.api_run(cfg) \ No newline at end of file + apis.api_run(cfg) diff --git a/code/models/__pycache__/__init__.cpython-37.pyc b/code/models/__pycache__/__init__.cpython-37.pyc new file mode 100644 index 0000000..43ed7ae Binary files /dev/null and b/code/models/__pycache__/__init__.cpython-37.pyc differ diff --git a/code/models/__pycache__/encoder.cpython-37.pyc b/code/models/__pycache__/encoder.cpython-37.pyc new file mode 100644 index 0000000..89c28e9 Binary files /dev/null and b/code/models/__pycache__/encoder.cpython-37.pyc differ diff --git a/code/models/__pycache__/init.cpython-37.pyc b/code/models/__pycache__/init.cpython-37.pyc new file mode 100644 index 0000000..cf4bf2f Binary files /dev/null and b/code/models/__pycache__/init.cpython-37.pyc differ diff --git a/code/models/__pycache__/vocab.cpython-37.pyc b/code/models/__pycache__/vocab.cpython-37.pyc new file mode 100644 index 0000000..1bb7928 Binary files /dev/null and b/code/models/__pycache__/vocab.cpython-37.pyc differ diff --git a/code/models/layers/__pycache__/GaLR.cpython-37.pyc b/code/models/layers/__pycache__/GaLR.cpython-37.pyc new file mode 100644 index 0000000..8c42ea3 Binary files /dev/null and b/code/models/layers/__pycache__/GaLR.cpython-37.pyc differ diff --git a/code/models/layers/__pycache__/GaLR_utils.cpython-37.pyc b/code/models/layers/__pycache__/GaLR_utils.cpython-37.pyc new file mode 100644 index 0000000..9c90816 Binary files /dev/null and b/code/models/layers/__pycache__/GaLR_utils.cpython-37.pyc differ diff --git a/code/models/layers/__pycache__/__init__.cpython-37.pyc b/code/models/layers/__pycache__/__init__.cpython-37.pyc new file mode 100644 index 0000000..e0c3d3d Binary files /dev/null and b/code/models/layers/__pycache__/__init__.cpython-37.pyc differ diff --git a/code/models/layers/__pycache__/mca.cpython-37.pyc b/code/models/layers/__pycache__/mca.cpython-37.pyc new file mode 100644 index 0000000..2668c29 Binary files /dev/null and b/code/models/layers/__pycache__/mca.cpython-37.pyc differ diff --git a/code/models/layers/__pycache__/seq2vec.cpython-37.pyc b/code/models/layers/__pycache__/seq2vec.cpython-37.pyc new file mode 100644 index 0000000..c934578 Binary files /dev/null and b/code/models/layers/__pycache__/seq2vec.cpython-37.pyc differ diff --git a/data/retrieval_system_data/rsd/rsd.pkl b/data/retrieval_system_data/rsd/rsd.pkl index b3083ab..22dcee6 100644 Binary files a/data/retrieval_system_data/rsd/rsd.pkl and b/data/retrieval_system_data/rsd/rsd.pkl differ diff --git a/data/retrieval_system_data/semantic_localization_data/8e9yjvt4/heatmap.png b/data/retrieval_system_data/semantic_localization_data/8e9yjvt4/heatmap.png deleted file mode 100644 index d739048..0000000 Binary files a/data/retrieval_system_data/semantic_localization_data/8e9yjvt4/heatmap.png and /dev/null differ diff --git a/data/retrieval_system_data/semantic_localization_data/8e9yjvt4/heatmap_add.png b/data/retrieval_system_data/semantic_localization_data/8e9yjvt4/heatmap_add.png deleted file mode 100644 index 210602f..0000000 Binary files a/data/retrieval_system_data/semantic_localization_data/8e9yjvt4/heatmap_add.png and /dev/null differ diff --git a/readme.md b/readme.md index 4e40a7a..1a71b10 100644 --- a/readme.md +++ b/readme.md @@ -1,6 +1,11 @@ -# Backend of cross-modal retrieval system +# Backend of Cross-modal Retrieval System ##### Author: Zhiqiang Yuan + +![Supported Python versions](https://img.shields.io/badge/python-3.7-blue.svg) +![Supported OS](https://img.shields.io/badge/Supported%20OS-Linux-yellow.svg) +![npm License](https://img.shields.io/npm/l/mithril.svg) + ### ------------------------------------------------------------------------------------- ### Welcome :+1:_`Fork and Star`_:+1:, then we'll let you know when we update @@ -9,6 +14,29 @@ The back-end of cross-modal retrieval system,wihch will contain services such The purpose of this project is to provide a set of applicable retrieval framework for the retrieval model. We will use RS image data as the baseline for development, and demonstrate the potential of the project through services such as semantic positioning and cross-modal retrieval. + +#### Summary + +* [Requirements](#requirements) +* [Apis](#apis) +* [Architecture](#architecture) +* [Three Steps to Use This Framework](#three-steps-to-use-this-framework) +* [Customize Your Rerieval Model](#customize-your-rerieval-model) +* [Citation](#citation) +### ------------------------------------------------------------------------------------- +### Requirements +```bash +numpy>=1.7.1 +six>=1.1.0 +PyTorch > 0.3 +flask >= 1.1.1 +Numpy +h5py +nltk +yaml +``` +------------------------------------------ + ### ------------------------------------------------------------------------------------- ### Apis ```bash @@ -16,13 +44,27 @@ We will use RS image data as the baseline for development, and demonstrate the p #/api/image_encode/ [POST] # FUNC: encode images -data = { - # image_id: file_path - 11:"../data/test_data/images/00013.jpg", - 33: "../data/test_data/images/00013.jpg", - 32: "../data/test_data/images/00013.jpg", -} -url = 'http://192.168.43.216:49205/api/image_encode/' +data = [ + { + "image_id": 11, + "image_path": "../data/test_data/images/00013.jpg", + "user_id": 1, + "privilege": 1 + }, + { + "image_id": 33, + "image_path": "../data/test_data/images/00013.jpg", + "user_id": 1, + "privilege": 1 + }, + { + "image_id": 32, + "image_path": "../data/test_data/images/00013.jpg", + "user_id": 2, + "privilege": 1 + } +] +url = 'http://192.168.97.241:33133/api/image_encode/' r = requests.post(url, data=json.dumps(data)) print(r.json()) @@ -34,8 +76,9 @@ print(r.json()) # FUNC: delete encodes # image_id -data = [3, 4] -url = 'http://192.168.43.216:49205/api/delete_encode/' +data = {"deleteID":"32"} +url = 'http://192.168.140.241:33133/api/delete_encode/' + r = requests.post(url, data=json.dumps(data)) print(r.json()) ``` @@ -46,12 +89,13 @@ print(r.json()) # FUNC: cross-modal retrieval data = { - 'text': "One block has a cross shaped roof church.", # retrieved text - 'retrieved_ids': "*", # retrieved images pool - 'start': 0, # from top - 'end': 100 # to end - } -url = 'http://192.168.43.216:49205/api/text_search/' + 'text': "One block has a cross shaped roof church.", + 'user_id': 1, + 'page_no': 1, + 'page_size': 10 +} +url = 'http://0.0.0.0:33133/api/text_search/' + r = requests.post(url, data=json.dumps(data)) print(r.json()) ``` @@ -62,12 +106,13 @@ print(r.json()) # FUNC: image-image retrieval data = { - 'image_path': "../data/test_data/images/00013.jpg",, # retrieved image - 'retrieved_ids': "*", # retrieved images pool: 1) * represents all, 2) [1, 2, 4] represent images pool - 'start': 0, # from top - 'end': 100 # to end - } -url = 'http://192.168.43.216:49205/api/image_search/' + 'image_path': "/data/test_data/images/00013.jpg", + 'user_id': 1, + 'page_no': 1, + 'page_size': 10 +} +url = 'http://0.0.0.0:33133/api/image_search/' + r = requests.post(url, data=json.dumps(data)) print(r.json()) ``` @@ -78,13 +123,17 @@ print(r.json()) # FUNC: semantic localization data = { - 'image_path': "../data/test_data/images/demo1.tif", - 'text': "there are two tennis courts beside the playground", - 'params': { - 'steps': [64, 128,256,512] - }, -} -url = 'http://192.168.43.216:49205/api/semantic_localization/' + "input_file": ["../data/test_data/images/demo1.tif"], + "output_file": [ + "../data/retrieval_system_data/semantic_localization_data/heatmap.png", + "../data/retrieval_system_data/semantic_localization_data/heatmap_add.png"], + "params": { + "text": "there are two tennis courts beside the playground", + "steps": [128,256,512] + } + } +url = 'http://192.168.97.241:33133/api/semantic_localization/' + r = requests.post(url, data=json.dumps(data)) print(r.json()) ``` @@ -110,5 +159,28 @@ print(r.json()) ``` ### ------------------------------------------------------------------------------------- -### Environments +### Three Steps to Use This Framework + +Step 1. Install the environment, download the code to the local, and change the path setting of the ./code/common/config file. At the same time, you need to change the yaml path file under ./code/models/options/ . + +Step 2. Enter the ./code directory and run main.py to start the flask service. + +Step 3. Use Postman etc. or python's built-in request service for sample requests. Some interface samples have been shown in ./test/test_qpi.py . + + +### ------------------------------------------------------------------------------------- +### Customize Your Rerieval Model + +You only need to change the ./code/models folder to make your retrieval model run in the service. For this, you should provide encoding interfaces and model initialization interfaces for different modal data. For more information about this, please see the README file under ./code/models/ . + ## Under Updating + +## Citation +If you feel this code helpful or use this code or dataset, please cite it as +``` +Z. Yuan et al., "Exploring a Fine-Grained Multiscale Method for Cross-Modal Remote Sensing Image Retrieval," in IEEE Transactions on Geoscience and Remote Sensing, doi: 10.1109/TGRS.2021.3078451. + +Z. Yuan et al., "A Lightweight Multi-scale Crossmodal Text-Image Retrieval Method In Remote Sensing," in IEEE Transactions on Geoscience and Remote Sensing, doi: 10.1109/TGRS.2021.3124252. +``` + + diff --git a/test/test_api.py b/test/test_api.py index a0de14f..515714e 100644 --- a/test/test_api.py +++ b/test/test_api.py @@ -3,12 +3,27 @@ def post_encode_image(): # 编码请求 # image_id: file_path - data = { - 11:"../data/test_data/images/00013.jpg", - 33: "../data/test_data/images/00013.jpg", - 32: "../data/test_data/images/00013.jpg", - } - url = 'http://192.168.43.216:49205/api/image_encode/' + data = [ + { + "image_id": 12, + "image_path": "../data/test_data/images/00013.jpg", + "user_id": 1, + "privilege": 1 + }, + { + "image_id": 33, + "image_path": "../data/test_data/images/00013.jpg", + "user_id": 1, + "privilege": 1 + }, + { + "image_id": 32, + "image_path": "../data/test_data/images/00013.jpg", + "user_id": 2, + "privilege": 1 + } + ] + url = 'http://192.168.140.241:33133/api/image_encode/' r = requests.post(url, data=json.dumps(data)) print(r.json()) @@ -16,8 +31,8 @@ def post_encode_image(): def post_delete_encode(): # 删除编码请求 # image_id - data = ['3'] - url = 'http://192.168.43.216:49205/api/delete_encode/' + data = {"deleteID":"32"} + url = 'http://192.168.140.241:33133/api/delete_encode/' r = requests.post(url, data=json.dumps(data)) print(r.json()) @@ -27,11 +42,11 @@ def post_t2i_rerieval(): # text data = { 'text': "One block has a cross shaped roof church.", - 'retrieved_ids': "*", - 'start': 0, - 'end': 100 + 'user_id': 2, + 'page_no': 2, + 'page_size': 0 } - url = 'http://192.168.43.216:49205/api/text_search/' + url = 'http://192.168.140.241:33133/api/text_search/' r = requests.post(url, data=json.dumps(data)) print(r.json()) @@ -41,11 +56,11 @@ def post_i2i_retrieval(): # image data = { 'image_path': "../data/test_data/images/00013.jpg", - 'retrieved_ids': "*", - 'start': 0, - 'end': 100 + 'user_id': 1, + 'page_no': 1, + 'page_size': 10 } - url = 'http://192.168.43.216:49205/api/image_search/' + url = 'http://192.168.140.241:33133/api/image_search/' r = requests.post(url, data=json.dumps(data)) print(r.json()) @@ -54,16 +69,23 @@ def post_semantic_localization(): # 语义定位请求 # # semantic localization data = { - 'image_path': "../data/test_data/images/demo1.tif", - 'text': "there are two tennis courts beside the playground", - 'params': { - 'steps': [128,256,512] - }, - } - url = 'http://192.168.43.216:49205/api/semantic_localization/' + "input_file": ["../data/test_data/images/demo1.tif"], + "output_file": [ + "../data/retrieval_system_data/semantic_localization_data/heatmap.png", + "../data/retrieval_system_data/semantic_localization_data/heatmap_add.png"], + "params": { + "text": "there are two tennis courts beside the playground", + "steps": [128,256,512] + } + } + url = 'http://192.168.140.241:33133/api/semantic_localization/' r = requests.post(url, data=json.dumps(data)) print(r.json()) if __name__=="__main__": - post_semantic_localization() \ No newline at end of file + # post_semantic_localization() + # post_encode_image() + # post_delete_encode() + # post_t2i_rerieval() + post_i2i_retrieval() diff --git a/test/test_data.py b/test/test_data.py index b5b8ae3..d11861c 100644 --- a/test/test_data.py +++ b/test/test_data.py @@ -5,5 +5,9 @@ def dict_load(name="rsd.pkl"): with open(name, 'rb') as f: return pickle.load(f) -pkl = dict_load("../data/retrieval_system_data/rsd.pkl") -print(pkl.keys()) \ No newline at end of file +pkl = dict_load("../data/retrieval_system_data/rsd/rsd.pkl") + +for k,v in pkl.items(): + print(k) + print(v) + print("========") \ No newline at end of file diff --git a/test/test_function.py b/test/test_function.py index 65a3d2c..ffcc475 100644 --- a/test/test_function.py +++ b/test/test_function.py @@ -1,16 +1,24 @@ import numpy as np -def l2norm(X, eps=1e-8): - """L2-normalize columns of X - """ - norm = np.array([sum([i**2 for i in X]) + eps for ii in X]) - X = np.divide(X, norm) - return X +data = [ + { + "image_id": 11, + "image_path": "../data/test_data/images/00013.jpg", + "user_id": 1, + "privilege": 1 + }, + { + "image_id": 33, + "image_path": "../data/test_data/images/00013.jpg", + "user_id": 1, + "privilege": 1 + }, + { + "image_id": 32, + "image_path": "../data/test_data/images/00013.jpg", + "user_id": 2, + "privilege": 1 + } +] -X = np.random.random_sample((512)) - -a = l2norm(X) -print(a) - -b = X/np.linalg.norm(X) -print(b) \ No newline at end of file +print(data) \ No newline at end of file