Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add knowledge form s3 #88

Merged
merged 16 commits into from
Apr 16, 2024
5 changes: 4 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
# See https://help.github.com/articles/ignoring-files/ for more about ignoring files.
__pycache__/
.pyc
*.pyc
# dependencies
/node_modules
/.pnp
.pnp.js

server/temp/

# testing
/coverage

Expand Down Expand Up @@ -39,3 +41,4 @@ next-env.d.ts
.yarn
/server/.aws-sam/*
.aws-sam/*

5 changes: 1 addition & 4 deletions app/api/chat/retrieval/route.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,7 @@ import { ChatOpenAI } from 'langchain/chat_models/openai';
import { PromptTemplate } from 'langchain/prompts';
import { SupabaseVectorStore } from 'langchain/vectorstores/supabase';
import { Document } from 'langchain/document';
import {
RunnableSequence,
RunnablePassthrough,
} from 'langchain/schema/runnable';
import { RunnableSequence } from 'langchain/schema/runnable';
import {
BytesOutputParser,
StringOutputParser,
Expand Down
10 changes: 10 additions & 0 deletions server/Dockerfile.aws.lambda
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,14 @@ COPY . ${LAMBDA_TASK_ROOT}
COPY requirements.txt .
RUN pip3 install -r requirements.txt --target "${LAMBDA_TASK_ROOT}" -U --no-cache-dir

# Setup NLTK again in system path to execute nltk.downloader
RUN pip install nltk
# Setup directory for NLTK_DATA
RUN mkdir -p /opt/nltk_data

# Download NLTK_DATA to build directory
RUN python -W ignore -m nltk.downloader punkt -d /opt/nltk_data
RUN python -W ignore -m nltk.downloader stopwords -d /opt/nltk_data
RUN python -W ignore -m nltk.downloader averaged_perceptron_tagger -d /opt/nltk_data

CMD ["python", "main.py"]
7 changes: 6 additions & 1 deletion server/data_class.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from typing import Optional
from pydantic import BaseModel


Expand All @@ -17,4 +18,8 @@ class ChatData(BaseModel):
class ExecuteMessage(BaseModel):
type: str
repo: str
path: str
path: str

class S3Config(BaseModel):
s3_bucket: str
file_path: Optional[str] = None
16 changes: 4 additions & 12 deletions server/main.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import os
from rag import retrieval

import uvicorn
from fastapi import FastAPI
from fastapi.responses import StreamingResponse
Expand All @@ -11,7 +11,7 @@
from data_class import ChatData

# Import fastapi routers
from routers import health_checker, github
from routers import health_checker, github, rag

open_api_key = get_env_variable("OPENAI_API_KEY")
is_dev = bool(get_env_variable("IS_DEV"))
Expand All @@ -33,25 +33,17 @@

app.include_router(health_checker.router)
app.include_router(github.router)

app.include_router(rag.router)

@app.post("/api/chat/stream", response_class=StreamingResponse)
def run_agent_chat(input_data: ChatData):
result = stream.agent_chat(input_data, open_api_key)
return StreamingResponse(result, media_type="text/event-stream")

@app.post("/api/rag/add_knowledge")
def add_knowledge():
data=retrieval.add_knowledge()
return data

@app.post("/api/rag/search_knowledge")
def search_knowledge(query: str):
data=retrieval.search_knowledge(query)
return data

if __name__ == "__main__":
if is_dev:
uvicorn.run("main:app", host="0.0.0.0", port=int(os.environ.get("PORT", "8080")), reload=True)
else:
uvicorn.run(app, host="0.0.0.0", port=int(os.environ.get("PORT", "8080")))
uvicorn.run(app, host="0.0.0.0", port=int(os.environ.get("PORT", "8080")))
60 changes: 31 additions & 29 deletions server/rag/retrieval.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,22 @@
import os
import json
from langchain_community.document_loaders import TextLoader
import boto3
from langchain_openai import OpenAIEmbeddings
from langchain_text_splitters import CharacterTextSplitter
from langchain_community.vectorstores import SupabaseVectorStore
from db.supabase.client import get_client
from data_class import S3Config
from uilts.env import get_env_variable
from langchain_community.document_loaders import S3DirectoryLoader


supabase_url = get_env_variable("SUPABASE_URL")
supabase_key = get_env_variable("SUPABASE_SERVICE_KEY")


table_name="antd_knowledge"
query_name="match_antd_knowledge"
chunk_size=500
chunk_size=2000


def convert_document_to_dict(document):
return {
Expand All @@ -32,36 +37,33 @@ def init_retriever():

return db.as_retriever()

def add_knowledge():
current_dir = os.path.dirname(os.path.abspath(__file__))
target_file_path = os.path.join(current_dir, "../docs/test.md")
loader = TextLoader(target_file_path)
documents = loader.load()
text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
docs = text_splitter.split_documents(documents)
embeddings = OpenAIEmbeddings()

def add_knowledge(config: S3Config):
try:
SupabaseVectorStore.from_documents(
docs,
embeddings,
client=supabase,
table_name=table_name,
query_name=query_name,
chunk_size=chunk_size,
)
return json.dumps({
"success": True,
"message": "Knowledge added successfully!"
})
loader = S3DirectoryLoader(config.s3_bucket, prefix=config.file_path)
documents = loader.load()
text_splitter = CharacterTextSplitter(chunk_size=2000, chunk_overlap=0)
docs = text_splitter.split_documents(documents)
embeddings = OpenAIEmbeddings()
SupabaseVectorStore.from_documents(
docs,
embeddings,
client=get_client(),
table_name=table_name,
query_name=query_name,
chunk_size=chunk_size,
)
return json.dumps({
"success": True,
"message": "Knowledge added successfully!",
"docs_len": len(documents)
})
except Exception as e:
return json.dumps({
"success": False,
"message": str(e)
})
return json.dumps({
"success": False,
"message": str(e)
})



def search_knowledge(query: str):
retriever = init_retriever()
docs = retriever.get_relevant_documents(query)
Expand Down
3 changes: 2 additions & 1 deletion server/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,5 @@ load_dotenv
supabase
boto3>=1.34.84
pyjwt>=2.4.0
pydantic>=2.7.0
pydantic>=2.7.0
unstructured[md]
2 changes: 1 addition & 1 deletion server/routers/github.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

router = APIRouter(
prefix="/api/github",
tags=["health_checkers"],
tags=["github"],
responses={404: {"description": "Not found"}},
)

Expand Down
2 changes: 1 addition & 1 deletion server/routers/health_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,4 @@

@router.get("/health_checker")
def health_checker():
return {"Hello": "World"}
return { "Hello": "World" }
20 changes: 20 additions & 0 deletions server/routers/rag.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from fastapi import APIRouter
from rag import retrieval
from data_class import S3Config

router = APIRouter(
prefix="/api",
tags=["rag"],
responses={404: {"description": "Not found"}},
)


@router.post("/rag/add_knowledge")
def add_knowledge(config: S3Config):
data=retrieval.add_knowledge(config)
return data

@router.post("/rag/search_knowledge")
def search_knowledge(query: str):
data=retrieval.search_knowledge(query)
return data
7 changes: 3 additions & 4 deletions server/tools/issue.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from typing import Optional
from github import Github
from langchain.tools import tool
from uilts.env import get_env_variable

DEFAULT_REPO_NAME = "ant-design/ant-design"

Expand Down Expand Up @@ -84,15 +83,15 @@ def search_issues(
:param state: The state of the issue, e.g: open, closed, all
"""
try:
search_query = f'{keyword} in:title,body,comments repo:{repo_name}'
search_query = f"{keyword} in:title,body,comments repo:{repo_name}"
# Retrieve a list of open issues from the repository
issues = g.search_issues(query=search_query, sort=sort, order=order)[:max_num]
print(f"issues: {issues}")

issues_list = [
{
'issue_name': f'Issue #{issue.number} - {issue.title}',
'issue_url': issue.html_url
"issue_name": f"Issue #{issue.number} - {issue.title}",
"issue_url": issue.html_url
}
for issue in issues
]
Expand Down
5 changes: 2 additions & 3 deletions server/tools/sourcecode.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@
from github import Github
from github.ContentFile import ContentFile
from langchain.tools import tool
from uilts.env import get_env_variable


DEFAULT_REPO_NAME = "ant-design/ant-design"

Expand All @@ -29,7 +27,8 @@ def search_code(

# Perform the search for code files containing the keyword
code_files = g.search_code(query=query)[:max_num]
return code_files

return code_files
except Exception as e:
print(f"An error occurred: {e}")
return None
Expand Down