diff --git a/api/apps/__init__.py b/api/apps/__init__.py index 4fdeb9630b1..3d02b847bdc 100644 --- a/api/apps/__init__.py +++ b/api/apps/__init__.py @@ -18,40 +18,68 @@ import sys from importlib.util import module_from_spec, spec_from_file_location from pathlib import Path -from flask import Blueprint, Flask -from werkzeug.wrappers.request import Request +from typing import Union + +from apiflask import APIFlask, APIBlueprint, HTTPTokenAuth from flask_cors import CORS +from flask_login import LoginManager +from flask_session import Session +from itsdangerous.url_safe import URLSafeTimedSerializer as Serializer +from werkzeug.wrappers.request import Request from api.db import StatusEnum -from api.db.db_models import close_connection +from api.db.db_models import close_connection, APIToken from api.db.services import UserService -from api.utils import CustomJSONEncoder, commands - -from flask_session import Session -from flask_login import LoginManager +from api.settings import API_VERSION, access_logger, RAG_FLOW_SERVICE_NAME from api.settings import SECRET_KEY, stat_logger -from api.settings import API_VERSION, access_logger +from api.utils import CustomJSONEncoder, commands from api.utils.api_utils import server_error_response -from itsdangerous.url_safe import URLSafeTimedSerializer as Serializer __all__ = ['app'] - logger = logging.getLogger('flask.app') for h in access_logger.handlers: logger.addHandler(h) Request.json = property(lambda self: self.get_json(force=True, silent=True)) -app = Flask(__name__) -CORS(app, supports_credentials=True,max_age=2592000) +# Integrate APIFlask: Flask class -> APIFlask class. +app = APIFlask(__name__, title=RAG_FLOW_SERVICE_NAME, version=API_VERSION, docs_path=f'/{API_VERSION}/docs') +# Integrate APIFlask: Use apiflask.HTTPTokenAuth for the HTTP Bearer or API Keys authentication. +http_token_auth = HTTPTokenAuth() + + +# Current logged-in user class +class AuthUser: + def __init__(self, tenant_id, token): + self.id = tenant_id + self.token = token + + def get_token(self): + return self.token + + +# Verify if the token is valid +@http_token_auth.verify_token +def verify_token(token: str) -> Union[AuthUser, None]: + try: + objs = APIToken.query(token=token) + if objs: + api_token = objs[0] + user = AuthUser(api_token.tenant_id, api_token.token) + return user + except Exception as e: + server_error_response(e) + return None + + +CORS(app, supports_credentials=True, max_age=2592000) app.url_map.strict_slashes = False app.json_encoder = CustomJSONEncoder app.errorhandler(Exception)(server_error_response) - ## convince for dev and debug -#app.config["LOGIN_DISABLED"] = True +# app.config["LOGIN_DISABLED"] = True app.config["SESSION_PERMANENT"] = False app.config["SESSION_TYPE"] = "filesystem" app.config['MAX_CONTENT_LENGTH'] = int(os.environ.get("MAX_CONTENT_LENGTH", 128 * 1024 * 1024)) @@ -66,7 +94,9 @@ def search_pages_path(pages_dir): app_path_list = [path for path in pages_dir.glob('*_app.py') if not path.name.startswith('.')] api_path_list = [path for path in pages_dir.glob('*sdk/*.py') if not path.name.startswith('.')] + restful_api_path_list = [path for path in pages_dir.glob('*apis/*.py') if not path.name.startswith('.')] app_path_list.extend(api_path_list) + app_path_list.extend(restful_api_path_list) return app_path_list @@ -79,11 +109,17 @@ def register_page(page_path): spec = spec_from_file_location(module_name, page_path) page = module_from_spec(spec) page.app = app - page.manager = Blueprint(page_name, module_name) + # Integrate APIFlask: Blueprint class -> APIBlueprint class + page.manager = APIBlueprint(page_name, module_name) sys.modules[module_name] = page spec.loader.exec_module(page) page_name = getattr(page, 'page_name', page_name) - url_prefix = f'/api/{API_VERSION}/{page_name}' if "/sdk/" in path else f'/{API_VERSION}/{page_name}' + if "/sdk/" in path or "/apis/" in path: + url_prefix = f'/api/{API_VERSION}/{page_name}' + # elif "/apis/" in path: + # url_prefix = f'/{API_VERSION}/api/{page_name}' + else: + url_prefix = f'/{API_VERSION}/{page_name}' app.register_blueprint(page.manager, url_prefix=url_prefix) return url_prefix @@ -93,6 +129,7 @@ def register_page(page_path): Path(__file__).parent, Path(__file__).parent.parent / 'api' / 'apps', Path(__file__).parent.parent / 'api' / 'apps' / 'sdk', + Path(__file__).parent.parent / 'api' / 'apps' / 'apis', ] client_urls_prefix = [ @@ -123,4 +160,4 @@ def load_user(web_request): @app.teardown_request def _db_close(exc): - close_connection() \ No newline at end of file + close_connection() diff --git a/api/apps/apis/__init__.py b/api/apps/apis/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/api/apps/apis/datasets.py b/api/apps/apis/datasets.py new file mode 100644 index 00000000000..73c71af981c --- /dev/null +++ b/api/apps/apis/datasets.py @@ -0,0 +1,96 @@ +# +# Copyright 2024 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from api.apps import http_token_auth +from api.apps.services import dataset_service +from api.utils.api_utils import server_error_response, http_basic_auth_required + + +@manager.post('') +@manager.input(dataset_service.CreateDatasetReq, location='json') +@manager.auth_required(http_token_auth) +def create_dataset(json_data): + """Creates a new Dataset(Knowledgebase).""" + try: + tenant_id = http_token_auth.current_user.id + return dataset_service.create_dataset(tenant_id, json_data) + except Exception as e: + return server_error_response(e) + + +@manager.put('') +@manager.input(dataset_service.UpdateDatasetReq, location='json') +@manager.auth_required(http_token_auth) +def update_dataset(json_data): + """Updates a Dataset(Knowledgebase).""" + try: + tenant_id = http_token_auth.current_user.id + return dataset_service.update_dataset(tenant_id, json_data) + except Exception as e: + return server_error_response(e) + + +@manager.get('/') +@manager.auth_required(http_token_auth) +def get_dataset_by_id(kb_id): + """Query Dataset(Knowledgebase) by Dataset(Knowledgebase) ID.""" + try: + tenant_id = http_token_auth.current_user.id + return dataset_service.get_dataset_by_id(tenant_id, kb_id) + except Exception as e: + return server_error_response(e) + + +@manager.get('/search') +@manager.input(dataset_service.SearchDatasetReq, location='query') +@manager.auth_required(http_token_auth) +def get_dataset_by_name(query_data): + """Query Dataset(Knowledgebase) by Dataset(Knowledgebase) Name.""" + try: + tenant_id = http_token_auth.current_user.id + return dataset_service.get_dataset_by_name(tenant_id, query_data["name"]) + except Exception as e: + return server_error_response(e) + + +@manager.get('') +@manager.input(dataset_service.QueryDatasetReq, location='query') +@http_basic_auth_required +@manager.auth_required(http_token_auth) +def get_all_datasets(query_data): + """Query all Datasets(Knowledgebase)""" + try: + tenant_id = http_token_auth.current_user.id + return dataset_service.get_all_datasets( + tenant_id, + query_data['page'], + query_data['page_size'], + query_data['orderby'], + query_data['desc'], + ) + except Exception as e: + return server_error_response(e) + + +@manager.delete('/') +@manager.auth_required(http_token_auth) +def delete_dataset(kb_id): + """Deletes a Dataset(Knowledgebase).""" + try: + tenant_id = http_token_auth.current_user.id + return dataset_service.delete_dataset(tenant_id, kb_id) + except Exception as e: + return server_error_response(e) diff --git a/api/apps/apis/documents.py b/api/apps/apis/documents.py new file mode 100644 index 00000000000..312a4cdff9d --- /dev/null +++ b/api/apps/apis/documents.py @@ -0,0 +1,64 @@ +# +# Copyright 2024 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from api.apps import http_token_auth +from api.apps.services import document_service +from api.utils.api_utils import server_error_response + + +@manager.route('/change_parser', methods=['POST']) +@manager.input(document_service.ChangeDocumentParserReq, location='json') +@manager.auth_required(http_token_auth) +def change_document_parser(json_data): + """Change document file parser.""" + try: + return document_service.change_document_parser(json_data) + except Exception as e: + return server_error_response(e) + + +@manager.route('/run', methods=['POST']) +@manager.input(document_service.RunParsingReq, location='json') +@manager.auth_required(http_token_auth) +def run_parsing(json_data): + """Run parsing documents file.""" + try: + return document_service.run_parsing(json_data) + except Exception as e: + return server_error_response(e) + + +@manager.post('/upload') +@manager.input(document_service.UploadDocumentsReq, location='form_and_files') +@manager.auth_required(http_token_auth) +def upload_documents_2_dataset(form_and_files_data): + """Upload documents file a Dataset(Knowledgebase).""" + try: + tenant_id = http_token_auth.current_user.id + return document_service.upload_documents_2_dataset(form_and_files_data, tenant_id) + except Exception as e: + return server_error_response(e) + + +@manager.get('') +@manager.input(document_service.QueryDocumentsReq, location='query') +@manager.auth_required(http_token_auth) +def get_all_documents(query_data): + """Query documents file in Dataset(Knowledgebase).""" + try: + tenant_id = http_token_auth.current_user.id + return document_service.get_all_documents(query_data, tenant_id) + except Exception as e: + return server_error_response(e) diff --git a/api/apps/services/__init__.py b/api/apps/services/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/api/apps/services/dataset_service.py b/api/apps/services/dataset_service.py new file mode 100644 index 00000000000..d6c53cef5b0 --- /dev/null +++ b/api/apps/services/dataset_service.py @@ -0,0 +1,161 @@ +# +# Copyright 2024 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from apiflask import Schema, fields, validators + +from api.db import StatusEnum, FileSource, ParserType +from api.db.db_models import File +from api.db.services import duplicate_name +from api.db.services.document_service import DocumentService +from api.db.services.file2document_service import File2DocumentService +from api.db.services.file_service import FileService +from api.db.services.knowledgebase_service import KnowledgebaseService +from api.db.services.user_service import TenantService +from api.settings import RetCode +from api.utils import get_uuid +from api.utils.api_utils import get_json_result, get_data_error_result + + +class QueryDatasetReq(Schema): + page = fields.Integer(load_default=1) + page_size = fields.Integer(load_default=150) + orderby = fields.String(load_default='create_time') + desc = fields.Boolean(load_default=True) + + +class SearchDatasetReq(Schema): + name = fields.String(required=True) + + +class CreateDatasetReq(Schema): + name = fields.String(required=True) + + +class UpdateDatasetReq(Schema): + kb_id = fields.String(required=True) + name = fields.String(validate=validators.Length(min=1, max=128)) + description = fields.String(allow_none=True) + permission = fields.String(validate=validators.OneOf(['me', 'team'])) + embd_id = fields.String(validate=validators.Length(min=1, max=128)) + language = fields.String(validate=validators.OneOf(['Chinese', 'English'])) + parser_id = fields.String(validate=validators.OneOf([parser_type.value for parser_type in ParserType])) + parser_config = fields.Dict() + avatar = fields.String() + + +def get_all_datasets(user_id, offset, count, orderby, desc): + tenants = TenantService.get_joined_tenants_by_user_id(user_id) + datasets = KnowledgebaseService.get_by_tenant_ids_by_offset( + [m["tenant_id"] for m in tenants], user_id, int(offset), int(count), orderby, desc) + return get_json_result(data=datasets) + + +def get_tenant_dataset_by_id(tenant_id, kb_id): + kbs = KnowledgebaseService.query(tenant_id=tenant_id, id=kb_id) + if not kbs: + return get_data_error_result(retmsg="Can't find this knowledgebase!") + return get_json_result(data=kbs[0].to_dict()) + + +def get_dataset_by_id(tenant_id, kb_id): + kbs = KnowledgebaseService.query(created_by=tenant_id, id=kb_id) + if not kbs: + return get_data_error_result(retmsg="Can't find this knowledgebase!") + return get_json_result(data=kbs[0].to_dict()) + + +def get_dataset_by_name(tenant_id, kb_name): + e, kb = KnowledgebaseService.get_by_name(kb_name=kb_name, tenant_id=tenant_id) + if not e: + return get_json_result( + data=False, retmsg='You do not own the dataset.', + retcode=RetCode.OPERATING_ERROR) + return get_json_result(data=kb.to_dict()) + + +def create_dataset(tenant_id, data): + kb_name = data["name"].strip() + kb_name = duplicate_name( + KnowledgebaseService.query, + name=kb_name, + tenant_id=tenant_id, + status=StatusEnum.VALID.value + ) + e, t = TenantService.get_by_id(tenant_id) + if not e: + return get_data_error_result(retmsg="Tenant not found.") + kb = { + "id": get_uuid(), + "name": kb_name, + "tenant_id": tenant_id, + "created_by": tenant_id, + "embd_id": t.embd_id, + } + if not KnowledgebaseService.save(**kb): + return get_data_error_result() + return get_json_result(data={"kb_id": kb["id"]}) + + +def update_dataset(tenant_id, data): + kb_name = data["name"].strip() + kb_id = data["kb_id"].strip() + if not KnowledgebaseService.query( + created_by=tenant_id, id=kb_id): + return get_json_result( + data=False, retmsg=f'Only owner of knowledgebase authorized for this operation.', + retcode=RetCode.OPERATING_ERROR) + + e, kb = KnowledgebaseService.get_by_id(kb_id) + if not e: + return get_data_error_result( + retmsg="Can't find this knowledgebase!") + + if kb_name.lower() != kb.name.lower() and len( + KnowledgebaseService.query(name=kb_name, tenant_id=tenant_id, status=StatusEnum.VALID.value)) > 1: + return get_data_error_result( + retmsg="Duplicated knowledgebase name.") + + del data["kb_id"] + if not KnowledgebaseService.update_by_id(kb.id, data): + return get_data_error_result() + + e, kb = KnowledgebaseService.get_by_id(kb.id) + if not e: + return get_data_error_result( + retmsg="Database error (Knowledgebase rename)!") + + return get_json_result(data=kb.to_json()) + + +def delete_dataset(tenant_id, kb_id): + kbs = KnowledgebaseService.query(created_by=tenant_id, id=kb_id) + if not kbs: + return get_json_result( + data=False, retmsg=f'Only owner of knowledgebase authorized for this operation.', + retcode=RetCode.OPERATING_ERROR) + + for doc in DocumentService.query(kb_id=kb_id): + if not DocumentService.remove_document(doc, kbs[0].tenant_id): + return get_data_error_result( + retmsg="Database error (Document removal)!") + f2d = File2DocumentService.get_by_document_id(doc.id) + FileService.filter_delete([File.source_type == FileSource.KNOWLEDGEBASE, File.id == f2d[0].file_id]) + File2DocumentService.delete_by_document_id(doc.id) + + if not KnowledgebaseService.delete_by_id(kb_id): + return get_data_error_result( + retmsg="Database error (Knowledgebase removal)!") + return get_json_result(data=True) diff --git a/api/apps/services/document_service.py b/api/apps/services/document_service.py new file mode 100644 index 00000000000..12be2599d03 --- /dev/null +++ b/api/apps/services/document_service.py @@ -0,0 +1,161 @@ +# +# Copyright 2024 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import re + +from apiflask import Schema, fields, validators +from elasticsearch_dsl import Q + +from api.db import FileType, TaskStatus, ParserType +from api.db.db_models import Task +from api.db.services.document_service import DocumentService +from api.db.services.file2document_service import File2DocumentService +from api.db.services.file_service import FileService +from api.db.services.knowledgebase_service import KnowledgebaseService +from api.db.services.task_service import TaskService, queue_tasks +from api.db.services.user_service import UserTenantService +from api.settings import RetCode +from api.utils.api_utils import get_data_error_result +from api.utils.api_utils import get_json_result +from rag.nlp import search +from rag.utils.es_conn import ELASTICSEARCH + + +class QueryDocumentsReq(Schema): + kb_id = fields.String(required=True, error='Invalid kb_id parameter!') + keywords = fields.String(load_default='') + page = fields.Integer(load_default=1) + page_size = fields.Integer(load_default=150) + orderby = fields.String(load_default='create_time') + desc = fields.Boolean(load_default=True) + + +class ChangeDocumentParserReq(Schema): + doc_id = fields.String(required=True) + parser_id = fields.String( + required=True, validate=validators.OneOf([parser_type.value for parser_type in ParserType]) + ) + parser_config = fields.Dict() + + +class RunParsingReq(Schema): + doc_ids = fields.List(required=True) + run = fields.Integer(default=1) + + +class UploadDocumentsReq(Schema): + kb_id = fields.String(required=True) + file = fields.File(required=True) + + +def get_all_documents(query_data, tenant_id): + kb_id = query_data["kb_id"] + tenants = UserTenantService.query(user_id=tenant_id) + for tenant in tenants: + if KnowledgebaseService.query( + tenant_id=tenant.tenant_id, id=kb_id): + break + else: + return get_json_result( + data=False, retmsg=f'Only owner of knowledgebase authorized for this operation.', + retcode=RetCode.OPERATING_ERROR) + keywords = query_data["keywords"] + + page_number = query_data["page"] + items_per_page = query_data["page_size"] + orderby = query_data["orderby"] + desc = query_data["desc"] + docs, tol = DocumentService.get_by_kb_id( + kb_id, page_number, items_per_page, orderby, desc, keywords) + return get_json_result(data={"total": tol, "docs": docs}) + + +def upload_documents_2_dataset(form_and_files_data, tenant_id): + file_objs = form_and_files_data['file'] + dataset_id = form_and_files_data['kb_id'] + for file_obj in file_objs: + if file_obj.filename == '': + return get_json_result( + data=False, retmsg='No file selected!', retcode=RetCode.ARGUMENT_ERROR) + e, kb = KnowledgebaseService.get_by_id(dataset_id) + if not e: + raise LookupError(f"Can't find the knowledgebase with ID {dataset_id}!") + err, _ = FileService.upload_document(kb, file_objs, tenant_id) + if err: + return get_json_result( + data=False, retmsg="\n".join(err), retcode=RetCode.SERVER_ERROR) + return get_json_result(data=True) + + +def change_document_parser(json_data): + e, doc = DocumentService.get_by_id(json_data["doc_id"]) + if not e: + return get_data_error_result(retmsg="Document not found!") + if doc.parser_id.lower() == json_data["parser_id"].lower(): + if "parser_config" in json_data: + if json_data["parser_config"] == doc.parser_config: + return get_json_result(data=True) + else: + return get_json_result(data=True) + + if doc.type == FileType.VISUAL or re.search( + r"\.(ppt|pptx|pages)$", doc.name): + return get_data_error_result(retmsg="Not supported yet!") + + e = DocumentService.update_by_id(doc.id, + {"parser_id": json_data["parser_id"], "progress": 0, "progress_msg": "", + "run": TaskStatus.UNSTART.value}) + if not e: + return get_data_error_result(retmsg="Document not found!") + if "parser_config" in json_data: + DocumentService.update_parser_config(doc.id, json_data["parser_config"]) + if doc.token_num > 0: + e = DocumentService.increment_chunk_num(doc.id, doc.kb_id, doc.token_num * -1, doc.chunk_num * -1, + doc.process_duation * -1) + if not e: + return get_data_error_result(retmsg="Document not found!") + tenant_id = DocumentService.get_tenant_id(json_data["doc_id"]) + if not tenant_id: + return get_data_error_result(retmsg="Tenant not found!") + ELASTICSEARCH.deleteByQuery( + Q("match", doc_id=doc.id), idxnm=search.index_name(tenant_id)) + + return get_json_result(data=True) + + +def run_parsing(json_data): + for id in json_data["doc_ids"]: + run = str(json_data["run"]) + info = {"run": run, "progress": 0} + if run == TaskStatus.RUNNING.value: + info["progress_msg"] = "" + info["chunk_num"] = 0 + info["token_num"] = 0 + DocumentService.update_by_id(id, info) + tenant_id = DocumentService.get_tenant_id(id) + if not tenant_id: + return get_data_error_result(retmsg="Tenant not found!") + ELASTICSEARCH.deleteByQuery( + Q("match", doc_id=id), idxnm=search.index_name(tenant_id)) + + if run == TaskStatus.RUNNING.value: + TaskService.filter_delete([Task.doc_id == id]) + e, doc = DocumentService.get_by_id(id) + doc = doc.to_dict() + doc["tenant_id"] = tenant_id + bucket, name = File2DocumentService.get_minio_address(doc_id=doc["id"]) + queue_tasks(doc, bucket, name) + + return get_json_result(data=True) diff --git a/api/utils/api_utils.py b/api/utils/api_utils.py index c5b93d56f0a..cbe0343b35a 100644 --- a/api/utils/api_utils.py +++ b/api/utils/api_utils.py @@ -27,8 +27,10 @@ import requests from flask import ( Response, jsonify, send_file, make_response, - request as flask_request, + request as flask_request, current_app, ) +from flask_login import current_user +from flask_login.config import EXEMPT_METHODS from werkzeug.http import HTTP_STATUS_CODES from api.db.db_models import APIToken @@ -288,3 +290,21 @@ def decorated_function(*args, **kwargs): return func(*args, **kwargs) return decorated_function + + +def http_basic_auth_required(func): + @wraps(func) + def decorated_view(*args, **kwargs): + if 'Authorization' in flask_request.headers: + # If the request header contains a token, skip username and password verification + return func(*args, **kwargs) + if flask_request.method in EXEMPT_METHODS or current_app.config.get("LOGIN_DISABLED"): + pass + elif not current_user.is_authenticated: + return current_app.login_manager.unauthorized() + + if callable(getattr(current_app, "ensure_sync", None)): + return current_app.ensure_sync(func)(*args, **kwargs) + return func(*args, **kwargs) + + return decorated_view diff --git a/requirements.txt b/requirements.txt index 01e31520dad..70853a889bf 100644 --- a/requirements.txt +++ b/requirements.txt @@ -102,3 +102,4 @@ xgboost==2.1.0 xpinyin==0.7.6 yfinance==0.1.96 zhipuai==2.0.1 +apiflask==2.2.1 diff --git a/requirements_arm.txt b/requirements_arm.txt index 4a2b35f7400..96839f3c7bc 100644 --- a/requirements_arm.txt +++ b/requirements_arm.txt @@ -173,3 +173,4 @@ yfinance==0.1.96 pywencai==0.12.2 akshare==1.14.72 ranx==0.3.20 +apiflask==2.2.1 diff --git a/sdk/python/ragflow/ragflow.py b/sdk/python/ragflow/ragflow.py index d4fc6de3643..09ea56037fc 100644 --- a/sdk/python/ragflow/ragflow.py +++ b/sdk/python/ragflow/ragflow.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import List +from typing import List, Union import requests @@ -23,6 +23,7 @@ from .modules.chunk import Chunk + class RAGFlow: def __init__(self, user_key, base_url, version='v1'): """ @@ -75,6 +76,74 @@ def list_datasets(self, page: int = 1, page_size: int = 1024, orderby: str = "cr return result_list raise Exception(res["retmsg"]) + def get_all_datasets( + self, page: int = 1, page_size: int = 1024, orderby: str = "create_time", desc: bool = True + ) -> List[DataSet]: + res = self.get("/datasets", + {"page": page, "page_size": page_size, "orderby": orderby, "desc": desc}) + res = res.json() + if res.get("retmsg") == "success": + return res['data'] + raise Exception(res["retmsg"]) + + def get_dataset_by_name(self, name: str) -> List[DataSet]: + res = self.get("/datasets/search", + {"name": name}) + res = res.json() + if res.get("retmsg") == "success": + return res['data'] + raise Exception(res["retmsg"]) + + def change_document_parser(self, doc_id: str, parser_id: str, parser_config: dict): + res = self.post( + "/documents/change_parser", + { + "doc_id": doc_id, + "parser_id": parser_id, + "parser_config": parser_config, + } + ) + res = res.json() + if res.get("retmsg") == "success": + return res['data'] + raise Exception(res["retmsg"]) + + def upload_documents_2_dataset(self, kb_id: str, files: Union[dict, List[bytes]]): + files_data = {} + if isinstance(files, dict): + files_data = files + elif isinstance(files, list): + for idx, file in enumerate(files): + files_data[f'file_{idx}'] = file + else: + files_data['file'] = files + data = { + 'kb_id': kb_id, + } + res = requests.post(url=self.api_url + "/documents/upload", data=data, files=files_data) + res = res.json() + if res.get("retmsg") == "success": + return res['data'] + raise Exception(res["retmsg"]) + + def documents_run_parsing(self, doc_ids: list): + res = self.post("/documents/run", + {"doc_ids": doc_ids}) + res = res.json() + if res.get("retmsg") == "success": + return res['data'] + raise Exception(res["retmsg"]) + + def get_all_documents( + self, keywords: str = '', page: int = 1, page_size: int = 1024, + orderby: str = "create_time", desc: bool = True): + res = self.get("/documents", + {"page": page, "page_size": page_size, "orderby": orderby, "desc": desc}) + res = res.json() + if res.get("retmsg") == "success": + return res['data'] + raise Exception(res["retmsg"]) + def get_dataset(self, id: str = None, name: str = None) -> DataSet: res = self.get("/dataset/detail", {"id": id, "name": name}) res = res.json() diff --git a/sdk/python/test/test_dataset.py b/sdk/python/test/test_dataset.py index 8c2084a9053..55ca0db863b 100644 --- a/sdk/python/test/test_dataset.py +++ b/sdk/python/test/test_dataset.py @@ -22,12 +22,13 @@ def setup_method(self): Delete all the datasets. """ ragflow = RAGFlow(API_KEY, HOST_ADDRESS) - listed_data = ragflow.list_dataset() - listed_data = listed_data['data'] + # listed_data = ragflow.list_datasets() + # listed_data = listed_data['data'] - listed_names = {d['name'] for d in listed_data} - for name in listed_names: - ragflow.delete_dataset(name) + # listed_names = {d['name'] for d in listed_data} + # for name in listed_names: + # print(f'--dataset-- {name}') + # ragflow.delete_dataset(name) # -----------------------create_dataset--------------------------------- def test_create_dataset_with_success(self): @@ -146,7 +147,7 @@ def test_list_dataset_success(self): """ ragflow = RAGFlow(API_KEY, HOST_ADDRESS) # Call the list_datasets method - response = ragflow.list_dataset() + response = ragflow.list_datasets() assert response['code'] == RetCode.SUCCESS def test_list_dataset_with_checking_size_and_name(self): @@ -163,7 +164,7 @@ def test_list_dataset_with_checking_size_and_name(self): dataset_name = response['data']['dataset_name'] real_name_to_create.add(dataset_name) - response = ragflow.list_dataset(0, 3) + response = ragflow.list_datasets(0, 3) listed_data = response['data'] listed_names = {d['name'] for d in listed_data} @@ -185,7 +186,7 @@ def test_list_dataset_with_getting_empty_result(self): dataset_name = response['data']['dataset_name'] real_name_to_create.add(dataset_name) - response = ragflow.list_dataset(0, 0) + response = ragflow.list_datasets(0, 0) listed_data = response['data'] listed_names = {d['name'] for d in listed_data} @@ -208,7 +209,7 @@ def test_list_dataset_with_creating_100_knowledge_bases(self): dataset_name = response['data']['dataset_name'] real_name_to_create.add(dataset_name) - res = ragflow.list_dataset(0, 100) + res = ragflow.list_datasets(0, 100) listed_data = res['data'] listed_names = {d['name'] for d in listed_data} @@ -221,7 +222,7 @@ def test_list_dataset_with_showing_one_dataset(self): Test listing one dataset and verify the size of the dataset. """ ragflow = RAGFlow(API_KEY, HOST_ADDRESS) - response = ragflow.list_dataset(0, 1) + response = ragflow.list_datasets(0, 1) datasets = response['data'] assert len(datasets) == 1 and response['code'] == RetCode.SUCCESS @@ -230,7 +231,7 @@ def test_list_dataset_failure(self): Test listing datasets with IndexError. """ ragflow = RAGFlow(API_KEY, HOST_ADDRESS) - response = ragflow.list_dataset(-1, -1) + response = ragflow.list_datasets(-1, -1) assert "IndexError" in response['message'] and response['code'] == RetCode.EXCEPTION_ERROR def test_list_dataset_for_empty_datasets(self): @@ -238,7 +239,7 @@ def test_list_dataset_for_empty_datasets(self): Test listing datasets when the datasets are empty. """ ragflow = RAGFlow(API_KEY, HOST_ADDRESS) - response = ragflow.list_dataset() + response = ragflow.list_datasets() datasets = response['data'] assert len(datasets) == 0 and response['code'] == RetCode.SUCCESS @@ -263,7 +264,8 @@ def test_delete_dataset_with_not_existing_dataset(self): """ ragflow = RAGFlow(API_KEY, HOST_ADDRESS) res = ragflow.delete_dataset("weird_dataset") - assert res['code'] == RetCode.OPERATING_ERROR and res['message'] == 'The dataset cannot be found for your current account.' + assert res['code'] == RetCode.OPERATING_ERROR and res[ + 'message'] == 'The dataset cannot be found for your current account.' def test_delete_dataset_with_creating_100_datasets_and_deleting_100_datasets(self): """ @@ -346,7 +348,7 @@ def test_delete_dataset_with_name_with_space_in_the_head_and_tail_and_length_exc assert (res['code'] == RetCode.OPERATING_ERROR and res['message'] == 'The dataset cannot be found for your current account.') -# ---------------------------------get_dataset----------------------------------------- + # ---------------------------------get_dataset----------------------------------------- def test_get_dataset_with_success(self): """ @@ -366,7 +368,7 @@ def test_get_dataset_with_failure(self): res = ragflow.get_dataset("weird_dataset") assert res['code'] == RetCode.DATA_ERROR and res['message'] == "Can't find this dataset!" -# ---------------------------------update a dataset----------------------------------- + # ---------------------------------update a dataset----------------------------------- def test_update_dataset_without_existing_dataset(self): """ @@ -435,7 +437,7 @@ def test_update_dataset_with_empty_parameter(self): assert (res['code'] == RetCode.DATA_ERROR and res['message'] == 'Please input at least one parameter that you want to update!') -# ---------------------------------mix the different methods-------------------------- + # ---------------------------------mix the different methods-------------------------- def test_create_and_delete_dataset_together(self): """ @@ -466,3 +468,11 @@ def test_create_and_delete_dataset_together(self): res = ragflow.delete_dataset(name) assert res["code"] == RetCode.SUCCESS + def test_list_dataset_success(self): + """ + Test listing datasets with a successful outcome. + """ + ragflow = RAGFlow(API_KEY, HOST_ADDRESS) + # Call the get_all_datasets method + response = ragflow.get_all_datasets() + assert isinstance(response, list) diff --git a/web/tsconfig.json b/web/tsconfig.json index 824e6cc8df0..bc880a023f2 100644 --- a/web/tsconfig.json +++ b/web/tsconfig.json @@ -1,4 +1,4 @@ { "extends": "./src/.umi/tsconfig.json", - "@@/*": ["src/.umi/*"], + "@@/*": ["src/.umi/*"] }