Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update/function call models #34

Closed
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 33 additions & 30 deletions whisperplus/pipelines/chatbot.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,27 +8,19 @@
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.vectorstores import LanceDB

# Configuration and Constants
MODEL_NAME = 'sentence-transformers/all-MiniLM-L6-v2'
LLM_MODEL_NAME = 'TheBloke/Mistral-7B-v0.1-GGUF'
LLM_MODEL_FILE = 'mistral-7b-v0.1.Q4_K_M.gguf'
LLM_MODEL_TYPE = "mistral"
TEXT_FILE_PATH = "transcript.text"
DATABASE_PATH = '/tmp/lancedb'


class ChatWithVideo:
DATABASE_PATH = '/tmp/lancedb' # Default database path

@staticmethod
def load_llm_model():
def load_llm_model(model_name, model_file, model_type):
try:
print("Starting to download the Mistral model...")
llm_model = CTransformers(
model=LLM_MODEL_NAME, model_file=LLM_MODEL_FILE, model_type=LLM_MODEL_TYPE)
print("Mistral model successfully loaded.")
print(f"Starting to download the {model_name} model...")
llm_model = CTransformers(model=model_name, model_file=model_file, model_type=model_type)
print(f"{model_name} model successfully loaded.")
return llm_model
except Exception as e:
print(f"Error loading the Mistral model: {e}")
print(f"Error loading the {model_name} model: {e}")
return None

@staticmethod
Expand All @@ -54,7 +46,6 @@ def setup_database():
print(f"Error setting up the database: {e}")
return None

# embedding model
@staticmethod
def prepare_embeddings(model_name):
try:
Expand Down Expand Up @@ -82,21 +73,20 @@ def prepare_documents(docs):
return None

@staticmethod
def run_query(query):
def run_query(query, text_file_path, model_name, llm_model_name, llm_model_file, llm_model_type):
if not query:
print("No query provided.")
return "No query provided."

print(f"Running query: {query}")
docs = ChatWithVideo.load_text_file(TEXT_FILE_PATH)
docs = ChatWithVideo.load_text_file(text_file_path)
if not docs:
return "Failed to load documents."

documents = ChatWithVideo.prepare_documents(docs)
if not documents:
return "Failed to prepare documents."

embeddings = ChatWithVideo.prepare_embeddings(MODEL_NAME)
embeddings = ChatWithVideo.prepare_embeddings(model_name)
if not embeddings:
return "Failed to prepare embeddings."

Expand All @@ -115,21 +105,17 @@ def run_query(query):
mode="overwrite")
docsearch = LanceDB.from_documents(documents, embeddings, connection=table)

llm = ChatWithVideo.load_llm_model()

llm = ChatWithVideo.load_llm_model(llm_model_name, llm_model_file, llm_model_type)
if not llm:
return "Failed to load LLM model."

template = """Use the following pieces of context to answer the question at the end.
If you don't know the answer, just say that you don't know, don't try to make up an answer.
Use three sentences maximum and keep the answer as concise as possible.
Always say "thanks for asking!" at the end of the answer.
{context}
Question: {question}
Helpful Answer:"""

template = """Use the following pieces of context to answer the question at the end...
{context}
Question: {question}
Helpful Answer:"""
QA_CHAIN_PROMPT = PromptTemplate(input_variables=["context", "question"], template=template)
print("prompt loaded")
print("Prompt loaded")

qa = RetrievalQA.from_chain_type(
llm,
chain_type='stuff',
Expand All @@ -140,3 +126,20 @@ def run_query(query):
except Exception as e:
print(f"Error running query: {e}")
return f"Error: {e}"


#####

# from wihsperplus.pipelines.chatbot import ChatWithVideo

# # Model configuration
# model_name = 'sentence-transformers/all-MiniLM-L6-v2'
# llm_model_name = 'TheBloke/Mistral-7B-v0.1-GGUF'
# llm_model_file = 'mistral-7b-v0.1.Q4_K_M.gguf'
# llm_model_type = "mistral"
# text_file_path = "transcript.text"

# # Run the query
# query = "What is mistral?"
# result = ChatWithVideo.run_query(query, text_file_path, model_name, llm_model_name, llm_model_file, llm_model_type)
# print("Result:", result)
Loading