Skip to content

Commit

Permalink
fix:solve the dup retrieval issue (#575)
Browse files Browse the repository at this point in the history
* feat: merge partial env variable and skip validation in the dev

* fix: fix dup retrieval

* feat: update petercat-utils
  • Loading branch information
MadratJerry authored Dec 16, 2024
1 parent de3dc77 commit 61bd4ef
Show file tree
Hide file tree
Showing 9 changed files with 129 additions and 121 deletions.
27 changes: 19 additions & 8 deletions petercat_utils/rag_helper/retrieval.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,13 @@
import json
from typing import Any, Dict


from langchain_community.vectorstores import SupabaseVectorStore
from langchain_openai import OpenAIEmbeddings

from .github_file_loader import GithubFileLoader
from ..data_class import GitDocConfig, RAGGitDocConfig, S3Config
from ..db.client.supabase import get_client


TABLE_NAME = "rag_docs"
QUERY_NAME = "match_embedding_docs"
CHUNK_SIZE = 2000
Expand Down Expand Up @@ -118,15 +116,16 @@ def add_knowledge_by_doc(config: RAGGitDocConfig):
supabase = get_client()
is_doc_added_query = (
supabase.table(TABLE_NAME)
.select("id, repo_name, commit_id, file_path")
.select("id")
.eq("repo_name", config.repo_name)
.eq("commit_id", loader.commit_id)
.eq("file_path", config.file_path)
.limit(1)
.execute()
)
if not is_doc_added_query.data:
is_doc_equal_query = (
supabase.table(TABLE_NAME).select("*").eq("file_sha", loader.file_sha)
supabase.table(TABLE_NAME).select("id").eq("file_sha", loader.file_sha).limit(1)
).execute()
if not is_doc_equal_query.data:
# If there is no file with the same file_sha, perform embedding.
Expand All @@ -139,14 +138,26 @@ def add_knowledge_by_doc(config: RAGGitDocConfig):
)
return store
else:
# Prioritize obtaining the minimal set of records to avoid overlapping with the original records.
minimum_repeat_result = supabase.rpc('count_rag_docs_by_sha', {'file_sha_input': loader.file_sha}).execute()
target_filter = minimum_repeat_result.data[0]
# Copy the minimal set
insert_docs = (
supabase.table(TABLE_NAME)
.select("*")
.eq("repo_name", target_filter['repo_name'])
.eq("file_path", target_filter['file_path'])
.eq("file_sha", target_filter['file_sha'])
.execute()
)
new_commit_list = [
{
**{k: v for k, v in item.items() if k != "id"},
"repo_name": config.repo_name,
"commit_id": loader.commit_id,
"file_path": config.file_path,
}
for item in is_doc_equal_query.data
for item in insert_docs.data
]
insert_result = supabase.table(TABLE_NAME).insert(new_commit_list).execute()
return insert_result
Expand All @@ -169,9 +180,9 @@ def reload_knowledge(config: RAGGitDocConfig):


def search_knowledge(
query: str,
repo_name: str,
meta_filter: Dict[str, Any] = {},
query: str,
repo_name: str,
meta_filter: Dict[str, Any] = {},
):
retriever = init_retriever(
{"filter": {"metadata": meta_filter, "repo_name": repo_name}}
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "petercat_utils"
version = "0.1.39"
version = "0.1.40"
description = ""
authors = ["raoha.rh <[email protected]>"]
readme = "README.md"
Expand Down
137 changes: 68 additions & 69 deletions server/auth/middleware.py
Original file line number Diff line number Diff line change
@@ -1,88 +1,87 @@
import traceback
from typing import Awaitable, Callable

from fastapi import HTTPException, Request, status
from fastapi.responses import JSONResponse
from petercat_utils import get_env_variable
from fastapi.security import OAuth2PasswordBearer
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.responses import Response
from fastapi.security import OAuth2PasswordBearer

from core.dao.botDAO import BotDAO

WEB_URL = get_env_variable("WEB_URL")
ENVRIMENT = get_env_variable("PETERCAT_ENV", "development")
from env import ENVIRONMENT, WEB_URL

ALLOW_LIST = [
"/",
"/favicon.ico",
"/api/health_checker",
"/api/bot/list",
"/api/bot/detail",
"/api/github/app/webhook",
"/app/installation/callback",
"/",
"/favicon.ico",
"/api/health_checker",
"/api/bot/list",
"/api/bot/detail",
"/api/github/app/webhook",
"/app/installation/callback",
]

ANONYMOUS_USER_ALLOW_LIST = [
"/api/auth/userinfo",
"/api/chat/qa",
"/api/chat/stream_qa",
"/api/auth/userinfo",
"/api/chat/qa",
"/api/chat/stream_qa",
]

oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/auth/token")


class AuthMiddleWare(BaseHTTPMiddleware):

async def oauth(self, request: Request):
try:
referer = request.headers.get('referer')
origin = request.headers.get('origin')
if referer and referer.startswith(WEB_URL):
return True
token = await oauth2_scheme(request=request)
if token:
bot_dao = BotDAO()
bot = bot_dao.get_bot(bot_id=token)
return bot and (
"*" in bot.domain_whitelist
or
origin in bot.domain_whitelist
)
except HTTPException:
return False
async def dispatch(self, request: Request, call_next: Callable[[Request], Awaitable[Response]]) -> Response:
try:
# if ENVRIMENT == "development":
# return await call_next(request)
# Auth 相关的直接放过
if request.url.path.startswith("/api/auth"):
return await call_next(request)
if request.url.path in ALLOW_LIST:
return await call_next(request)
if await self.oauth(request=request):
return await call_next(request)

# 获取 session 中的用户信息
user = request.session.get("user")
if not user:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Unauthorized")
if user['sub'].startswith("client|"):
if request.url.path in ANONYMOUS_USER_ALLOW_LIST:
return await call_next(request)
else:
# 如果没有用户信息,返回 401 Unauthorized 错误
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Anonymous User Not Allow")
return await call_next(request)
except HTTPException as e:
print(traceback.format_exception(e))
# 处理 HTTP 异常
return JSONResponse(status_code=e.status_code, content={"detail": e.detail})
except Exception as e:
# 处理其他异常
return JSONResponse(status_code=500, content={"detail": f"Internal Server Error: {e}"})
async def oauth(self, request: Request):
try:
referer = request.headers.get('referer')
origin = request.headers.get('origin')
if referer and referer.startswith(WEB_URL):
return True

token = await oauth2_scheme(request=request)
if token:
bot_dao = BotDAO()
bot = bot_dao.get_bot(bot_id=token)
return bot and (
"*" in bot.domain_whitelist
or
origin in bot.domain_whitelist
)
except HTTPException:
return False

async def dispatch(self, request: Request, call_next: Callable[[Request], Awaitable[Response]]) -> Response:
try:
if ENVIRONMENT == "development":
return await call_next(request)

# Auth 相关的直接放过
if request.url.path.startswith("/api/auth"):
return await call_next(request)

if request.url.path in ALLOW_LIST:
return await call_next(request)

if await self.oauth(request=request):
return await call_next(request)

# 获取 session 中的用户信息
user = request.session.get("user")
if not user:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Unauthorized")

if user['sub'].startswith("client|"):
if request.url.path in ANONYMOUS_USER_ALLOW_LIST:
return await call_next(request)
else:
# 如果没有用户信息,返回 401 Unauthorized 错误
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Anonymous User Not Allow")

return await call_next(request)
except HTTPException as e:
print(traceback.format_exception(e))
# 处理 HTTP 异常
return JSONResponse(status_code=e.status_code, content={"detail": e.detail})
except Exception as e:
# 处理其他异常
return JSONResponse(status_code=500, content={"detail": f"Internal Server Error: {e}"})
22 changes: 11 additions & 11 deletions server/auth/router.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,18 @@
from github import Github
from core.dao.profilesDAO import ProfilesDAO
import secrets
from typing import Annotated, Optional

from authlib.integrations.starlette_client import OAuth
from fastapi import APIRouter, Request, HTTPException, status, Depends
from fastapi.responses import RedirectResponse, JSONResponse
import secrets
from petercat_utils import get_client, get_env_variable
from github import Github
from starlette.config import Config
from authlib.integrations.starlette_client import OAuth
from typing import Annotated, Optional

from auth.get_user_info import generateAnonymousUser, getUserInfoByToken, get_user_id
from auth.get_user_info import (
generateAnonymousUser,
getUserAccessToken,
getUserInfoByToken,
get_user_id,
)
from core.dao.profilesDAO import ProfilesDAO
from petercat_utils import get_client, get_env_variable

AUTH0_DOMAIN = get_env_variable("AUTH0_DOMAIN")

Expand All @@ -26,6 +25,7 @@
LOGIN_URL = f"{API_URL}/api/auth/login"

WEB_URL = get_env_variable("WEB_URL")

WEB_LOGIN_SUCCESS_URL = f"{WEB_URL}/user/login"
MARKET_URL = f"{WEB_URL}/market"

Expand Down Expand Up @@ -133,8 +133,8 @@ async def get_agreement_status(user_id: Optional[str] = Depends(get_user_id)):

@router.post("/accept/agreement", status_code=200)
async def bot_generator(
request: Request,
user_id: Annotated[str | None, Depends(get_user_id)] = None,
request: Request,
user_id: Annotated[str | None, Depends(get_user_id)] = None,
):
if not user_id:
raise HTTPException(status_code=401, detail="User not found")
Expand Down
6 changes: 6 additions & 0 deletions server/env.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# list all env variables
from petercat_utils import get_env_variable

WEB_URL = get_env_variable("WEB_URL")
ENVIRONMENT = get_env_variable("PETERCAT_ENV", "development")
API_URL = get_env_variable("API_URL")
21 changes: 10 additions & 11 deletions server/github_app/router.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import logging
from typing import Annotated

from fastapi import (
APIRouter,
BackgroundTasks,
Expand All @@ -8,27 +10,24 @@
Request,
status,
)
import logging
from fastapi.responses import RedirectResponse

from github import Auth, Github

from auth.get_user_info import get_user
from core.dao.repositoryConfigDAO import RepositoryConfigDAO
from core.models.bot import RepoBindBotRequest
from core.models.user import User

from env import WEB_URL
from github_app.handlers import get_handler
from github_app.purchased import PurchaseServer
from github_app.utils import (
get_private_key,
)

from petercat_utils import get_env_variable

REGIN_NAME = get_env_variable("AWS_REGION")
AWS_GITHUB_SECRET_NAME = get_env_variable("AWS_GITHUB_SECRET_NAME")
APP_ID = get_env_variable("X_GITHUB_APP_ID")
WEB_URL = get_env_variable("WEB_URL")

logger = logging.getLogger()
logger.setLevel("INFO")
Expand All @@ -51,9 +50,9 @@ def github_app_callback(code: str, installation_id: str, setup_action: str):

@router.post("/app/webhook")
async def github_app_webhook(
request: Request,
background_tasks: BackgroundTasks,
x_github_event: str = Header(...),
request: Request,
background_tasks: BackgroundTasks,
x_github_event: str = Header(...),
):
payload = await request.json()
if x_github_event == "marketplace_purchase":
Expand Down Expand Up @@ -86,7 +85,7 @@ async def github_app_webhook(

@router.get("/user/repos_installed_app")
def get_user_repos_installed_app(
user: Annotated[User | None, Depends(get_user)] = None
user: Annotated[User | None, Depends(get_user)] = None
):
"""
Get github user installed app repositories which saved in platform database.
Expand Down Expand Up @@ -116,8 +115,8 @@ def get_user_repos_installed_app(

@router.post("/repo/bind_bot", status_code=200)
def bind_bot_to_repo(
request: RepoBindBotRequest,
user: Annotated[User | None, Depends(get_user)] = None,
request: RepoBindBotRequest,
user: Annotated[User | None, Depends(get_user)] = None,
):
if user is None:
raise HTTPException(
Expand Down
Loading

0 comments on commit 61bd4ef

Please sign in to comment.