Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

utils.default_device_type(): detect either cuda or mps #718

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 7 additions & 6 deletions ingest.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import logging
import os
import sys
from utils import default_device_type
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor, as_completed

import click
Expand Down Expand Up @@ -38,7 +40,7 @@ def load_single_document(file_path: str) -> Document:
return loader.load()[0]
except Exception as ex:
file_log('%s loading error: \n%s' % (file_path, ex))
return None
return None

def load_document_batch(filepaths):
logging.info("Loading document batch")
Expand Down Expand Up @@ -93,7 +95,7 @@ def load_documents(source_dir: str) -> list[Document]:
docs.extend(contents)
except Exception as ex:
file_log('Exception: %s' % (ex))

return docs


Expand All @@ -109,11 +111,10 @@ def split_documents(documents: list[Document]) -> tuple[list[Document], list[Doc
text_docs.append(doc)
return text_docs, python_docs


@click.command()
@click.option(
"--device_type",
default="cuda" if torch.cuda.is_available() else "cpu",
default=default_device_type(),
type=click.Choice(
[
"cpu",
Expand All @@ -137,7 +138,7 @@ def split_documents(documents: list[Document]) -> tuple[list[Document], list[Doc
"mtia",
],
),
help="Device to run on. (Default is cuda)",
help=f"Device to run on. (Default is {default_device_type()})",
)
def main(device_type):
# Load documents and split in chunks
Expand Down Expand Up @@ -171,7 +172,7 @@ def main(device_type):
persist_directory=PERSIST_DIRECTORY,
client_settings=CHROMA_SETTINGS,
)



if __name__ == "__main__":
Expand Down
9 changes: 5 additions & 4 deletions run_localGPT.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import click
import torch
import utils
from utils import default_device_type
from langchain.chains import RetrievalQA
from langchain.embeddings import HuggingFaceInstructEmbeddings
from langchain.llms import HuggingFacePipeline
Expand Down Expand Up @@ -161,11 +162,11 @@ def retrieval_qa_pipline(device_type, use_history, promptTemplate_type="llama"):
return qa


# chose device typ to run on as well as to show source documents.
# chose device type to run on as well as to show source documents.
@click.command()
@click.option(
"--device_type",
default="cuda" if torch.cuda.is_available() else "cpu",
default=default_device_type(),
type=click.Choice(
[
"cpu",
Expand All @@ -189,7 +190,7 @@ def retrieval_qa_pipline(device_type, use_history, promptTemplate_type="llama"):
"mtia",
],
),
help="Device to run on. (Default is cuda)",
help=f"Device to run on. (Default is {default_device_type()})",
)
@click.option(
"--show_sources",
Expand Down Expand Up @@ -269,7 +270,7 @@ def main(device_type, show_sources, use_history, model_type, save_qa):
print("\n> " + document.metadata["source"] + ":")
print(document.page_content)
print("----------------------------------SOURCE DOCUMENTS---------------------------")

# Log the Q&A to CSV only if save_qa is True
if save_qa:
utils.log_to_csv(query, answer)
Expand Down
10 changes: 9 additions & 1 deletion utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import os
import csv
import torch
import sys
from datetime import datetime

def log_to_csv(question, answer):
Expand All @@ -22,4 +24,10 @@ def log_to_csv(question, answer):
with open(log_path, mode='a', newline='', encoding='utf-8') as file:
writer = csv.writer(file)
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
writer.writerow([timestamp, question, answer])
writer.writerow([timestamp, question, answer])

def has_mps():
return sys.platform == "darwin" and torch.backends.mps.is_available()

def default_device_type():
return "cuda" if torch.cuda.is_available() else "mps" if has_mps() else "cpu"