diff --git a/README.md b/README.md index 94983bc0..61c1dd32 100644 --- a/README.md +++ b/README.md @@ -28,10 +28,10 @@ pip install fastembed-gpu ```python from fastembed import TextEmbedding -from typing import List + # Example list of documents -documents: List[str] = [ +documents: list[str] = [ "This is built to be faster and lighter than other embedding libraries e.g. Transformers, Sentence-Transformers, etc.", "fastembed is supported by and maintained by Qdrant.", ] @@ -139,11 +139,10 @@ embeddings = list(model.embed(images)) ### 🔄 Rerankers ```python -from typing import List from fastembed.rerank.cross_encoder import TextCrossEncoder query = "Who is maintaining Qdrant?" -documents: List[str] = [ +documents: list[str] = [ "This is built to be faster and lighter than other embedding libraries e.g. Transformers, Sentence-Transformers, etc.", "fastembed is supported by and maintained by Qdrant.", ] diff --git a/docs/examples/Hindi_Tamil_RAG_with_Navarasa7B.ipynb b/docs/examples/Hindi_Tamil_RAG_with_Navarasa7B.ipynb index e18dbc15..ef3cc2d7 100644 --- a/docs/examples/Hindi_Tamil_RAG_with_Navarasa7B.ipynb +++ b/docs/examples/Hindi_Tamil_RAG_with_Navarasa7B.ipynb @@ -44,7 +44,7 @@ }, { "cell_type": "code", - "execution_count": 21, + "execution_count": null, "metadata": { "ExecuteTime": { "end_time": "2024-03-30T00:45:24.814968Z", @@ -58,8 +58,6 @@ }, "outputs": [], "source": [ - "from typing import List\n", - "\n", "import numpy as np\n", "from datasets import load_dataset\n", "from peft import AutoPeftModelForCausalLM\n", @@ -72,11 +70,11 @@ }, { "cell_type": "code", - "execution_count": 23, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ - "hf_token = # Get your token from https://huggingface.co/settings/token, needed for Gemma weights" + "hf_token = \"\" # Get your token from https://huggingface.co/settings/token, needed for Gemma weights" ] }, { @@ -246,7 +244,7 @@ }, "outputs": [], "source": [ - "context_embeddings: List[np.ndarray] = list(\n", + "context_embeddings: list[np.ndarray] = list(\n", " embedding_model.embed(contexts)\n", ") # Note the list() call - this is a generator" ] diff --git a/docs/examples/SPLADE_with_FastEmbed.ipynb b/docs/examples/SPLADE_with_FastEmbed.ipynb index 457b4056..c36baf87 100644 --- a/docs/examples/SPLADE_with_FastEmbed.ipynb +++ b/docs/examples/SPLADE_with_FastEmbed.ipynb @@ -47,7 +47,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": null, "metadata": { "ExecuteTime": { "end_time": "2024-03-30T00:49:20.516644Z", @@ -56,8 +56,7 @@ }, "outputs": [], "source": [ - "from fastembed import SparseTextEmbedding, SparseEmbedding\n", - "from typing import List" + "from fastembed import SparseTextEmbedding, SparseEmbedding" ] }, { @@ -134,7 +133,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": null, "metadata": { "ExecuteTime": { "end_time": "2024-03-30T00:49:28.624109Z", @@ -143,7 +142,7 @@ }, "outputs": [], "source": [ - "documents: List[str] = [\n", + "documents: list[str] = [\n", " \"Chandrayaan-3 is India's third lunar mission\",\n", " \"It aimed to land a rover on the Moon's surface - joining the US, China and Russia\",\n", " \"The mission is a follow-up to Chandrayaan-2, which had partial success\",\n", @@ -157,7 +156,7 @@ " \"Chandrayaan-3 was launched from the Satish Dhawan Space Centre in Sriharikota\",\n", " \"Chandrayaan-3 was launched earlier in the year 2023\",\n", "]\n", - "sparse_embeddings_list: List[SparseEmbedding] = list(\n", + "sparse_embeddings_list: list[SparseEmbedding] = list(\n", " model.embed(documents, batch_size=6)\n", ") # batch_size is optional, notice the generator" ] @@ -235,7 +234,9 @@ "source": [ "# Let's print the first 5 features and their weights for better understanding.\n", "for i in range(5):\n", - " print(f\"Token at index {sparse_embeddings_list[0].indices[i]} has weight {sparse_embeddings_list[0].values[i]}\")" + " print(\n", + " f\"Token at index {sparse_embeddings_list[0].indices[i]} has weight {sparse_embeddings_list[0].values[i]}\"\n", + " )" ] }, { @@ -261,7 +262,9 @@ "import json\n", "from transformers import AutoTokenizer\n", "\n", - "tokenizer = AutoTokenizer.from_pretrained(SparseTextEmbedding.list_supported_models()[0][\"sources\"][\"hf\"])" + "tokenizer = AutoTokenizer.from_pretrained(\n", + " SparseTextEmbedding.list_supported_models()[0][\"sources\"][\"hf\"]\n", + ")" ] }, { @@ -326,7 +329,9 @@ " token_weight_dict[token] = weight\n", "\n", " # Sort the dictionary by weights\n", - " token_weight_dict = dict(sorted(token_weight_dict.items(), key=lambda item: item[1], reverse=True))\n", + " token_weight_dict = dict(\n", + " sorted(token_weight_dict.items(), key=lambda item: item[1], reverse=True)\n", + " )\n", " return token_weight_dict\n", "\n", "\n", diff --git a/docs/index.md b/docs/index.md index e4b8f0c7..15ea646a 100644 --- a/docs/index.md +++ b/docs/index.md @@ -26,14 +26,14 @@ pip install fastembed ```python from fastembed import TextEmbedding -documents: List[str] = [ +documents: list[str] = [ "passage: Hello, World!", "query: Hello, World!", "passage: This is an example passage.", "fastembed is supported by and maintained by Qdrant." ] embedding_model = TextEmbedding() -embeddings: List[np.ndarray] = embedding_model.embed(documents) +embeddings: list[np.ndarray] = embedding_model.embed(documents) ``` ## Usage with Qdrant diff --git a/fastembed/common/model_management.py b/fastembed/common/model_management.py index 5ce95a49..33f513c8 100644 --- a/fastembed/common/model_management.py +++ b/fastembed/common/model_management.py @@ -7,7 +7,11 @@ import requests from huggingface_hub import snapshot_download -from huggingface_hub.utils import RepositoryNotFoundError +from huggingface_hub.utils import ( + RepositoryNotFoundError, + disable_progress_bars, + enable_progress_bars, +) from loguru import logger from tqdm import tqdm @@ -93,7 +97,7 @@ def download_file_from_gcs(cls, url: str, output_path: str, show_progress: bool def download_files_from_huggingface( cls, hf_source_repo: str, - cache_dir: Optional[str] = None, + cache_dir: str, extra_patterns: Optional[list[str]] = None, local_files_only: bool = False, **kwargs, @@ -119,6 +123,12 @@ def download_files_from_huggingface( if extra_patterns is not None: allow_patterns.extend(extra_patterns) + snapshot_dir = Path(cache_dir) / f"models--{hf_source_repo.replace('/', '--')}" + is_cached = snapshot_dir.exists() + + if is_cached: + disable_progress_bars() + return snapshot_download( repo_id=hf_source_repo, allow_patterns=allow_patterns, @@ -265,6 +275,8 @@ def download_model( f"Could not download model from HuggingFace: {e} " "Falling back to other sources." ) + finally: + enable_progress_bars() if url_source or local_files_only: try: return cls.retrieve_model_gcs( diff --git a/fastembed/common/utils.py b/fastembed/common/utils.py index cd50dfdf..3a8bb7ff 100644 --- a/fastembed/common/utils.py +++ b/fastembed/common/utils.py @@ -1,13 +1,13 @@ import os +import sys +import re import tempfile -from itertools import islice +import unicodedata from pathlib import Path +from itertools import islice from typing import Generator, Iterable, Optional, Union -import unicodedata -import sys + import numpy as np -import re -from typing import Set def normalize(input_array, p=2, dim=1, eps=1e-12) -> np.ndarray: @@ -45,7 +45,7 @@ def define_cache_dir(cache_dir: Optional[str] = None) -> Path: return cache_path -def get_all_punctuation() -> Set[str]: +def get_all_punctuation() -> set[str]: return set( chr(i) for i in range(sys.maxunicode) if unicodedata.category(chr(i)).startswith("P") ) diff --git a/fastembed/image/onnx_embedding.py b/fastembed/image/onnx_embedding.py index c4cbf348..5647c2ff 100644 --- a/fastembed/image/onnx_embedding.py +++ b/fastembed/image/onnx_embedding.py @@ -80,7 +80,7 @@ def __init__( Mutually exclusive with the `cuda` and `device_ids` arguments. Defaults to None. cuda (bool, optional): Whether to use cuda for inference. Mutually exclusive with `providers` Defaults to False. - device_ids (Optional[List[int]], optional): The list of device ids to use for data parallel processing in + device_ids (Optional[list[int]], optional): The list of device ids to use for data parallel processing in workers. Should be used with `cuda=True`, mutually exclusive with `providers`. Defaults to None. lazy_load (bool, optional): Whether to load the model during class initialization or on demand. Should be set to True when using multiple-gpu and parallel encoding. Defaults to False. @@ -134,7 +134,7 @@ def list_supported_models(cls) -> list[dict[str, Any]]: Lists the supported models. Returns: - List[Dict[str, Any]]: A list of dictionaries containing the model information. + list[Dict[str, Any]]: A list of dictionaries containing the model information. """ return supported_onnx_models diff --git a/tests/profiling.py b/tests/profiling.py index 75db178e..1edf3b04 100644 --- a/tests/profiling.py +++ b/tests/profiling.py @@ -9,7 +9,7 @@ # %% import time -from typing import Callable, List, Tuple +from typing import Callable import matplotlib.pyplot as plt import torch.nn.functional as F @@ -23,7 +23,7 @@ # data is a list of strings, each string is a document. # %% -documents: List[str] = [ +documents: list[str] = [ "Chandrayaan-3 is India's third lunar mission", "It aimed to land a rover on the Moon's surface - joining the US, China and Russia", "The mission is a follow-up to Chandrayaan-2, which had partial success", @@ -56,7 +56,7 @@ def __init__(self, model_id: str): self.model = AutoModel.from_pretrained(model_id) self.tokenizer = AutoTokenizer.from_pretrained(model_id) - def embed(self, texts: List[str]): + def embed(self, texts: list[str]): encoded_input = self.tokenizer( texts, max_length=512, padding=True, truncation=True, return_tensors="pt" ) @@ -88,7 +88,7 @@ def embed(self, texts: List[str]): # %% def calculate_time_stats( embed_func: Callable, documents: list, k: int -) -> Tuple[float, float, float]: +) -> tuple[float, float, float]: times = [] for _ in range(k): # Timing the embed_func call @@ -111,8 +111,8 @@ def calculate_time_stats( # %% def plot_character_per_second_comparison( - hf_stats: Tuple[float, float, float], - fst_stats: Tuple[float, float, float], + hf_stats: tuple[float, float, float], + fst_stats: tuple[float, float, float], documents: list, ): # Calculating total characters in documents