Skip to content

Commit

Permalink
complete implementation of dataset SDK (infiniflow#2147)
Browse files Browse the repository at this point in the history
### What problem does this PR solve?

Complete implementation of dataset SDK.
infiniflow#1102

### Type of change


- [x] New Feature (non-breaking change which adds functionality)

---------

Co-authored-by: Feiue <[email protected]>
Co-authored-by: Kevin Hu <[email protected]>
  • Loading branch information
3 people authored Aug 29, 2024
1 parent 7d44054 commit 3eec945
Show file tree
Hide file tree
Showing 6 changed files with 259 additions and 85 deletions.
162 changes: 118 additions & 44 deletions api/apps/sdk/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,82 +15,156 @@
#
from flask import request

from api.db import StatusEnum
from api.db.db_models import APIToken
from api.db import StatusEnum, FileSource
from api.db.db_models import File
from api.db.services.document_service import DocumentService
from api.db.services.file2document_service import File2DocumentService
from api.db.services.file_service import FileService
from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.user_service import TenantService
from api.settings import RetCode
from api.utils import get_uuid
from api.utils.api_utils import get_data_error_result
from api.utils.api_utils import get_json_result
from api.utils.api_utils import get_json_result, token_required, get_data_error_result


@manager.route('/save', methods=['POST'])
def save():
@token_required
def save(tenant_id):
req = request.json
token = request.headers.get('Authorization').split()[1]
objs = APIToken.query(token=token)
if not objs:
return get_json_result(
data=False, retmsg='Token is not valid!"', retcode=RetCode.AUTHENTICATION_ERROR)
tenant_id = objs[0].tenant_id
e, t = TenantService.get_by_id(tenant_id)
if not e:
return get_data_error_result(retmsg="Tenant not found.")
if "id" not in req:
if "tenant_id" in req or "embd_id" in req:
return get_data_error_result(
retmsg="Tenant_id or embedding_model must not be provided")
if "name" not in req:
return get_data_error_result(
retmsg="Name is not empty!")
req['id'] = get_uuid()
req["name"] = req["name"].strip()
if req["name"] == "":
return get_data_error_result(
retmsg="Name is not empty")
if KnowledgebaseService.query(name=req["name"]):
retmsg="Name is not empty string!")
if KnowledgebaseService.query(name=req["name"], tenant_id=tenant_id, status=StatusEnum.VALID.value):
return get_data_error_result(
retmsg="Duplicated knowledgebase name")
retmsg="Duplicated knowledgebase name in creating dataset.")
req["tenant_id"] = tenant_id
req['created_by'] = tenant_id
req['embd_id'] = t.embd_id
if not KnowledgebaseService.save(**req):
return get_data_error_result(retmsg="Data saving error")
req.pop('created_by')
keys_to_rename = {'embd_id': "embedding_model", 'parser_id': 'parser_method',
'chunk_num': 'chunk_count', 'doc_num': 'document_count'}
for old_key,new_key in keys_to_rename.items():
if old_key in req:
req[new_key]=req.pop(old_key)
return get_data_error_result(retmsg="Create dataset error.(Database error)")
return get_json_result(data=req)
else:
if req["tenant_id"] != tenant_id or req["embd_id"] != t.embd_id:
return get_data_error_result(
retmsg="Can't change tenant_id or embedding_model")
if "tenant_id" in req:
if req["tenant_id"] != tenant_id:
return get_data_error_result(
retmsg="Can't change tenant_id.")

e, kb = KnowledgebaseService.get_by_id(req["id"])
if not e:
return get_data_error_result(
retmsg="Can't find this knowledgebase!")
if "embd_id" in req:
if req["embd_id"] != t.embd_id:
return get_data_error_result(
retmsg="Can't change embedding_model.")

if not KnowledgebaseService.query(
created_by=tenant_id, id=req["id"]):
return get_json_result(
data=False, retmsg=f'Only owner of knowledgebase authorized for this operation.',
data=False, retmsg='You do not own the dataset.',
retcode=RetCode.OPERATING_ERROR)

if req["chunk_num"] != kb.chunk_num or req['doc_num'] != kb.doc_num:
return get_data_error_result(
retmsg="Can't change document_count or chunk_count ")
e, kb = KnowledgebaseService.get_by_id(req["id"])

if kb.chunk_num > 0 and req['parser_id'] != kb.parser_id:
return get_data_error_result(
retmsg="if chunk count is not 0, parser method is not changable. ")
if "chunk_num" in req:
if req["chunk_num"] != kb.chunk_num:
return get_data_error_result(
retmsg="Can't change chunk_count.")

if "doc_num" in req:
if req['doc_num'] != kb.doc_num:
return get_data_error_result(
retmsg="Can't change document_count.")

if req["name"].lower() != kb.name.lower() \
and len(KnowledgebaseService.query(name=req["name"], tenant_id=req['tenant_id'],
status=StatusEnum.VALID.value)) > 0:
return get_data_error_result(
retmsg="Duplicated knowledgebase name.")
if "parser_id" in req:
if kb.chunk_num > 0 and req['parser_id'] != kb.parser_id:
return get_data_error_result(
retmsg="if chunk count is not 0, parse method is not changable.")
if "name" in req:
if req["name"].lower() != kb.name.lower() \
and len(KnowledgebaseService.query(name=req["name"], tenant_id=tenant_id,
status=StatusEnum.VALID.value)) > 0:
return get_data_error_result(
retmsg="Duplicated knowledgebase name in updating dataset.")

del req["id"]
req['created_by'] = tenant_id
if not KnowledgebaseService.update_by_id(kb.id, req):
return get_data_error_result(retmsg="Data update error ")
return get_data_error_result(retmsg="Update dataset error.(Database error)")
return get_json_result(data=True)


@manager.route('/delete', methods=['DELETE'])
@token_required
def delete(tenant_id):
req = request.args
kbs = KnowledgebaseService.query(
created_by=tenant_id, id=req["id"])
if not kbs:
return get_json_result(
data=False, retmsg='You do not own the dataset',
retcode=RetCode.OPERATING_ERROR)

for doc in DocumentService.query(kb_id=req["id"]):
if not DocumentService.remove_document(doc, kbs[0].tenant_id):
return get_data_error_result(
retmsg="Remove document error.(Database error)")
f2d = File2DocumentService.get_by_document_id(doc.id)
FileService.filter_delete([File.source_type == FileSource.KNOWLEDGEBASE, File.id == f2d[0].file_id])
File2DocumentService.delete_by_document_id(doc.id)

if not KnowledgebaseService.delete_by_id(req["id"]):
return get_data_error_result(
retmsg="Delete dataset error.(Database error)")
return get_json_result(data=True)


@manager.route('/list', methods=['GET'])
@token_required
def list_datasets(tenant_id):
page_number = int(request.args.get("page", 1))
items_per_page = int(request.args.get("page_size", 1024))
orderby = request.args.get("orderby", "create_time")
desc = bool(request.args.get("desc", True))
tenants = TenantService.get_joined_tenants_by_user_id(tenant_id)
kbs = KnowledgebaseService.get_by_tenant_ids(
[m["tenant_id"] for m in tenants], tenant_id, page_number, items_per_page, orderby, desc)
return get_json_result(data=kbs)


@manager.route('/detail', methods=['GET'])
@token_required
def detail(tenant_id):
req = request.args
if "id" in req:
id = req["id"]
kb = KnowledgebaseService.query(created_by=tenant_id, id=req["id"])
if not kb:
return get_json_result(
data=False, retmsg='You do not own the dataset',
retcode=RetCode.OPERATING_ERROR)
if "name" in req:
name = req["name"]
if kb[0].name != name:
return get_json_result(
data=False, retmsg='You do not own the dataset',
retcode=RetCode.OPERATING_ERROR)
e, k = KnowledgebaseService.get_by_id(id)
return get_json_result(data=k.to_dict())
else:
if "name" in req:
name = req["name"]
e, k = KnowledgebaseService.get_by_name(kb_name=name, tenant_id=tenant_id)
if not e:
return get_json_result(
data=False, retmsg='You do not own the dataset',
retcode=RetCode.OPERATING_ERROR)
return get_json_result(data=k.to_dict())
else:
return get_data_error_result(
retmsg="At least one of `id` or `name` must be provided.")
47 changes: 34 additions & 13 deletions api/utils/api_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,30 +13,32 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
import functools
import json
import random
import time
from base64 import b64encode
from functools import wraps
from hmac import HMAC
from io import BytesIO
from urllib.parse import quote, urlencode
from uuid import uuid1

import requests
from flask import (
Response, jsonify, send_file, make_response,
request as flask_request,
)
from werkzeug.http import HTTP_STATUS_CODES

from api.utils import json_dumps
from api.settings import RetCode
from api.db.db_models import APIToken
from api.settings import (
REQUEST_MAX_WAIT_SEC, REQUEST_WAIT_SEC,
stat_logger, CLIENT_AUTHENTICATION, HTTP_APP_KEY, SECRET_KEY
)
import requests
import functools
from api.settings import RetCode
from api.utils import CustomJSONEncoder
from uuid import uuid1
from base64 import b64encode
from hmac import HMAC
from urllib.parse import quote, urlencode
from api.utils import json_dumps

requests.models.complexjson.dumps = functools.partial(
json.dumps, cls=CustomJSONEncoder)
Expand Down Expand Up @@ -96,7 +98,6 @@ def get_exponential_backoff_interval(retries, full_jitter=False):

def get_json_result(retcode=RetCode.SUCCESS, retmsg='success',
data=None, job_id=None, meta=None):
import re
result_dict = {
"retcode": retcode,
"retmsg": retmsg,
Expand Down Expand Up @@ -145,7 +146,8 @@ def server_error_response(e):
return get_json_result(
retcode=RetCode.EXCEPTION_ERROR, retmsg=repr(e.args[0]), data=e.args[1])
if repr(e).find("index_not_found_exception") >= 0:
return get_json_result(retcode=RetCode.EXCEPTION_ERROR, retmsg="No chunk found, please upload file and parse it.")
return get_json_result(retcode=RetCode.EXCEPTION_ERROR,
retmsg="No chunk found, please upload file and parse it.")

return get_json_result(retcode=RetCode.EXCEPTION_ERROR, retmsg=repr(e))

Expand Down Expand Up @@ -190,7 +192,9 @@ def decorated_function(*_args, **_kwargs):
return get_json_result(
retcode=RetCode.ARGUMENT_ERROR, retmsg=error_string)
return func(*_args, **_kwargs)

return decorated_function

return wrapper


Expand All @@ -217,7 +221,7 @@ def get_json_result(retcode=RetCode.SUCCESS, retmsg='success', data=None):


def construct_response(retcode=RetCode.SUCCESS,
retmsg='success', data=None, auth=None):
retmsg='success', data=None, auth=None):
result_dict = {"retcode": retcode, "retmsg": retmsg, "data": data}
response_dict = {}
for key, value in result_dict.items():
Expand All @@ -235,6 +239,7 @@ def construct_response(retcode=RetCode.SUCCESS,
response.headers["Access-Control-Expose-Headers"] = "Authorization"
return response


def construct_result(code=RetCode.DATA_ERROR, message='data is missing'):
import re
result_dict = {"code": code, "message": re.sub(r"rag", "seceum", message, flags=re.IGNORECASE)}
Expand Down Expand Up @@ -263,7 +268,23 @@ def construct_error_response(e):
pass
if len(e.args) > 1:
return construct_json_result(code=RetCode.EXCEPTION_ERROR, message=repr(e.args[0]), data=e.args[1])
if repr(e).find("index_not_found_exception") >=0:
return construct_json_result(code=RetCode.EXCEPTION_ERROR, message="No chunk found, please upload file and parse it.")
if repr(e).find("index_not_found_exception") >= 0:
return construct_json_result(code=RetCode.EXCEPTION_ERROR,
message="No chunk found, please upload file and parse it.")

return construct_json_result(code=RetCode.EXCEPTION_ERROR, message=repr(e))


def token_required(func):
@wraps(func)
def decorated_function(*args, **kwargs):
token = flask_request.headers.get('Authorization').split()[1]
objs = APIToken.query(token=token)
if not objs:
return get_json_result(
data=False, retmsg='Token is not valid!', retcode=RetCode.AUTHENTICATION_ERROR
)
kwargs['tenant_id'] = objs[0].tenant_id
return func(*args, **kwargs)

return decorated_function
12 changes: 8 additions & 4 deletions sdk/python/ragflow/modules/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,17 @@ def to_json(self):
pr[name] = value
return pr


def post(self, path, param):
res = self.rag.post(path,param)
res = self.rag.post(path, param)
return res

def get(self, path, params=''):
res = self.rag.get(path,params)
def get(self, path, params):
res = self.rag.get(path, params)
return res

def rm(self, path, params):
res = self.rag.delete(path, params)
return res

def __str__(self):
return str(self.to_json())
28 changes: 23 additions & 5 deletions sdk/python/ragflow/modules/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,18 +21,36 @@ def __init__(self, rag, res_dict):
self.permission = "me"
self.document_count = 0
self.chunk_count = 0
self.parser_method = "naive"
self.parse_method = "naive"
self.parser_config = None
for k in list(res_dict.keys()):
if k == "embd_id":
res_dict["embedding_model"] = res_dict[k]
if k == "parser_id":
res_dict['parse_method'] = res_dict[k]
if k == "doc_num":
res_dict["document_count"] = res_dict[k]
if k == "chunk_num":
res_dict["chunk_count"] = res_dict[k]
if k not in self.__dict__:
res_dict.pop(k)
super().__init__(rag, res_dict)

def save(self):
def save(self) -> bool:
res = self.post('/dataset/save',
{"id": self.id, "name": self.name, "avatar": self.avatar, "tenant_id": self.tenant_id,
"description": self.description, "language": self.language, "embd_id": self.embedding_model,
"permission": self.permission,
"doc_num": self.document_count, "chunk_num": self.chunk_count, "parser_id": self.parser_method,
"doc_num": self.document_count, "chunk_num": self.chunk_count, "parser_id": self.parse_method,
"parser_config": self.parser_config.to_json()
})
res = res.json()
if not res.get("retmsg"): return True
raise Exception(res["retmsg"])
if res.get("retmsg") == "success": return True
raise Exception(res["retmsg"])

def delete(self) -> bool:
res = self.rm('/dataset/delete',
{"id": self.id})
res = res.json()
if res.get("retmsg") == "success": return True
raise Exception(res["retmsg"])
Loading

0 comments on commit 3eec945

Please sign in to comment.