Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cleanup open source RAG notebook #1179

Merged
merged 1 commit into from
Oct 27, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
123 changes: 42 additions & 81 deletions notebooks/llms/langchain/readthedocs_rag_milvus_client.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
"<img src=\"../../../images/rag_image.png\" width=\"80%\"/>\n",
"</div>\n",
"\n",
"Many RAG demos use OpenAI for the Embedding Model and ChatGPT for the Generative AI model. In this notebook, we will demo a fully open source RAG stack - open source embedding model hosted on HuggingFace, Milvus, and an open source LLM.\n",
"Many RAG demos use OpenAI for the Embedding Model and ChatGPT for the Generative AI model. In this notebook, we will demo a fully open source RAG stack - open source embedding model available on HuggingFace, Milvus, and an open source LLM.\n",
"\n",
"Let's get started!"
]
Expand All @@ -34,7 +34,7 @@
"outputs": [],
"source": [
"# For colab install these libraries in this order:\n",
"# !pip install milvus, pymilvus, langchain, torch, transformers, python-dotenv\n",
"# !pip install milvus, pymilvus, langchain, torch, transformers, python-dotenv, accelerate\n",
"\n",
"# Import common libraries.\n",
"import time\n",
Expand Down Expand Up @@ -98,7 +98,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
"Milvus server startup time: 8.419662952423096 sec\n",
"Milvus server startup time: 6.501969814300537 sec\n",
"v2.2-testing-20230824-68-ga34a9d606-lite\n"
]
}
Expand Down Expand Up @@ -141,7 +141,7 @@
"metadata": {},
"source": [
"## Load the Embedding Model checkpoint and use it to create vector embeddings\n",
"**Embedding model:** We will use the open-source [sentence transformers](https://www.sbert.net/docs/pretrained_models.html) hosted on HuggingFace to encode the documentation text. We will save the embeddings to a pandas dataframe and then into the milvus database.\n",
"**Embedding model:** We will use the open-source [sentence transformers](https://www.sbert.net/docs/pretrained_models.html) available on HuggingFace to encode the documentation text. We will download the model from HuggingFace and run it locally. We'll save the model's generated embeedings to a pandas dataframe and then into the milvus database.\n",
"\n",
"💡 Note: To keep your tokens private, best practice is to use an env variable. <br>\n",
"In Jupyter, need .env file (in same dir as notebooks) containing lines like this:\n",
Expand All @@ -159,10 +159,6 @@
"output_type": "stream",
"text": [
"device: cpu\n",
"Token will not been saved to git credential helper. Pass `add_to_git_credential=True` if you want to set the git credential as well.\n",
"Token is valid (permission: write).\n",
"Your token has been saved to /Users/christybergman/.cache/huggingface/token\n",
"Login successful\n",
"<class 'sentence_transformers.SentenceTransformer.SentenceTransformer'>\n",
"SentenceTransformer(\n",
" (0): Transformer({'max_seq_length': 512, 'do_lower_case': True}) with Transformer model: BertModel \n",
Expand Down Expand Up @@ -192,9 +188,9 @@
"_ = load_dotenv(find_dotenv())\n",
"from huggingface_hub import login\n",
"\n",
"# Login to huggingface_hub\n",
"hub_token = os.getenv(\"HUGGINGFACEHUB_API_TOKEN\")\n",
"login(token=hub_token)\n",
"# # Login to huggingface_hub\n",
"# hub_token = os.getenv(\"HUGGINGFACEHUB_API_TOKEN\")\n",
"# login(token=hub_token)\n",
"\n",
"# Load the model from huggingface model hub.\n",
"model_name = \"BAAI/bge-base-en-v1.5\"\n",
Expand Down Expand Up @@ -236,17 +232,17 @@
"\n",
"model_kwargs = {\"device\": DEVICE}\n",
"encode_kwargs = {'normalize_embeddings': True}\n",
"lc_retriever = HuggingFaceEmbeddings(\n",
"lc_encoder = HuggingFaceEmbeddings(\n",
" model_name=model_name,\n",
" model_kwargs=model_kwargs,\n",
" encode_kwargs=encode_kwargs\n",
")\n",
"type(lc_retriever)"
"type(lc_encoder)"
]
},
{
"cell_type": "code",
"execution_count": 9,
"execution_count": 6,
"id": "6861beb7",
"metadata": {},
"outputs": [
Expand All @@ -267,9 +263,7 @@
"docs = loader.load()\n",
"\n",
"num_documents = len(docs)\n",
"print(f\"loaded {num_documents} documents\")\n",
"# print(f\"type: {type(docs)}, len: {len(docs)}, type: {type(docs[0])}\")\n",
"# docs[0]"
"print(f\"loaded {num_documents} documents\")"
]
},
{
Expand All @@ -288,40 +282,15 @@
},
{
"cell_type": "code",
"execution_count": 10,
"id": "a751a9bd",
"metadata": {},
"outputs": [],
"source": [
"from langchain.text_splitter import RecursiveCharacterTextSplitter\n",
"\n",
"def recursive_splitter_wrapper(text, chunk_size):\n",
"\n",
" # Default chunk overlap is 10% chunk_size.\n",
" chunk_overlap = np.round(chunk_size * 0.10, 0)\n",
"\n",
" # Use langchain's convenient recursive chunking method.\n",
" text_splitter = RecursiveCharacterTextSplitter(\n",
" chunk_size=chunk_size,\n",
" chunk_overlap=chunk_overlap,\n",
" length_function=len,\n",
" )\n",
" \n",
" chunks = text_splitter.split_text(text)\n",
" return [chunk for chunk in chunks if chunk]\n"
]
},
{
"cell_type": "code",
"execution_count": 11,
"execution_count": 7,
"id": "a53595fa",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"chunking time: 0.0027921199798583984\n",
"chunking time: 0.0025839805603027344\n",
"type: list of <class 'langchain.schema.document.Document'>, len: 197\n",
"\n",
"Looking at a sample chunk...\n",
Expand All @@ -334,6 +303,8 @@
}
],
"source": [
"from langchain.text_splitter import RecursiveCharacterTextSplitter\n",
"\n",
"# Use the embedding model parameters to calculate chunk_size and overlap.\n",
"chunk_size = MAX_SEQ_LENGTH - HF_EOS_TOKEN_LENGTH\n",
"# Default chunk overlap is 10% chunk_size.\n",
Expand Down Expand Up @@ -364,7 +335,7 @@
},
{
"cell_type": "code",
"execution_count": 12,
"execution_count": 8,
"id": "512130a3",
"metadata": {},
"outputs": [
Expand Down Expand Up @@ -414,7 +385,7 @@
},
{
"cell_type": "code",
"execution_count": 13,
"execution_count": 9,
"id": "b51ff139",
"metadata": {},
"outputs": [
Expand All @@ -423,7 +394,7 @@
"output_type": "stream",
"text": [
"Start inserting entities\n",
"Langchain Milvus insert time for 197 vectors: 11.718086004257202 seconds\n",
"Langchain Milvus insert time for 197 vectors: 11.850640058517456 seconds\n",
"type: <class 'langchain.vectorstores.milvus.Milvus'>\n"
]
}
Expand All @@ -439,7 +410,7 @@
"\n",
"vector_store = Milvus.from_documents(\n",
" chunks,\n",
" embedding=lc_retriever,\n",
" embedding=lc_encoder,\n",
" connection_args={\"host\": MILVUS_HOST, \n",
" \"port\": MILVUS_PORT},\n",
")\n",
Expand Down Expand Up @@ -485,7 +456,7 @@
},
{
"cell_type": "code",
"execution_count": 15,
"execution_count": 10,
"id": "5e7f41f4",
"metadata": {},
"outputs": [
Expand Down Expand Up @@ -524,46 +495,36 @@
},
{
"cell_type": "code",
"execution_count": 16,
"execution_count": null,
"id": "89642119",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Milvus query time: 0.2895548343658447\n",
"source: {'source': 'https://pymilvus.readthedocs.io/en/latest/genindex.html'}\n"
]
}
],
"outputs": [],
"source": [
"# RETRIEVAL USING MILVUS WITH LANGCHAIN\n",
"\n",
"start_time = time.time()\n",
"# Default search.\n",
"# docs = vector_store.similarity_search(question, k=7)\n",
"\n",
"# MMR search.\n",
"# docs = vector_store.max_marginal_relevance_search(question, k=7, fetch_k=100)\n",
"\n",
"# Search with metadata. TODO: Add better filtering query!\n",
"METADATA_URL = \"https://pymilvus.readthedocs.io/en/latest/_modules/milvus/client/stub.html\"\n",
"SEARCH_PARAMS = dict({\n",
" \"expr\": \"text = METADATA_URL\",\n",
" })\n",
"\n",
"start_time = time.time()\n",
"\n",
"# Default search.\n",
"docs = vector_store.similarity_search(\n",
" question,\n",
" question, \n",
" k=100,\n",
" filter={\"source\": METADATA_URL},\n",
" param=SEARCH_PARAMS,\n",
" verbose=True,\n",
")\n",
" )\n",
"\n",
"end_time = time.time()\n",
"print(f\"Milvus query time: {end_time - start_time}\")\n",
"\n",
"# View the retrieval result.\n",
"print(f\"source: {docs[0].metadata}\")\n",
"# print([doc.page_content for doc in docs])\n",
"\n",
"# # View the retrieval result.\n",
"# for d in docs:\n",
"# print(d.metadata)\n"
"# print(d.metadata)\n",
"# # print(d.page_content[:100])\n"
]
},
{
Expand All @@ -578,7 +539,7 @@
},
{
"cell_type": "code",
"execution_count": 17,
"execution_count": 12,
"id": "d3dfa33a",
"metadata": {},
"outputs": [
Expand Down Expand Up @@ -627,7 +588,7 @@
},
{
"cell_type": "code",
"execution_count": 18,
"execution_count": 13,
"id": "3e7fa0b6",
"metadata": {},
"outputs": [
Expand Down Expand Up @@ -668,7 +629,7 @@
},
{
"cell_type": "code",
"execution_count": 19,
"execution_count": 14,
"id": "a68e87b1",
"metadata": {},
"outputs": [
Expand Down Expand Up @@ -701,7 +662,7 @@
},
{
"cell_type": "code",
"execution_count": 20,
"execution_count": 15,
"id": "d0e81e68",
"metadata": {},
"outputs": [],
Expand All @@ -713,7 +674,7 @@
},
{
"cell_type": "code",
"execution_count": 21,
"execution_count": 16,
"id": "c777937e",
"metadata": {},
"outputs": [
Expand All @@ -728,7 +689,7 @@
"IPython version : 8.15.0\n",
"\n",
"torch : 2.0.1\n",
"transformers: 4.33.2\n",
"transformers: 4.34.1\n",
"milvus : 2.3.0\n",
"pymilvus : 2.3.0\n",
"langchain : 0.0.301\n",
Expand Down