Skip to content

Commit

Permalink
加入了用户和权限机制
Browse files Browse the repository at this point in the history
  • Loading branch information
xiaoyuan1996 committed Nov 12, 2021
1 parent d71b6d5 commit be49bf0
Show file tree
Hide file tree
Showing 35 changed files with 269 additions and 135 deletions.
Binary file added code/__pycache__/globalvar.cpython-37.pyc
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
5 changes: 2 additions & 3 deletions code/api_controlers/apis.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
from flask import Flask, request, jsonify
import json, os
import _thread
from threading import Timer

from api_controlers import base_function, image_encode_ctl, delete_encode_ctl,\
crossmodal_search_ctl, image_search_ctl, semantic_localization_ctl, utils
text_search_ctl, image_search_ctl, semantic_localization_ctl, utils

def api_run(cfg):
app = Flask(__name__) # Flask 初始化
Expand All @@ -28,7 +27,7 @@ def delete_encode():
@app.route(cfg['apis']['text_search']['route'], methods=['post'])
def text_search():
request_data = json.loads(request.data.decode('utf-8'))
return_json = crossmodal_search_ctl.text_search(request_data)
return_json = text_search_ctl.text_search(request_data)
return return_json

# 图像检索
Expand Down
4 changes: 3 additions & 1 deletion code/api_controlers/delete_encode_ctl.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,16 @@
def delete_encode(request_data):
logger.info("\nRequest json: {}".format(request_data))


request_data = [int(i) for i in request_data['deleteID'].split(",")]
if request_data != []:
# 删除数据
for k in request_data:
rsd = globalvar.get_value("rsd")
if k not in rsd.keys():
return utils.get_stand_return(False, "Key {} not found in encode pool.".format(k))
else:
rsd = utils.dict_delete(k, rsd)
rsd = utils.dict_delete(int(k), rsd)
globalvar.set_value("rsd", rsd)

utils.dict_save(rsd, cfg['data_paths']['rsd_path'])
Expand Down
10 changes: 6 additions & 4 deletions code/api_controlers/image_encode_ctl.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@ def image_encode_append(request_data):
logger.info("Request json: {}".format(request_data))

# 加入未编码数据
for k,v in request_data.items():
for item in request_data:
unembeded_images = globalvar.get_value("unembeded_images")
unembeded_images = utils.dict_insert(k, v, unembeded_images)
unembeded_images = utils.dict_insert(int(item["image_id"]), item, unembeded_images)
globalvar.set_value("unembeded_images", value=unembeded_images)

logger.info("Request append successfully for above request.\n")
Expand All @@ -33,13 +33,15 @@ def image_encode_runner():
if unembeded_images != {}:
logger.info("{} images in unembeded image pool have been detected ...".format(len(unembeded_images.keys())))
for img_id in list(unembeded_images.keys()):
img_path = unembeded_images[img_id]
img_path = unembeded_images[img_id]["image_path"]

image_vector = base_function.image_encoder_api(model, img_path)

# 更新rsd数据
rsd = globalvar.get_value("rsd")
rsd = utils.dict_insert(img_id, image_vector, rsd)
unembeded_images[img_id]["image_vector"] = image_vector
rsd = utils.dict_insert(img_id, unembeded_images[img_id], rsd)

globalvar.set_value("rsd", value=rsd)

# 删除未编码池
Expand Down
43 changes: 26 additions & 17 deletions code/api_controlers/image_search_ctl.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,35 +15,44 @@ def image_search(request_data):

# 检测请求完备性
if not isinstance(request_data, dict):
return utils.get_stand_return(False, "Request must be dicts, and have keys: image_path, retrieved_ids, start, end.")
return utils.get_stand_return(False, "Request must be dicts, and have keys: image_path, user_id, page_no, page_size.")
if 'image_path' not in request_data.keys():
return utils.get_stand_return(False, "Request must have keys: str image_path.")
if 'retrieved_ids' not in request_data.keys():
return utils.get_stand_return(False, "Request must have keys: list retrieved_ids, default = *.")
if 'start' not in request_data.keys():
return utils.get_stand_return(False, "Request must have keys: int start.")
if 'end' not in request_data.keys():
return utils.get_stand_return(False, "Request must have keys: int end.")
if 'user_id' not in request_data.keys():
return utils.get_stand_return(False, "Request must have keys: int user_id")
if 'page_no' not in request_data.keys():
return utils.get_stand_return(False, "Request must have keys: int page_no.")
if 'page_size' not in request_data.keys():
return utils.get_stand_return(False, "Request must have keys: int page_size.")

# 解析
image_path, retrieved_ids, start, end = request_data['image_path'], request_data['retrieved_ids'], request_data['start'], request_data['end']
image_path, user_id, page_no, page_size = request_data['image_path'], int(request_data['user_id']), int(request_data['page_no']), int(request_data['page_size'])

#编码文本
# 出错机制
if page_no <=0 or page_size <=0 :
return utils.get_stand_return(False, "Request page_no and page_size must >= 1.")

#编码图像
image_vector = base_function.image_encoder_api(model, image_path)

# 向量比对
logger.info("Parse request correct, start image retrieval ...")
time_start = time.time()

retrieval_results = {}
# 统计匹配数据
rsd = globalvar.get_value("rsd")
if retrieved_ids == "*": # 检索所有影像
for k in rsd.keys():
retrieval_results[k] = base_function.cosine_sim_api(image_vector, rsd[k])
else:
for k in retrieved_ids: # 检索指定影像
retrieval_results[k] = base_function.cosine_sim_api(image_vector, rsd[k])
sorted_keys = utils.sort_based_values(retrieval_results)[start:end] # 排序
rsd_retrieved, retrieval_results = {}, {}
for k,v in rsd.items():
if (rsd[k]["privilege"] == 1) or (rsd[k]["user_id"] == user_id):
rsd_retrieved[k] = v

# 计算
for k in rsd_retrieved.keys():
retrieval_results[k] = base_function.cosine_sim_api(image_vector, rsd[k]["image_vector"])

# 排序
start, end = page_size * (page_no-1), page_size * page_no
sorted_keys = utils.sort_based_values(retrieval_results)[start:end]

time_end = time.time()
logger.info("Retrieval finished in {:.4f}s.".format(time_end - time_start))
Expand Down
44 changes: 25 additions & 19 deletions code/api_controlers/semantic_localization_ctl.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,17 +57,18 @@ def split_image(img_path, steps):

logger.info("Image {} has been split successfully.".format(img_path))

def generate_heatmap(img_path, text):
def generate_heatmap(img_path, text, output_file_h, output_file_a):
subimages_dir = os.path.join(cfg['data_paths']['temp_path'], os.path.basename(img_path).split(".")[0]) +'_subimages'

heatmap_subdir = utils.create_random_dirs_name(cfg['data_paths']['temp_path'])
heatmap_dir = os.path.join(cfg['data_paths']['semantic_localization_path'], heatmap_subdir)
# heatmap_subdir = utils.create_random_dirs_name(cfg['data_paths']['temp_path'])
# heatmap_dir = os.path.join(cfg['data_paths']['semantic_localization_path'], heatmap_subdir)
# heatmap_dir = output_path

# 清除缓存
if os.path.exists(heatmap_dir):
utils.delete_dire(heatmap_dir)
else:
os.makedirs(heatmap_dir)
# if os.path.exists(heatmap_dir):
# utils.delete_dire(heatmap_dir)
# else:
# os.makedirs(heatmap_dir)

logger.info("Start calculate similarities ...")
cal_start = time.time()
Expand Down Expand Up @@ -121,40 +122,45 @@ def generate_heatmap(img_path, text):
logger.info("Generate heatmap in {}s".format(generate_end-generate_start))

# save
logger.info("Saving heatmap in {} ...".format(heatmap_dir))
cv2.imwrite(os.path.join(heatmap_dir, "heatmap.png"),heatmap)
cv2.imwrite(os.path.join(heatmap_dir, "heatmap_add.png"),img_add)
# logger.info("Saving heatmap in {} ...".format(heatmap_dir))
# cv2.imwrite(os.path.join(heatmap_dir, "heatmap.png"),heatmap)
# cv2.imwrite(os.path.join(heatmap_dir, "heatmap_add.png"),img_add)

logger.info("Saving heatmap in {} ...".format(output_file_h))
logger.info("Saving heatmap in {} ...".format(output_file_a))
cv2.imwrite( output_file_h ,heatmap)
cv2.imwrite( output_file_a ,img_add)
logger.info("Saved ok.")

# clear temp
utils.delete_dire(subimages_dir)
os.rmdir(subimages_dir)

return heatmap_dir
# return heatmap_dir


def semantic_localization(request_data):
logger.info("Request json: {}".format(request_data))

# 检测请求完备性
if not isinstance(request_data, dict):
return utils.get_stand_return(False, "Request must be dicts, and have keys: image_path, text, and params.")
if 'image_path' not in request_data.keys():
return utils.get_stand_return(False, "Request must have keys: str image_path.")
if 'text' not in request_data.keys():
return utils.get_stand_return(False, "Request must have keys: str text.")
return utils.get_stand_return(False, "Request must be dicts, and have keys: input_file, output_file, params.")
if 'input_file' not in request_data.keys():
return utils.get_stand_return(False, "Request must have keys: list input_file.")
if 'output_file' not in request_data.keys():
return utils.get_stand_return(False, "Request must have keys: list output_file.")
if ('params' in request_data.keys()) and ('steps' in request_data['params'].keys()):
steps = request_data['params']['steps']
else:
steps = [128,256,512]

# 解析
image_path, text, params = request_data['image_path'], request_data['text'], request_data['params']
image_path, text, params, output_file_h, output_file_a = request_data['input_file'][0], request_data['params']['text'], request_data['params'], request_data['output_file'][0], request_data['output_file'][1]

# 判断文件格式
if not (image_path.endswith('.tif') or image_path.endswith('.jpg') or image_path.endswith('.tiff') or image_path.endswith('.png')):
return utils.get_stand_return(False, "File format is uncorrect: only support .tif, .tiff, .jpg, and .png .")
else:
split_image(image_path, steps)
heatmap_dir = generate_heatmap(image_path, text)
return utils.get_stand_return(True, "Generate successfully in {}".format(heatmap_dir))
generate_heatmap(image_path, text, output_file_h, output_file_a)
return utils.get_stand_return(True, "Generate successfully.")
Original file line number Diff line number Diff line change
Expand Up @@ -15,35 +15,44 @@ def text_search(request_data):

# 检测请求完备性
if not isinstance(request_data, dict):
return utils.get_stand_return(False, "Request must be dicts, and have keys: text, retrieved_ids, start, end.")
return utils.get_stand_return(False, "Request must be dicts, and have keys: text, user_id, page_no, page_size.")
if 'text' not in request_data.keys():
return utils.get_stand_return(False, "Request must have keys: str text.")
if 'retrieved_ids' not in request_data.keys():
return utils.get_stand_return(False, "Request must have keys: list retrieved_ids, default = *.")
if 'start' not in request_data.keys():
return utils.get_stand_return(False, "Request must have keys: int start.")
if 'end' not in request_data.keys():
return utils.get_stand_return(False, "Request must have keys: int end.")
if 'user_id' not in request_data.keys():
return utils.get_stand_return(False, "Request must have keys: int user_id")
if 'page_no' not in request_data.keys():
return utils.get_stand_return(False, "Request must have keys: int page_no.")
if 'page_size' not in request_data.keys():
return utils.get_stand_return(False, "Request must have keys: int page_size.")

# 解析
text, retrieved_ids, start, end = request_data['text'], request_data['retrieved_ids'], request_data['start'], request_data['end']
text, user_id, page_no, page_size = request_data['text'], int(request_data['user_id']), int(request_data['page_no']), int(request_data['page_size'])

# 出错机制
if page_no <=0 or page_size <=0 :
return utils.get_stand_return(False, "Request page_no and page_size must >= 1.")

#编码文本
text_vector = base_function.text_encoder_api(model, vocab_word, text)

# 向量比对
# 开始向量比对
logger.info("Parse request correct, start cross-modal retrieval ...")
time_start = time.time()

retrieval_results = {}
# 统计匹配数据
rsd = globalvar.get_value("rsd")
if retrieved_ids == "*": # 检索所有影像
for k in rsd.keys():
retrieval_results[k] = base_function.cosine_sim_api(text_vector, rsd[k])
else:
for k in retrieved_ids: # 检索指定影像
retrieval_results[k] = base_function.cosine_sim_api(text_vector, rsd[k])
sorted_keys = utils.sort_based_values(retrieval_results)[start:end] # 排序
rsd_retrieved, retrieval_results = {}, {}
for k,v in rsd.items():
if (rsd[k]["privilege"] == 1) or (rsd[k]["user_id"] == user_id):
rsd_retrieved[k] = v

# 计算
for k in rsd_retrieved.keys():
retrieval_results[k] = base_function.cosine_sim_api(text_vector, rsd[k]["image_vector"])

# 排序
start, end = page_size * (page_no-1), page_size * page_no
sorted_keys = utils.sort_based_values(retrieval_results)[start:end]

time_end = time.time()
logger.info("Retrieval finished in {:.4f}s.".format(time_end - time_start))
Expand Down
5 changes: 3 additions & 2 deletions code/common/config.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
apis:
hosts:
ip: '192.168.43.216'
port: 49205
ip: '192.168.140.241'
port: 33133
image_encode:
route: '/api/image_encode/'
delete_encode:
Expand All @@ -14,6 +14,7 @@ apis:
route: '/api/semantic_localization/'
data_paths:
images_dir: '../data/test_data'
rsd_dir_path: '../data/retrieval_system_data/rsd'
rsd_path: '../data/retrieval_system_data/rsd/rsd.pkl'
semantic_localization_path: '../data/retrieval_system_data/semantic_localization_data'
temp_path: '../data/retrieval_system_data/tmp'
Expand Down
6 changes: 3 additions & 3 deletions code/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@
# 创建起始变量
logger.info("Create init variables")
globalvar.set_value("unembeded_images", value={})
globalvar.set_value("rsd", value=utils.init_rsd(cfg['data_paths']['rsd_dir_path']))
utils.create_dirs(cfg['data_paths']['rsd_path'])
globalvar.set_value("rsd", value=utils.init_rsd(cfg['data_paths']['rsd_path']))
utils.create_dirs(cfg['data_paths']['rsd_dir_path'])
utils.create_dirs(cfg['data_paths']['semantic_localization_path'])
utils.create_dirs(cfg['data_paths']['temp_path'])

Expand All @@ -36,4 +36,4 @@
# 开启接口
from api_controlers import apis
logger.info("Start apis and running ...\n")
apis.api_run(cfg)
apis.api_run(cfg)
Binary file added code/models/__pycache__/__init__.cpython-37.pyc
Binary file not shown.
Binary file added code/models/__pycache__/encoder.cpython-37.pyc
Binary file not shown.
Binary file added code/models/__pycache__/init.cpython-37.pyc
Binary file not shown.
Binary file added code/models/__pycache__/vocab.cpython-37.pyc
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file added code/models/layers/__pycache__/mca.cpython-37.pyc
Binary file not shown.
Binary file not shown.
Binary file modified data/retrieval_system_data/rsd/rsd.pkl
Binary file not shown.
Binary file not shown.
Binary file not shown.
Loading

0 comments on commit be49bf0

Please sign in to comment.