forked from lancedb/vectordb-recipes
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
165 lines (129 loc) · 4.58 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
import nltk
import pandas as pd
nltk.download("punkt")
import re
import os
import ollama
# lancedb modules for embedding api
import lancedb
from lancedb.embeddings import get_registry
from lancedb.pydantic import LanceModel, Vector
# Recursive Text Splitter
def recursive_text_splitter(text, max_chunk_length=1000, overlap=100):
"""
Helper function for chunking text recursively
"""
# Initialize result
result = []
current_chunk_count = 0
separator = ["\n", " "]
_splits = re.split(f"({separator})", text)
splits = [_splits[i] + _splits[i + 1] for i in range(1, len(_splits), 2)]
for i in range(len(splits)):
if current_chunk_count != 0:
chunk = "".join(
splits[
current_chunk_count
- overlap : current_chunk_count
+ max_chunk_length
]
)
else:
chunk = "".join(splits[0:max_chunk_length])
if len(chunk) > 0:
result.append("".join(chunk))
current_chunk_count += max_chunk_length
return result
# model definition using LanceDB Embedding API
model1 = get_registry().get("openai").create()
model2 = get_registry().get("ollama").create(name="llama3")
model3 = get_registry().get("ollama").create(name="mistral")
# define schema for Embedding Spaces with embedding api
class TextModel1(LanceModel):
text: str = model1.SourceField()
vector: Vector(model1.ndims()) = model1.VectorField()
class TextModel2(LanceModel):
text: str = model2.SourceField()
vector: Vector(model2.ndims()) = model2.VectorField()
class TextModel3(LanceModel):
text: str = model3.SourceField()
vector: Vector(model3.ndims()) = model3.VectorField()
# Embedding Spaces
def LanceDBEmbeddingSpace(df):
"""
Create 3 Embedding spaces with Colbert, Llama3 and Mistral
"""
db = lancedb.connect("/tmp/lancedb")
print("Embedding spaces creation started \U0001F6A7.....")
table1 = db.create_table(
"embed_space1",
schema=TextModel1,
mode="overwrite",
)
table2 = db.create_table(
"embed_space2",
schema=TextModel2,
mode="overwrite",
)
table3 = db.create_table(
"embed_space3",
schema=TextModel3,
mode="overwrite",
)
table1.add(df)
table2.add(df)
table3.add(df)
print("3 Embedding spaces created \U0001f44d")
return table1, table2, table3
if __name__ == "__main__":
filename = input(
"Enter filepath(.txt), you want to query(Default file: lease.txt) : "
)
if filename == "":
filename = "lease.txt"
else:
if not os.path.exists(filename):
print("Given ", filename, " doesn't exists \U0000274c")
exit()
# Read Document
with open(filename, "r") as file:
text_data = file.read()
# Split the text using the recursive character text splitter
chunks = recursive_text_splitter(text_data, max_chunk_length=100, overlap=10)
df = pd.DataFrame({"text": chunks})
table1, table2, table3 = LanceDBEmbeddingSpace(df)
# Query Question
while True:
question = input("Enter Query: ")
if question in ["q", "exit", "quit"]:
break
# Query Search
print("Query Search started ......")
result1 = table1.search(question).limit(3).to_list()
result2 = table2.search(question).limit(3).to_list()
result3 = table3.search(question).limit(3).to_list()
context = (
[r["text"] for r in result1]
+ [r["text"] for r in result2]
+ [r["text"] for r in result3]
)
print("Answer generation started ....")
# Context Prompt
base_prompt = """You are an AI assistant. Your task is to understand the user question, and provide an answer using the provided contexts. Every answer you generate should have citations in this pattern "Answer [position].", for example: "Earth is round [1][2].," if it's relevant.
Your answers are correct, high-quality, and written by an domain expert. If the provided context does not contain the answer, simply state, "The provided context does not have the answer."
User question: {}
Contexts:
{}
"""
# llm
prompt = f"{base_prompt.format(question, context)}"
response = ollama.chat(
model="llama3",
messages=[
{
"role": "system",
"content": prompt,
},
],
)
print("Answer: ", response["message"]["content"])