From 0d049f1be34bfbb438be79f6dd1704cc6ed7ec6a Mon Sep 17 00:00:00 2001 From: MadratJerry Date: Sun, 1 Dec 2024 16:15:03 +0800 Subject: [PATCH 1/3] feat: merge partial env variable and skip validation in the dev --- server/auth/middleware.py | 137 ++++++++++++++++++------------------ server/auth/router.py | 30 ++++---- server/env.py | 6 ++ server/github_app/router.py | 25 ++++--- server/main.py | 25 +++---- server/tests/test_main.py | 8 +-- 6 files changed, 117 insertions(+), 114 deletions(-) create mode 100644 server/env.py diff --git a/server/auth/middleware.py b/server/auth/middleware.py index 44a8912e..73244095 100644 --- a/server/auth/middleware.py +++ b/server/auth/middleware.py @@ -1,88 +1,87 @@ import traceback from typing import Awaitable, Callable + from fastapi import HTTPException, Request, status from fastapi.responses import JSONResponse -from petercat_utils import get_env_variable +from fastapi.security import OAuth2PasswordBearer from starlette.middleware.base import BaseHTTPMiddleware from starlette.responses import Response -from fastapi.security import OAuth2PasswordBearer from core.dao.botDAO import BotDAO - -WEB_URL = get_env_variable("WEB_URL") -ENVRIMENT = get_env_variable("PETERCAT_ENV", "development") +from env import ENVIRONMENT, WEB_URL ALLOW_LIST = [ - "/", - "/favicon.ico", - "/api/health_checker", - "/api/bot/list", - "/api/bot/detail", - "/api/github/app/webhook", - "/app/installation/callback", + "/", + "/favicon.ico", + "/api/health_checker", + "/api/bot/list", + "/api/bot/detail", + "/api/github/app/webhook", + "/app/installation/callback", ] ANONYMOUS_USER_ALLOW_LIST = [ - "/api/auth/userinfo", - "/api/chat/qa", - "/api/chat/stream_qa", + "/api/auth/userinfo", + "/api/chat/qa", + "/api/chat/stream_qa", ] oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/auth/token") + class AuthMiddleWare(BaseHTTPMiddleware): - async def oauth(self, request: Request): - try: - referer = request.headers.get('referer') - origin = request.headers.get('origin') - if referer and referer.startswith(WEB_URL): - return True - - token = await oauth2_scheme(request=request) - if token: - bot_dao = BotDAO() - bot = bot_dao.get_bot(bot_id=token) - return bot and ( - "*" in bot.domain_whitelist - or - origin in bot.domain_whitelist - ) - except HTTPException: - return False - - async def dispatch(self, request: Request, call_next: Callable[[Request], Awaitable[Response]]) -> Response: - try: - # if ENVRIMENT == "development": - # return await call_next(request) - - # Auth 相关的直接放过 - if request.url.path.startswith("/api/auth"): - return await call_next(request) - - if request.url.path in ALLOW_LIST: - return await call_next(request) - - if await self.oauth(request=request): - return await call_next(request) - - # 获取 session 中的用户信息 - user = request.session.get("user") - if not user: - raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Unauthorized") - - if user['sub'].startswith("client|"): - if request.url.path in ANONYMOUS_USER_ALLOW_LIST: - return await call_next(request) - else: - # 如果没有用户信息,返回 401 Unauthorized 错误 - raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Anonymous User Not Allow") - - return await call_next(request) - except HTTPException as e: - print(traceback.format_exception(e)) - # 处理 HTTP 异常 - return JSONResponse(status_code=e.status_code, content={"detail": e.detail}) - except Exception as e: - # 处理其他异常 - return JSONResponse(status_code=500, content={"detail": f"Internal Server Error: {e}"}) + async def oauth(self, request: Request): + try: + referer = request.headers.get('referer') + origin = request.headers.get('origin') + if referer and referer.startswith(WEB_URL): + return True + + token = await oauth2_scheme(request=request) + if token: + bot_dao = BotDAO() + bot = bot_dao.get_bot(bot_id=token) + return bot and ( + "*" in bot.domain_whitelist + or + origin in bot.domain_whitelist + ) + except HTTPException: + return False + + async def dispatch(self, request: Request, call_next: Callable[[Request], Awaitable[Response]]) -> Response: + try: + if ENVIRONMENT == "development": + return await call_next(request) + + # Auth 相关的直接放过 + if request.url.path.startswith("/api/auth"): + return await call_next(request) + + if request.url.path in ALLOW_LIST: + return await call_next(request) + + if await self.oauth(request=request): + return await call_next(request) + + # 获取 session 中的用户信息 + user = request.session.get("user") + if not user: + raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Unauthorized") + + if user['sub'].startswith("client|"): + if request.url.path in ANONYMOUS_USER_ALLOW_LIST: + return await call_next(request) + else: + # 如果没有用户信息,返回 401 Unauthorized 错误 + raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Anonymous User Not Allow") + + return await call_next(request) + except HTTPException as e: + print(traceback.format_exception(e)) + # 处理 HTTP 异常 + return JSONResponse(status_code=e.status_code, content={"detail": e.detail}) + except Exception as e: + # 处理其他异常 + return JSONResponse(status_code=500, content={"detail": f"Internal Server Error: {e}"}) diff --git a/server/auth/router.py b/server/auth/router.py index bb6d4f8d..e1b53727 100644 --- a/server/auth/router.py +++ b/server/auth/router.py @@ -1,12 +1,14 @@ +import secrets +from typing import Annotated + +from authlib.integrations.starlette_client import OAuth from fastapi import APIRouter, Request, HTTPException, status, Depends from fastapi.responses import RedirectResponse, JSONResponse -import secrets -from petercat_utils import get_client, get_env_variable from starlette.config import Config -from authlib.integrations.starlette_client import OAuth -from typing import Annotated from auth.get_user_info import generateAnonymousUser, getUserInfoByToken, get_user_id +from env import API_URL, WEB_URL +from petercat_utils import get_client, get_env_variable AUTH0_DOMAIN = get_env_variable("AUTH0_DOMAIN") @@ -14,11 +16,9 @@ CLIENT_ID = get_env_variable("AUTH0_CLIENT_ID") CLIENT_SECRET = get_env_variable("AUTH0_CLIENT_SECRET") -API_URL = get_env_variable("API_URL") CALLBACK_URL = f"{API_URL}/api/auth/callback" LOGIN_URL = f"{API_URL}/api/auth/login" -WEB_URL = get_env_variable("WEB_URL") WEB_LOGIN_SUCCESS_URL = f"{WEB_URL}/user/login" MARKET_URL = f"{WEB_URL}/market" @@ -42,6 +42,7 @@ responses={404: {"description": "Not found"}}, ) + async def getAnonymousUser(request: Request): clientId = request.query_params.get("clientId") if not clientId: @@ -53,6 +54,7 @@ async def getAnonymousUser(request: Request): request.session['user'] = data return data + @router.get("/login") async def login(request: Request): if CLIENT_ID is None: @@ -62,13 +64,15 @@ async def login(request: Request): redirect_response = await oauth.auth0.authorize_redirect(request, redirect_uri=CALLBACK_URL) return redirect_response + @router.get('/logout') async def logout(request: Request): request.session.pop('user', None) redirect = request.query_params.get('redirect') if redirect: return RedirectResponse(url=f'{redirect}', status_code=302) - return { "success": True } + return {"success": True} + @router.get("/callback") async def callback(request: Request): @@ -90,18 +94,20 @@ async def callback(request: Request): supabase.table("profiles").upsert(data).execute() return RedirectResponse(url=f'{WEB_LOGIN_SUCCESS_URL}', status_code=302) + @router.get("/userinfo") async def userinfo(request: Request): user = request.session.get('user') if not user: data = await getAnonymousUser(request) - return { "data": data, "status": 200} - return { "data": user, "status": 200} + return {"data": data, "status": 200} + return {"data": user, "status": 200} + @router.post("/accept/agreement", status_code=200) async def bot_generator( - request: Request, - user_id: Annotated[str | None, Depends(get_user_id)] = None, + request: Request, + user_id: Annotated[str | None, Depends(get_user_id)] = None, ): if not user_id: return JSONResponse( @@ -114,7 +120,7 @@ async def bot_generator( try: supabase = get_client() response = supabase.table("profiles").update({"agreement_accepted": True}).match({"id": user_id}).execute() - + if not response.data: return JSONResponse( content={ diff --git a/server/env.py b/server/env.py new file mode 100644 index 00000000..72f79905 --- /dev/null +++ b/server/env.py @@ -0,0 +1,6 @@ +# list all env variables +from petercat_utils import get_env_variable + +WEB_URL = get_env_variable("WEB_URL") +ENVIRONMENT = get_env_variable("PETERCAT_ENV", "development") +API_URL = get_env_variable("API_URL") diff --git a/server/github_app/router.py b/server/github_app/router.py index d91a2860..fbcdd664 100644 --- a/server/github_app/router.py +++ b/server/github_app/router.py @@ -1,4 +1,7 @@ +import logging +import time from typing import Annotated + from fastapi import ( APIRouter, BackgroundTasks, @@ -8,19 +11,17 @@ Request, status, ) -import logging from fastapi.responses import RedirectResponse - -import time from github import Auth, Github + from auth.get_user_info import get_user from core.dao.authorizationDAO import AuthorizationDAO from core.dao.repositoryConfigDAO import RepositoryConfigDAO +from core.models.authorization import Authorization from core.models.bot import RepoBindBotRequest from core.models.repository import RepositoryConfig -from core.models.authorization import Authorization from core.models.user import User - +from env import WEB_URL from github_app.handlers import get_handler from github_app.purchased import PurchaseServer from github_app.utils import ( @@ -30,13 +31,11 @@ get_private_key, get_user_orgs, ) - from petercat_utils import get_env_variable REGIN_NAME = get_env_variable("AWS_REGION") AWS_GITHUB_SECRET_NAME = get_env_variable("AWS_GITHUB_SECRET_NAME") APP_ID = get_env_variable("X_GITHUB_APP_ID") -WEB_URL = get_env_variable("WEB_URL") logger = logging.getLogger() logger.setLevel("INFO") @@ -97,9 +96,9 @@ def github_app_callback(code: str, installation_id: str, setup_action: str): @router.post("/app/webhook") async def github_app_webhook( - request: Request, - background_tasks: BackgroundTasks, - x_github_event: str = Header(...), + request: Request, + background_tasks: BackgroundTasks, + x_github_event: str = Header(...), ): payload = await request.json() if x_github_event == "marketplace_purchase": @@ -132,7 +131,7 @@ async def github_app_webhook( @router.get("/user/repos_installed_app") def get_user_repos_installed_app_( - user: Annotated[User | None, Depends(get_user)] = None + user: Annotated[User | None, Depends(get_user)] = None ): """ Get github user installed app repositories which saved in platform database. @@ -162,8 +161,8 @@ def get_user_repos_installed_app_( @router.post("/repo/bind_bot", status_code=200) def bind_bot_to_repo( - request: RepoBindBotRequest, - user: Annotated[User | None, Depends(get_user)] = None, + request: RepoBindBotRequest, + user: Annotated[User | None, Depends(get_user)] = None, ): if user is None: raise HTTPException( diff --git a/server/main.py b/server/main.py index a5d747d6..0d4c82fc 100644 --- a/server/main.py +++ b/server/main.py @@ -1,34 +1,29 @@ import os -from fastapi.responses import RedirectResponse import uvicorn - from fastapi import FastAPI -from starlette.middleware.sessions import SessionMiddleware from fastapi.middleware.cors import CORSMiddleware -from auth.cors_middleware import AuthCORSMiddleWare -from i18n.translations import I18nConfig, I18nMiddleware - -from auth.middleware import AuthMiddleWare -from petercat_utils import get_env_variable - +from fastapi.responses import RedirectResponse +from starlette.middleware.sessions import SessionMiddleware # Import fastapi routers from auth import router as auth_router +from auth.cors_middleware import AuthCORSMiddleWare +from auth.middleware import AuthMiddleWare +from aws import router as aws_router from bot import router as bot_router from chat import router as chat_router +from env import ENVIRONMENT, API_URL, WEB_URL +from github_app import router as github_app_router +from i18n.translations import I18nConfig, I18nMiddleware +from petercat_utils import get_env_variable from rag import router as rag_router from task import router as task_router -from github_app import router as github_app_router -from aws import router as aws_router from user import router as user_router AUTH0_DOMAIN = get_env_variable("AUTH0_DOMAIN") API_AUDIENCE = get_env_variable("API_IDENTIFIER") CLIENT_ID = get_env_variable("AUTH0_CLIENT_ID") -API_URL = get_env_variable("API_URL") -WEB_URL = get_env_variable("WEB_URL") -ENVRIMENT = get_env_variable("PETERCAT_ENV", "development") CALLBACK_URL = f"{API_URL}/api/auth/callback" is_dev = bool(get_env_variable("IS_DEV")) @@ -77,7 +72,7 @@ def home_page(): @app.get("/api/health_checker") def health_checker(): return { - "ENVRIMENT": ENVRIMENT, + "ENVIRONMENT": ENVIRONMENT, "API_URL": API_URL, "WEB_URL": WEB_URL, "CALLBACK_URL": CALLBACK_URL, diff --git a/server/tests/test_main.py b/server/tests/test_main.py index 3eb22630..e8d343e0 100644 --- a/server/tests/test_main.py +++ b/server/tests/test_main.py @@ -1,18 +1,16 @@ from fastapi.testclient import TestClient + +from env import ENVIRONMENT, WEB_URL, API_URL from petercat_utils import get_env_variable from main import app -API_URL = get_env_variable("API_URL") -WEB_URL = get_env_variable("WEB_URL") -ENVRIMENT = get_env_variable("PETERCAT_ENV", "development") - client = TestClient(app) def test_health_checker(): response = client.get("/api/health_checker") assert response.status_code == 200 assert response.json() == { - 'ENVRIMENT': ENVRIMENT, + 'ENVIRONMENT': ENVIRONMENT, 'API_URL': API_URL, 'CALLBACK_URL': f'{API_URL}/api/auth/callback', 'WEB_URL': WEB_URL, From 5690cf7ff2fcd85a96739e4950f6114062ff69c8 Mon Sep 17 00:00:00 2001 From: MadratJerry Date: Mon, 9 Dec 2024 07:44:06 +0800 Subject: [PATCH 2/3] fix: fix dup retrieval --- petercat_utils/rag_helper/retrieval.py | 27 ++++++++++++++++++-------- 1 file changed, 19 insertions(+), 8 deletions(-) diff --git a/petercat_utils/rag_helper/retrieval.py b/petercat_utils/rag_helper/retrieval.py index 5a3207dd..1cb889e1 100644 --- a/petercat_utils/rag_helper/retrieval.py +++ b/petercat_utils/rag_helper/retrieval.py @@ -1,7 +1,6 @@ import json from typing import Any, Dict - from langchain_community.vectorstores import SupabaseVectorStore from langchain_openai import OpenAIEmbeddings @@ -9,7 +8,6 @@ from ..data_class import GitDocConfig, RAGGitDocConfig, S3Config from ..db.client.supabase import get_client - TABLE_NAME = "rag_docs" QUERY_NAME = "match_embedding_docs" CHUNK_SIZE = 2000 @@ -118,15 +116,16 @@ def add_knowledge_by_doc(config: RAGGitDocConfig): supabase = get_client() is_doc_added_query = ( supabase.table(TABLE_NAME) - .select("id, repo_name, commit_id, file_path") + .select("id") .eq("repo_name", config.repo_name) .eq("commit_id", loader.commit_id) .eq("file_path", config.file_path) + .limit(1) .execute() ) if not is_doc_added_query.data: is_doc_equal_query = ( - supabase.table(TABLE_NAME).select("*").eq("file_sha", loader.file_sha) + supabase.table(TABLE_NAME).select("id").eq("file_sha", loader.file_sha).limit(1) ).execute() if not is_doc_equal_query.data: # If there is no file with the same file_sha, perform embedding. @@ -139,6 +138,18 @@ def add_knowledge_by_doc(config: RAGGitDocConfig): ) return store else: + # Prioritize obtaining the minimal set of records to avoid overlapping with the original records. + minimum_repeat_result = supabase.rpc('count_rag_docs_by_sha', {'file_sha_input': loader.file_sha}).execute() + target_filter = minimum_repeat_result.data[0] + # Copy the minimal set + insert_docs = ( + supabase.table(TABLE_NAME) + .select("*") + .eq("repo_name", target_filter['repo_name']) + .eq("file_path", target_filter['file_path']) + .eq("file_sha", target_filter['file_sha']) + .execute() + ) new_commit_list = [ { **{k: v for k, v in item.items() if k != "id"}, @@ -146,7 +157,7 @@ def add_knowledge_by_doc(config: RAGGitDocConfig): "commit_id": loader.commit_id, "file_path": config.file_path, } - for item in is_doc_equal_query.data + for item in insert_docs.data ] insert_result = supabase.table(TABLE_NAME).insert(new_commit_list).execute() return insert_result @@ -169,9 +180,9 @@ def reload_knowledge(config: RAGGitDocConfig): def search_knowledge( - query: str, - repo_name: str, - meta_filter: Dict[str, Any] = {}, + query: str, + repo_name: str, + meta_filter: Dict[str, Any] = {}, ): retriever = init_retriever( {"filter": {"metadata": meta_filter, "repo_name": repo_name}} From 5872ecc4c18331b3df7cea24a9920715b846371c Mon Sep 17 00:00:00 2001 From: MadratJerry Date: Sun, 15 Dec 2024 00:42:51 +0800 Subject: [PATCH 3/3] feat: update petercat-utils --- pyproject.toml | 2 +- subscriber/requirements.txt | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index ef116755..6790e442 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "petercat_utils" -version = "0.1.39" +version = "0.1.40" description = "" authors = ["raoha.rh "] readme = "README.md" diff --git a/subscriber/requirements.txt b/subscriber/requirements.txt index 23c995c0..3ff51c19 100644 --- a/subscriber/requirements.txt +++ b/subscriber/requirements.txt @@ -1 +1 @@ -petercat_utils>=0.1.39 +petercat_utils>=0.1.40