Skip to content

Commit

Permalink
feat: merge
Browse files Browse the repository at this point in the history
  • Loading branch information
alichengyue committed Apr 16, 2024
2 parents aa5e39d + fd501c7 commit d6396de
Show file tree
Hide file tree
Showing 12 changed files with 88 additions and 56 deletions.
5 changes: 4 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
# See https://help.github.com/articles/ignoring-files/ for more about ignoring files.
__pycache__/
.pyc
*.pyc
# dependencies
/node_modules
/.pnp
.pnp.js

server/temp/

# testing
/coverage

Expand Down Expand Up @@ -39,3 +41,4 @@ next-env.d.ts
.yarn
/server/.aws-sam/*
.aws-sam/*

5 changes: 1 addition & 4 deletions app/api/chat/retrieval/route.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,7 @@ import { ChatOpenAI } from 'langchain/chat_models/openai';
import { PromptTemplate } from 'langchain/prompts';
import { SupabaseVectorStore } from 'langchain/vectorstores/supabase';
import { Document } from 'langchain/document';
import {
RunnableSequence,
RunnablePassthrough,
} from 'langchain/schema/runnable';
import { RunnableSequence } from 'langchain/schema/runnable';
import {
BytesOutputParser,
StringOutputParser,
Expand Down
11 changes: 10 additions & 1 deletion server/Dockerfile.aws.lambda
Original file line number Diff line number Diff line change
@@ -1,11 +1,20 @@
FROM public.ecr.aws/docker/library/python:3.12.0-slim-bullseye

# Copy aws-lambda-adapter for Steaming response
COPY --from=public.ecr.aws/awsguru/aws-lambda-adapter:0.8.1 /lambda-adapter /opt/extensions/lambda-adapter

# Copy nltk_lambda_layer for using nltk in lambda
COPY --from=public.ecr.aws/m5s2b0d4/nltk_lambda_layer:latest /nltk_data /opt/nltk_data


# Copy function code
COPY . ${LAMBDA_TASK_ROOT}
# from your project folder.
COPY requirements.txt .
RUN pip3 install -r requirements.txt --target "${LAMBDA_TASK_ROOT}" -U --no-cache-dir

CMD ["python", "main.py"]
# Set NLTK_DATA to load nltk_data
ENV NLTK_DATA=/opt/nltk_data


CMD ["python", "main.py"]
7 changes: 6 additions & 1 deletion server/data_class.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from typing import Optional
from pydantic import BaseModel


Expand All @@ -17,4 +18,8 @@ class ChatData(BaseModel):
class ExecuteMessage(BaseModel):
type: str
repo: str
path: str
path: str

class S3Config(BaseModel):
s3_bucket: str
file_path: Optional[str] = None
17 changes: 7 additions & 10 deletions server/main.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import os
from rag import retrieval

import uvicorn
from fastapi import FastAPI
from fastapi.responses import StreamingResponse
Expand All @@ -11,7 +11,11 @@
from data_class import ChatData

# Import fastapi routers
<<<<<<< HEAD
from routers import bot, health_checker, github
=======
from routers import bot, health_checker, github, rag
>>>>>>> new-branch-name

open_api_key = get_env_variable("OPENAI_API_KEY")
is_dev = bool(get_env_variable("IS_DEV"))
Expand All @@ -33,6 +37,7 @@

app.include_router(health_checker.router)
app.include_router(github.router)
app.include_router(rag.router)
app.include_router(bot.router)


Expand All @@ -41,18 +46,10 @@ def run_agent_chat(input_data: ChatData):
result = stream.agent_chat(input_data, open_api_key)
return StreamingResponse(result, media_type="text/event-stream")

@app.post("/api/rag/add_knowledge")
def add_knowledge():
data=retrieval.add_knowledge()
return data

@app.post("/api/rag/search_knowledge")
def search_knowledge(query: str):
data=retrieval.search_knowledge(query)
return data

if __name__ == "__main__":
if is_dev:
uvicorn.run("main:app", host="0.0.0.0", port=int(os.environ.get("PORT", "8080")), reload=True)
else:
uvicorn.run(app, host="0.0.0.0", port=int(os.environ.get("PORT", "8080")))
uvicorn.run(app, host="0.0.0.0", port=int(os.environ.get("PORT", "8080")))
60 changes: 31 additions & 29 deletions server/rag/retrieval.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,22 @@
import os
import json
from langchain_community.document_loaders import TextLoader
import boto3
from langchain_openai import OpenAIEmbeddings
from langchain_text_splitters import CharacterTextSplitter
from langchain_community.vectorstores import SupabaseVectorStore
from db.supabase.client import get_client
from data_class import S3Config
from uilts.env import get_env_variable
from langchain_community.document_loaders import S3DirectoryLoader


supabase_url = get_env_variable("SUPABASE_URL")
supabase_key = get_env_variable("SUPABASE_SERVICE_KEY")


table_name="antd_knowledge"
query_name="match_antd_knowledge"
chunk_size=500
chunk_size=2000


def convert_document_to_dict(document):
return {
Expand All @@ -32,36 +37,33 @@ def init_retriever():

return db.as_retriever()

def add_knowledge():
current_dir = os.path.dirname(os.path.abspath(__file__))
target_file_path = os.path.join(current_dir, "../docs/test.md")
loader = TextLoader(target_file_path)
documents = loader.load()
text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
docs = text_splitter.split_documents(documents)
embeddings = OpenAIEmbeddings()

def add_knowledge(config: S3Config):
try:
SupabaseVectorStore.from_documents(
docs,
embeddings,
client=supabase,
table_name=table_name,
query_name=query_name,
chunk_size=chunk_size,
)
return json.dumps({
"success": True,
"message": "Knowledge added successfully!"
})
loader = S3DirectoryLoader(config.s3_bucket, prefix=config.file_path)
documents = loader.load()
text_splitter = CharacterTextSplitter(chunk_size=2000, chunk_overlap=0)
docs = text_splitter.split_documents(documents)
embeddings = OpenAIEmbeddings()
SupabaseVectorStore.from_documents(
docs,
embeddings,
client=get_client(),
table_name=table_name,
query_name=query_name,
chunk_size=chunk_size,
)
return json.dumps({
"success": True,
"message": "Knowledge added successfully!",
"docs_len": len(documents)
})
except Exception as e:
return json.dumps({
"success": False,
"message": str(e)
})
return json.dumps({
"success": False,
"message": str(e)
})



def search_knowledge(query: str):
retriever = init_retriever()
docs = retriever.get_relevant_documents(query)
Expand Down
3 changes: 2 additions & 1 deletion server/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,5 @@ load_dotenv
supabase
boto3>=1.34.84
pyjwt>=2.4.0
pydantic>=2.7.0
pydantic>=2.7.0
unstructured[md]
2 changes: 1 addition & 1 deletion server/routers/github.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

router = APIRouter(
prefix="/api/github",
tags=["health_checkers"],
tags=["github"],
responses={404: {"description": "Not found"}},
)

Expand Down
2 changes: 1 addition & 1 deletion server/routers/health_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,4 @@

@router.get("/health_checker")
def health_checker():
return {"Hello": "World"}
return { "Hello": "World" }
20 changes: 20 additions & 0 deletions server/routers/rag.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from fastapi import APIRouter
from rag import retrieval
from data_class import S3Config

router = APIRouter(
prefix="/api",
tags=["rag"],
responses={404: {"description": "Not found"}},
)


@router.post("/rag/add_knowledge")
def add_knowledge(config: S3Config):
data=retrieval.add_knowledge(config)
return data

@router.post("/rag/search_knowledge")
def search_knowledge(query: str):
data=retrieval.search_knowledge(query)
return data
7 changes: 3 additions & 4 deletions server/tools/issue.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from typing import Optional
from github import Github
from langchain.tools import tool
from uilts.env import get_env_variable

DEFAULT_REPO_NAME = "ant-design/ant-design"

Expand Down Expand Up @@ -84,15 +83,15 @@ def search_issues(
:param state: The state of the issue, e.g: open, closed, all
"""
try:
search_query = f'{keyword} in:title,body,comments repo:{repo_name}'
search_query = f"{keyword} in:title,body,comments repo:{repo_name}"
# Retrieve a list of open issues from the repository
issues = g.search_issues(query=search_query, sort=sort, order=order)[:max_num]
print(f"issues: {issues}")

issues_list = [
{
'issue_name': f'Issue #{issue.number} - {issue.title}',
'issue_url': issue.html_url
"issue_name": f"Issue #{issue.number} - {issue.title}",
"issue_url": issue.html_url
}
for issue in issues
]
Expand Down
5 changes: 2 additions & 3 deletions server/tools/sourcecode.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@
from github import Github
from github.ContentFile import ContentFile
from langchain.tools import tool
from uilts.env import get_env_variable


DEFAULT_REPO_NAME = "ant-design/ant-design"

Expand All @@ -29,7 +27,8 @@ def search_code(

# Perform the search for code files containing the keyword
code_files = g.search_code(query=query)[:max_num]
return code_files

return code_files
except Exception as e:
print(f"An error occurred: {e}")
return None
Expand Down

0 comments on commit d6396de

Please sign in to comment.