Skip to content

Commit

Permalink
complete implementation of dataset SDK
Browse files Browse the repository at this point in the history
  • Loading branch information
root committed Aug 29, 2024
1 parent fc172b4 commit 80f3d14
Show file tree
Hide file tree
Showing 6 changed files with 232 additions and 71 deletions.
156 changes: 112 additions & 44 deletions api/apps/sdk/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,82 +15,150 @@
#
from flask import request

from api.db import StatusEnum
from api.db.db_models import APIToken
from api.db import StatusEnum, FileSource
from api.db.db_models import File
from api.db.services.document_service import DocumentService
from api.db.services.file2document_service import File2DocumentService
from api.db.services.file_service import FileService
from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.user_service import TenantService
from api.settings import RetCode
from api.utils import get_uuid
from api.utils.api_utils import get_data_error_result
from api.utils.api_utils import get_json_result
from api.utils.api_utils import get_json_result,token_required,get_data_error_result



@manager.route('/save', methods=['POST'])
def save():
@token_required
def save(tenant_id):
req = request.json
token = request.headers.get('Authorization').split()[1]
objs = APIToken.query(token=token)
if not objs:
return get_json_result(
data=False, retmsg='Token is not valid!"', retcode=RetCode.AUTHENTICATION_ERROR)
tenant_id = objs[0].tenant_id
e, t = TenantService.get_by_id(tenant_id)
if not e:
return get_data_error_result(retmsg="Tenant not found.")
if "id" not in req:
if "tenant_id" in req or "embd_id" in req:
return get_data_error_result(
retmsg="tenant_id or embedding_model must not be provided")
if "name" not in req:
return get_data_error_result(
retmsg="Name is not empty!")
req['id'] = get_uuid()
req["name"] = req["name"].strip()
if req["name"] == "":
return get_data_error_result(
retmsg="Name is not empty")
if KnowledgebaseService.query(name=req["name"]):
retmsg="Name is not empty string!")
if KnowledgebaseService.query(name=req["name"], tenant_id=tenant_id, status=StatusEnum.VALID.value):
return get_data_error_result(
retmsg="Duplicated knowledgebase name")
retmsg="Duplicated knowledgebase name in creating dataset.")
req["tenant_id"] = tenant_id
req['created_by'] = tenant_id
req['embd_id'] = t.embd_id
if not KnowledgebaseService.save(**req):
return get_data_error_result(retmsg="Data saving error")
req.pop('created_by')
keys_to_rename = {'embd_id': "embedding_model", 'parser_id': 'parser_method',
'chunk_num': 'chunk_count', 'doc_num': 'document_count'}
for old_key,new_key in keys_to_rename.items():
if old_key in req:
req[new_key]=req.pop(old_key)
return get_data_error_result(retmsg="Create dataset error.(Database error)")
return get_json_result(data=req)
else:
if req["tenant_id"] != tenant_id or req["embd_id"] != t.embd_id:
return get_data_error_result(
retmsg="Can't change tenant_id or embedding_model")
if "tenant_id" in req:
if req["tenant_id"] != tenant_id:
return get_data_error_result(
retmsg="Can't change tenant_id.")

e, kb = KnowledgebaseService.get_by_id(req["id"])
if not e:
return get_data_error_result(
retmsg="Can't find this knowledgebase!")
if "embd_id" in req:
if req["embd_id"] != t.embd_id:
return get_data_error_result(
retmsg="Can't change embedding_model.")

if not KnowledgebaseService.query(
created_by=tenant_id, id=req["id"]):
return get_json_result(
data=False, retmsg=f'Only owner of knowledgebase authorized for this operation.',
data=False, retmsg='You do not own the dataset.',
retcode=RetCode.OPERATING_ERROR)

if req["chunk_num"] != kb.chunk_num or req['doc_num'] != kb.doc_num:
return get_data_error_result(
retmsg="Can't change document_count or chunk_count ")
e, kb = KnowledgebaseService.get_by_id(req["id"])

if kb.chunk_num > 0 and req['parser_id'] != kb.parser_id:
return get_data_error_result(
retmsg="if chunk count is not 0, parser method is not changable. ")
if "chunk_num" in req:
if req["chunk_num"] != kb.chunk_num:
return get_data_error_result(
retmsg="Can't change chunk_count.")

if "doc_num" in req:
if req['doc_num'] != kb.doc_num:
return get_data_error_result(
retmsg="Can't change document_count.")

if req["name"].lower() != kb.name.lower() \
and len(KnowledgebaseService.query(name=req["name"], tenant_id=req['tenant_id'],
status=StatusEnum.VALID.value)) > 0:
return get_data_error_result(
retmsg="Duplicated knowledgebase name.")
if "parser_id" in req:
if kb.chunk_num > 0 and req['parser_id'] != kb.parser_id:
return get_data_error_result(
retmsg="if chunk count is not 0, parse method is not changable.")
if "name" in req:
if req["name"].lower() != kb.name.lower() \
and len(KnowledgebaseService.query(name=req["name"], tenant_id=tenant_id,
status=StatusEnum.VALID.value)) > 0:
return get_data_error_result(
retmsg="Duplicated knowledgebase name in updating dataset.")

del req["id"]
req['created_by'] = tenant_id
if not KnowledgebaseService.update_by_id(kb.id, req):
return get_data_error_result(retmsg="Data update error ")
return get_data_error_result(retmsg="Update dataset error.(Database error)")
return get_json_result(data=True)


@manager.route('/delete', methods=['DELETE'])
@token_required
def delete(tenant_id):
req = request.args
kbs = KnowledgebaseService.query(
created_by=tenant_id, id=req["id"])
if not kbs:
return get_json_result(
data=False, retmsg='You do not own the dataset',
retcode=RetCode.OPERATING_ERROR)

for doc in DocumentService.query(kb_id=req["id"]):
if not DocumentService.remove_document(doc, kbs[0].tenant_id):
return get_data_error_result(
retmsg="Remove document error.(Database error)")
f2d = File2DocumentService.get_by_document_id(doc.id)
FileService.filter_delete([File.source_type == FileSource.KNOWLEDGEBASE, File.id == f2d[0].file_id])
File2DocumentService.delete_by_document_id(doc.id)

if not KnowledgebaseService.delete_by_id(req["id"]):
return get_data_error_result(
retmsg="Delete dataset error.(Database error)")
return get_json_result(data=True)


@manager.route('/list', methods=['GET'])
@token_required
def list_datasets(tenant_id):
page_number = int(request.args.get("page", 1))
items_per_page = int(request.args.get("page_size", 150))
orderby = request.args.get("orderby", "create_time")
desc = bool(request.args.get("desc", True))
tenants = TenantService.get_joined_tenants_by_user_id(tenant_id)
kbs = KnowledgebaseService.get_by_tenant_ids(
[m["tenant_id"] for m in tenants], tenant_id, page_number, items_per_page, orderby, desc)
return get_json_result(data=kbs)


@manager.route('/detail', methods=['GET'])
@token_required
def detail(tenant_id):
req = request.args
if "id" in req:
id = req["id"]
if "name" in req:
name = req["name"]
if not KnowledgebaseService.query(id=id, name=name, tenant_id=tenant_id, status=StatusEnum.VALID.value):
return get_json_result(data=None)
if not KnowledgebaseService.query(
created_by=tenant_id, id=req["id"]):
return get_json_result(
data=False, retmsg='You do not own the dataset',
retcode=RetCode.OPERATING_ERROR)
e, k = KnowledgebaseService.get_by_id(id)
return get_json_result(data=k.to_dict())
else:
if "name" in req:
name = req["name"]
e, k = KnowledgebaseService.get_by_name(kb_name=name, tenant_id=tenant_id)
return get_json_result(data=k.to_dict())
else:
return get_json_result(data=None)
15 changes: 15 additions & 0 deletions api/utils/api_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
)
from werkzeug.http import HTTP_STATUS_CODES

from api.db.db_models import APIToken
from api.utils import json_dumps
from api.settings import RetCode
from api.settings import (
Expand Down Expand Up @@ -267,3 +268,17 @@ def construct_error_response(e):
return construct_json_result(code=RetCode.EXCEPTION_ERROR, message="No chunk found, please upload file and parse it.")

return construct_json_result(code=RetCode.EXCEPTION_ERROR, message=repr(e))

def token_required(func):
@wraps(func)
def decorated_function(*args, **kwargs):
token = flask_request.headers.get('Authorization').split()[1]
objs = APIToken.query(token=token)
if not objs:
return get_json_result(
data=False, retmsg='Token is not valid!', retcode=RetCode.AUTHENTICATION_ERROR
)
kwargs['tenant_id'] = objs[0].tenant_id
return func(*args, **kwargs)

return decorated_function
12 changes: 8 additions & 4 deletions sdk/python/ragflow/modules/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,17 @@ def to_json(self):
pr[name] = value
return pr


def post(self, path, param):
res = self.rag.post(path,param)
res = self.rag.post(path, param)
return res

def get(self, path, params=''):
res = self.rag.get(path,params)
def get(self, path, params):
res = self.rag.get(path, params)
return res

def rm(self, path, params):
res = self.rag.delete(path, params)
return res

def __str__(self):
return str(self.to_json())
26 changes: 22 additions & 4 deletions sdk/python/ragflow/modules/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,18 +21,36 @@ def __init__(self, rag, res_dict):
self.permission = "me"
self.document_count = 0
self.chunk_count = 0
self.parser_method = "naive"
self.parse_method = "naive"
self.parser_config = None
for k in list(res_dict.keys()):
if k == "embd_id":
res_dict["embedding_model"]=res_dict[k]
if k == "parser_id":
res_dict['parse_method']=res_dict[k]
if k == "doc_num":
res_dict["document_count"]=res_dict[k]
if k == "chunk_num":
res_dict["chunk_count"]=res_dict[k]
if k not in self.__dict__:
res_dict.pop(k)
super().__init__(rag, res_dict)

def save(self):
def save(self) -> bool :
res = self.post('/dataset/save',
{"id": self.id, "name": self.name, "avatar": self.avatar, "tenant_id": self.tenant_id,
"description": self.description, "language": self.language, "embd_id": self.embedding_model,
"permission": self.permission,
"doc_num": self.document_count, "chunk_num": self.chunk_count, "parser_id": self.parser_method,
"doc_num": self.document_count, "chunk_num": self.chunk_count, "parser_id": self.parse_method,
"parser_config": self.parser_config.to_json()
})
res = res.json()
if not res.get("retmsg"): return True
if res.get("retmsg") == "success": return True
raise Exception(res["retmsg"])

def delete(self) -> bool:
res = self.rm('/dataset/delete',
{"id": self.id})
res = res.json()
if res.get("retmsg") == "success": return True
raise Exception(res["retmsg"])
53 changes: 39 additions & 14 deletions sdk/python/ragflow/ragflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import List

import requests

from .modules.dataset import DataSet
Expand All @@ -25,30 +27,53 @@ def __init__(self, user_key, base_url, version='v1'):
"""
self.user_key = user_key
self.api_url = f"{base_url}/api/{version}"
self.authorization_header = {"Authorization": "{} {}".format("Bearer",self.user_key)}
self.authorization_header = {"Authorization": "{} {}".format("Bearer", self.user_key)}

def post(self, path, param):
res = requests.post(url=self.api_url + path, json=param, headers=self.authorization_header)
return res

def get(self, path, params=''):
res = requests.get(self.api_url + path, params=params, headers=self.authorization_header)
def get(self, path, params=None):
res = requests.get(url=self.api_url + path, params=params, headers=self.authorization_header)
return res

def delete(self, path, params):
res = requests.delete(url=self.api_url + path, params=params, headers=self.authorization_header)
return res

def create_dataset(self, name:str,avatar:str="",description:str="",language:str="English",permission:str="me",
document_count:int=0,chunk_count:int=0,parser_method:str="naive",
parser_config:DataSet.ParserConfig=None):
def create_dataset(self, name: str, avatar: str = "", description: str = "", language: str = "English",
permission: str = "me",
document_count: int = 0, chunk_count: int = 0, parse_method: str = "naive",
parser_config: DataSet.ParserConfig = None) -> DataSet:
if parser_config is None:
parser_config = DataSet.ParserConfig(self, {"chunk_token_count":128,"layout_recognize": True, "delimiter":"\n!?。;!?","task_page_size":12})
parser_config=parser_config.to_json()
res=self.post("/dataset/save",{"name":name,"avatar":avatar,"description":description,"language":language,"permission":permission,
"doc_num": document_count,"chunk_num":chunk_count,"parser_id":parser_method,
"parser_config":parser_config
}
)
parser_config = DataSet.ParserConfig(self, {"chunk_token_count": 128, "layout_recognize": True,
"delimiter": "\n!?。;!?", "task_page_size": 12})
parser_config = parser_config.to_json()
res = self.post("/dataset/save",
{"name": name, "avatar": avatar, "description": description, "language": language,
"permission": permission,
"doc_num": document_count, "chunk_num": chunk_count, "parser_id": parse_method,
"parser_config": parser_config
}
)
res = res.json()
if not res.get("retmsg"):
if res.get("retmsg") == "success":
return DataSet(self, res["data"])
raise Exception(res["retmsg"])

def list_datasets(self) -> List[DataSet]:
res = self.get("/dataset/list")
res = res.json()
result_list = []
if res['data']:
for data in res['data']:
result_list.append(DataSet(self, data))
return result_list

def get_dataset(self, id: str = None, name: str = None) -> DataSet:
res = self.get("/dataset/detail", {"id": id, "name": name})
res = res.json()
print(res)
if res['data']:
return DataSet(self, res['data'])
return None
Loading

0 comments on commit 80f3d14

Please sign in to comment.