diff --git a/api/apps/kb_app.py b/api/apps/kb_app.py index ff7f0ae135a..c8086aa4209 100644 --- a/api/apps/kb_app.py +++ b/api/apps/kb_app.py @@ -29,6 +29,7 @@ from api.utils.api_utils import get_json_result from api import settings from rag.nlp import search +from api.constants import DATASET_NAME_LIMIT @manager.route('/create', methods=['post']) @@ -36,10 +37,19 @@ @validate_request("name") def create(): req = request.json - req["name"] = req["name"].strip() - req["name"] = duplicate_name( + dataset_name = req["name"] + if not isinstance(dataset_name, str): + return get_data_error_result(message="Dataset name must be string.") + if dataset_name == "": + return get_data_error_result(message="Dataset name can't be empty.") + if len(dataset_name) >= DATASET_NAME_LIMIT: + return get_data_error_result( + message=f"Dataset name length is {len(dataset_name)} which is large than {DATASET_NAME_LIMIT}") + + dataset_name = dataset_name.strip() + dataset_name = duplicate_name( KnowledgebaseService.query, - name=req["name"], + name=dataset_name, tenant_id=current_user.id, status=StatusEnum.VALID.value) try: @@ -73,7 +83,8 @@ def update(): if not KnowledgebaseService.query( created_by=current_user.id, id=req["kb_id"]): return get_json_result( - data=False, message='Only owner of knowledgebase authorized for this operation.', code=settings.RetCode.OPERATING_ERROR) + data=False, message='Only owner of knowledgebase authorized for this operation.', + code=settings.RetCode.OPERATING_ERROR) e, kb = KnowledgebaseService.get_by_id(req["kb_id"]) if not e: @@ -81,7 +92,8 @@ def update(): message="Can't find this knowledgebase!") if req["name"].lower() != kb.name.lower() \ - and len(KnowledgebaseService.query(name=req["name"], tenant_id=current_user.id, status=StatusEnum.VALID.value)) > 1: + and len( + KnowledgebaseService.query(name=req["name"], tenant_id=current_user.id, status=StatusEnum.VALID.value)) > 1: return get_data_error_result( message="Duplicated knowledgebase name.") @@ -152,10 +164,11 @@ def rm(): ) try: kbs = KnowledgebaseService.query( - created_by=current_user.id, id=req["kb_id"]) + created_by=current_user.id, id=req["kb_id"]) if not kbs: return get_json_result( - data=False, message='Only owner of knowledgebase authorized for this operation.', code=settings.RetCode.OPERATING_ERROR) + data=False, message='Only owner of knowledgebase authorized for this operation.', + code=settings.RetCode.OPERATING_ERROR) for doc in DocumentService.query(kb_id=req["kb_id"]): if not DocumentService.remove_document(doc, kbs[0].tenant_id): diff --git a/api/constants.py b/api/constants.py index 8d72c7e85aa..e6a97e2c1b1 100644 --- a/api/constants.py +++ b/api/constants.py @@ -23,3 +23,5 @@ RAG_FLOW_SERVICE_NAME = "ragflow" REQUEST_WAIT_SEC = 2 REQUEST_MAX_WAIT_SEC = 300 + +DATASET_NAME_LIMIT = 128 diff --git a/rag/utils/infinity_conn.py b/rag/utils/infinity_conn.py index 699e279e3b2..6f508e8bdde 100644 --- a/rag/utils/infinity_conn.py +++ b/rag/utils/infinity_conn.py @@ -310,7 +310,9 @@ def get( table_name = f"{indexName}_{knowledgebaseId}" table_instance = db_instance.get_table(table_name) kb_res = table_instance.output(["*"]).filter(f"id = '{chunkId}'").to_pl() - df_list.append(kb_res) + if len(kb_res) != 0 and kb_res.shape[0] > 0: + df_list.append(kb_res) + self.connPool.release_conn(inf_conn) res = concat_dataframes(df_list, ["id"]) res_fields = self.getFields(res, res.columns) diff --git a/sdk/python/test/test_frontend_api/common.py b/sdk/python/test/test_frontend_api/common.py index 4e44812b256..aa6e258e089 100644 --- a/sdk/python/test/test_frontend_api/common.py +++ b/sdk/python/test/test_frontend_api/common.py @@ -3,6 +3,8 @@ HOST_ADDRESS = os.getenv('HOST_ADDRESS', 'http://127.0.0.1:9380') +DATASET_NAME_LIMIT = 128 + def create_dataset(auth, dataset_name): authorization = {"Authorization": auth} url = f"{HOST_ADDRESS}/v1/kb/create" @@ -24,3 +26,9 @@ def rm_dataset(auth, dataset_id): json = {"kb_id": dataset_id} res = requests.post(url=url, headers=authorization, json=json) return res.json() + +def update_dataset(auth, json_req): + authorization = {"Authorization": auth} + url = f"{HOST_ADDRESS}/v1/kb/update" + res = requests.post(url=url, headers=authorization, json=json_req) + return res.json() diff --git a/sdk/python/test/test_frontend_api/test_dataset.py b/sdk/python/test/test_frontend_api/test_dataset.py index c78d8e0df07..c6e62fc2a7a 100644 --- a/sdk/python/test/test_frontend_api/test_dataset.py +++ b/sdk/python/test/test_frontend_api/test_dataset.py @@ -1,6 +1,8 @@ -from common import HOST_ADDRESS, create_dataset, list_dataset, rm_dataset -import requests - +from common import HOST_ADDRESS, create_dataset, list_dataset, rm_dataset, update_dataset, DATASET_NAME_LIMIT +import re +import pytest +import random +import string def test_dataset(get_auth): # create dataset @@ -56,8 +58,76 @@ def test_dataset_1k_dataset(get_auth): assert res.get("code") == 0, f"{res.get('message')}" print(f"{len(dataset_list)} datasets are deleted") -# delete dataset -# create invalid name dataset +def test_duplicated_name_dataset(get_auth): + # create dataset + for i in range(20): + res = create_dataset(get_auth, "test_create_dataset") + assert res.get("code") == 0, f"{res.get('message')}" + + # list dataset + res = list_dataset(get_auth, 1) + data = res.get("data") + dataset_list = [] + pattern = r'^test_create_dataset.*' + for item in data: + dataset_name = item.get("name") + dataset_id = item.get("id") + dataset_list.append(dataset_id) + match = re.match(pattern, dataset_name) + assert match != None + + for dataset_id in dataset_list: + res = rm_dataset(get_auth, dataset_id) + assert res.get("code") == 0, f"{res.get('message')}" + print(f"{len(dataset_list)} datasets are deleted") + +def test_invalid_name_dataset(get_auth): + # create dataset + # with pytest.raises(Exception) as e: + res = create_dataset(get_auth, 0) + assert res['code'] == 102 + + res = create_dataset(get_auth, "") + assert res['code'] == 102 + + long_string = "" + + while len(long_string) <= DATASET_NAME_LIMIT: + long_string += random.choice(string.ascii_letters + string.digits) + + res = create_dataset(get_auth, long_string) + assert res['code'] == 102 + print(res) + +def test_update_different_params_dataset(get_auth): + # create dataset + res = create_dataset(get_auth, "test_create_dataset") + assert res.get("code") == 0, f"{res.get('message')}" + + # list dataset + page_number = 1 + dataset_list = [] + while True: + res = list_dataset(get_auth, page_number) + data = res.get("data") + for item in data: + dataset_id = item.get("id") + dataset_list.append(dataset_id) + if len(dataset_list) < page_number * 150: + break + page_number += 1 + + print(f"found {len(dataset_list)} datasets") + dataset_id = dataset_list[0] + + json_req = {"kb_id": dataset_id, "name": "test_update_dataset", "description": "test", "permission": "me", "parser_id": "presentation"} + res = update_dataset(get_auth, json_req) + assert res.get("code") == 0, f"{res.get('message')}" + + # delete dataset + for dataset_id in dataset_list: + res = rm_dataset(get_auth, dataset_id) + assert res.get("code") == 0, f"{res.get('message')}" + print(f"{len(dataset_list)} datasets are deleted") + # update dataset with different parameters -# create duplicated name dataset -# diff --git a/printEnvironment.sh b/show_env.sh similarity index 95% rename from printEnvironment.sh rename to show_env.sh index 28bf3db6f3d..83c47635cbf 100644 --- a/printEnvironment.sh +++ b/show_env.sh @@ -15,7 +15,7 @@ get_distro_info() { echo "$distro_id $distro_version (Kernel version: $kernel_version)" } -# get Git repo name +# get Git repository name git_repo_name='' if git rev-parse --is-inside-work-tree > /dev/null 2>&1; then git_repo_name=$(basename "$(git rev-parse --show-toplevel)") @@ -48,8 +48,8 @@ else python_version="Python not installed" fi -# Print all infomation -echo "Current Repo: $git_repo_name" +# Print all information +echo "Current Repository: $git_repo_name" # get Commit ID git_version=$(git log -1 --pretty=format:'%h')