Skip to content

Commit

Permalink
Fix doc indexing and use accelerator (#30)
Browse files Browse the repository at this point in the history
Signed-off-by: Aivin V. Solatorio <[email protected]>
  • Loading branch information
avsolatorio authored Feb 26, 2024
1 parent 312ae33 commit 5c284d6
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 4 deletions.
14 changes: 12 additions & 2 deletions llm4data/embeddings/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,22 @@
from qdrant_client.http import models
from pydantic.main import ModelMetaclass
from dataclasses import dataclass, asdict

import torch

# Make the model atomically available
LOADED_MODELS: dict = {}


def get_device():
device = "cpu"
if torch.cuda.device_count() > 0:
device = "cuda:0"
elif torch.backends.mps.is_available():
device = "mps:0"

return device


@dataclass
class BaseEmbeddingModel:
model_size = {
Expand Down Expand Up @@ -101,7 +111,7 @@ def _create_embeddings(self):
raise ValueError("`config.kwargs` must be a dict")

self.embeddings = getattr(langchain_embeddings, self.embedding_cls)(
**self.kwargs
**{"model_kwargs": {"device": get_device()}, **self.kwargs}
)

if self.max_tokens is None and self.embeddings:
Expand Down
2 changes: 1 addition & 1 deletion llm4data/scripts/indexing/docs/docs.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Optional, Union
from pathlib import Path
from langchain.docstore.document import Document
from langchain.document_loaders import PyMuPDFLoader
from langchain_community.document_loaders import PyMuPDFLoader
from langchain.text_splitter import (
NLTKTextSplitter,
CharacterTextSplitter,
Expand Down
2 changes: 1 addition & 1 deletion llm4data/scripts/indexing/docs/load_docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,5 +106,5 @@ def main(path: Union[str, Path], strict: bool = False):


if __name__ == "__main__":
# python -m llm4data.scripts.indexing.docs.load_docs --path=data/knowledge/docs/prwp/pdf --strict
# python -m llm4data.scripts.indexing.docs.load_docs --path=data/sources/docs/prwp/pdf --strict
fire.Fire(main)

0 comments on commit 5c284d6

Please sign in to comment.