Skip to content

Commit

Permalink
Merge pull request #88 from ant-xuexiao/feat_git_rag
Browse files Browse the repository at this point in the history
feat: add knowledge form s3
  • Loading branch information
xingwanying authored Apr 16, 2024
2 parents 23497f4 + 747af83 commit 867f7f5
Show file tree
Hide file tree
Showing 12 changed files with 85 additions and 57 deletions.
5 changes: 4 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
# See https://help.github.com/articles/ignoring-files/ for more about ignoring files.
__pycache__/
.pyc
*.pyc
# dependencies
/node_modules
/.pnp
.pnp.js

server/temp/

# testing
/coverage

Expand Down Expand Up @@ -39,3 +41,4 @@ next-env.d.ts
.yarn
/server/.aws-sam/*
.aws-sam/*

5 changes: 1 addition & 4 deletions app/api/chat/retrieval/route.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,7 @@ import { ChatOpenAI } from 'langchain/chat_models/openai';
import { PromptTemplate } from 'langchain/prompts';
import { SupabaseVectorStore } from 'langchain/vectorstores/supabase';
import { Document } from 'langchain/document';
import {
RunnableSequence,
RunnablePassthrough,
} from 'langchain/schema/runnable';
import { RunnableSequence } from 'langchain/schema/runnable';
import {
BytesOutputParser,
StringOutputParser,
Expand Down
10 changes: 10 additions & 0 deletions server/Dockerfile.aws.lambda
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,14 @@ COPY . ${LAMBDA_TASK_ROOT}
COPY requirements.txt .
RUN pip3 install -r requirements.txt --target "${LAMBDA_TASK_ROOT}" -U --no-cache-dir

# Setup NLTK again in system path to execute nltk.downloader
RUN pip install nltk
# Setup directory for NLTK_DATA
RUN mkdir -p /opt/nltk_data

# Download NLTK_DATA to build directory
RUN python -W ignore -m nltk.downloader punkt -d /opt/nltk_data
RUN python -W ignore -m nltk.downloader stopwords -d /opt/nltk_data
RUN python -W ignore -m nltk.downloader averaged_perceptron_tagger -d /opt/nltk_data

CMD ["python", "main.py"]
7 changes: 6 additions & 1 deletion server/data_class.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from typing import Optional
from pydantic import BaseModel


Expand All @@ -17,4 +18,8 @@ class ChatData(BaseModel):
class ExecuteMessage(BaseModel):
type: str
repo: str
path: str
path: str

class S3Config(BaseModel):
s3_bucket: str
file_path: Optional[str] = None
16 changes: 4 additions & 12 deletions server/main.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import os
from rag import retrieval

import uvicorn
from fastapi import FastAPI
from fastapi.responses import StreamingResponse
Expand All @@ -11,7 +11,7 @@
from data_class import ChatData

# Import fastapi routers
from routers import health_checker, github
from routers import health_checker, github, rag

open_api_key = get_env_variable("OPENAI_API_KEY")
is_dev = bool(get_env_variable("IS_DEV"))
Expand All @@ -33,25 +33,17 @@

app.include_router(health_checker.router)
app.include_router(github.router)

app.include_router(rag.router)

@app.post("/api/chat/stream", response_class=StreamingResponse)
def run_agent_chat(input_data: ChatData):
result = stream.agent_chat(input_data, open_api_key)
return StreamingResponse(result, media_type="text/event-stream")

@app.post("/api/rag/add_knowledge")
def add_knowledge():
data=retrieval.add_knowledge()
return data

@app.post("/api/rag/search_knowledge")
def search_knowledge(query: str):
data=retrieval.search_knowledge(query)
return data

if __name__ == "__main__":
if is_dev:
uvicorn.run("main:app", host="0.0.0.0", port=int(os.environ.get("PORT", "8080")), reload=True)
else:
uvicorn.run(app, host="0.0.0.0", port=int(os.environ.get("PORT", "8080")))
uvicorn.run(app, host="0.0.0.0", port=int(os.environ.get("PORT", "8080")))
60 changes: 31 additions & 29 deletions server/rag/retrieval.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,22 @@
import os
import json
from langchain_community.document_loaders import TextLoader
import boto3
from langchain_openai import OpenAIEmbeddings
from langchain_text_splitters import CharacterTextSplitter
from langchain_community.vectorstores import SupabaseVectorStore
from db.supabase.client import get_client
from data_class import S3Config
from uilts.env import get_env_variable
from langchain_community.document_loaders import S3DirectoryLoader


supabase_url = get_env_variable("SUPABASE_URL")
supabase_key = get_env_variable("SUPABASE_SERVICE_KEY")


table_name="antd_knowledge"
query_name="match_antd_knowledge"
chunk_size=500
chunk_size=2000


def convert_document_to_dict(document):
return {
Expand All @@ -32,36 +37,33 @@ def init_retriever():

return db.as_retriever()

def add_knowledge():
current_dir = os.path.dirname(os.path.abspath(__file__))
target_file_path = os.path.join(current_dir, "../docs/test.md")
loader = TextLoader(target_file_path)
documents = loader.load()
text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
docs = text_splitter.split_documents(documents)
embeddings = OpenAIEmbeddings()

def add_knowledge(config: S3Config):
try:
SupabaseVectorStore.from_documents(
docs,
embeddings,
client=supabase,
table_name=table_name,
query_name=query_name,
chunk_size=chunk_size,
)
return json.dumps({
"success": True,
"message": "Knowledge added successfully!"
})
loader = S3DirectoryLoader(config.s3_bucket, prefix=config.file_path)
documents = loader.load()
text_splitter = CharacterTextSplitter(chunk_size=2000, chunk_overlap=0)
docs = text_splitter.split_documents(documents)
embeddings = OpenAIEmbeddings()
SupabaseVectorStore.from_documents(
docs,
embeddings,
client=get_client(),
table_name=table_name,
query_name=query_name,
chunk_size=chunk_size,
)
return json.dumps({
"success": True,
"message": "Knowledge added successfully!",
"docs_len": len(documents)
})
except Exception as e:
return json.dumps({
"success": False,
"message": str(e)
})
return json.dumps({
"success": False,
"message": str(e)
})



def search_knowledge(query: str):
retriever = init_retriever()
docs = retriever.get_relevant_documents(query)
Expand Down
3 changes: 2 additions & 1 deletion server/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,5 @@ load_dotenv
supabase
boto3>=1.34.84
pyjwt>=2.4.0
pydantic>=2.7.0
pydantic>=2.7.0
unstructured[md]
2 changes: 1 addition & 1 deletion server/routers/github.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

router = APIRouter(
prefix="/api/github",
tags=["health_checkers"],
tags=["github"],
responses={404: {"description": "Not found"}},
)

Expand Down
2 changes: 1 addition & 1 deletion server/routers/health_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,4 @@

@router.get("/health_checker")
def health_checker():
return {"Hello": "World"}
return { "Hello": "World" }
20 changes: 20 additions & 0 deletions server/routers/rag.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from fastapi import APIRouter
from rag import retrieval
from data_class import S3Config

router = APIRouter(
prefix="/api",
tags=["rag"],
responses={404: {"description": "Not found"}},
)


@router.post("/rag/add_knowledge")
def add_knowledge(config: S3Config):
data=retrieval.add_knowledge(config)
return data

@router.post("/rag/search_knowledge")
def search_knowledge(query: str):
data=retrieval.search_knowledge(query)
return data
7 changes: 3 additions & 4 deletions server/tools/issue.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from typing import Optional
from github import Github
from langchain.tools import tool
from uilts.env import get_env_variable

DEFAULT_REPO_NAME = "ant-design/ant-design"

Expand Down Expand Up @@ -84,15 +83,15 @@ def search_issues(
:param state: The state of the issue, e.g: open, closed, all
"""
try:
search_query = f'{keyword} in:title,body,comments repo:{repo_name}'
search_query = f"{keyword} in:title,body,comments repo:{repo_name}"
# Retrieve a list of open issues from the repository
issues = g.search_issues(query=search_query, sort=sort, order=order)[:max_num]
print(f"issues: {issues}")

issues_list = [
{
'issue_name': f'Issue #{issue.number} - {issue.title}',
'issue_url': issue.html_url
"issue_name": f"Issue #{issue.number} - {issue.title}",
"issue_url": issue.html_url
}
for issue in issues
]
Expand Down
5 changes: 2 additions & 3 deletions server/tools/sourcecode.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@
from github import Github
from github.ContentFile import ContentFile
from langchain.tools import tool
from uilts.env import get_env_variable


DEFAULT_REPO_NAME = "ant-design/ant-design"

Expand All @@ -29,7 +27,8 @@ def search_code(

# Perform the search for code files containing the keyword
code_files = g.search_code(query=query)[:max_num]
return code_files

return code_files
except Exception as e:
print(f"An error occurred: {e}")
return None
Expand Down

0 comments on commit 867f7f5

Please sign in to comment.