Skip to content

Commit

Permalink
Improve models cache progressbar (#406)
Browse files Browse the repository at this point in the history
* chore: Remove typing hints of Python less than 3.9

* chore: Removed optional from cache as it cannot be undefined

* improve: Turned off progress bar of huggingface models if cached
  • Loading branch information
hh-space-invader authored Nov 21, 2024
1 parent e9dc3b1 commit adfc03e
Show file tree
Hide file tree
Showing 8 changed files with 51 additions and 37 deletions.
7 changes: 3 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,10 @@ pip install fastembed-gpu

```python
from fastembed import TextEmbedding
from typing import List


# Example list of documents
documents: List[str] = [
documents: list[str] = [
"This is built to be faster and lighter than other embedding libraries e.g. Transformers, Sentence-Transformers, etc.",
"fastembed is supported by and maintained by Qdrant.",
]
Expand Down Expand Up @@ -139,11 +139,10 @@ embeddings = list(model.embed(images))

### 🔄 Rerankers
```python
from typing import List
from fastembed.rerank.cross_encoder import TextCrossEncoder

query = "Who is maintaining Qdrant?"
documents: List[str] = [
documents: list[str] = [
"This is built to be faster and lighter than other embedding libraries e.g. Transformers, Sentence-Transformers, etc.",
"fastembed is supported by and maintained by Qdrant.",
]
Expand Down
10 changes: 4 additions & 6 deletions docs/examples/Hindi_Tamil_RAG_with_Navarasa7B.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
},
{
"cell_type": "code",
"execution_count": 21,
"execution_count": null,
"metadata": {
"ExecuteTime": {
"end_time": "2024-03-30T00:45:24.814968Z",
Expand All @@ -58,8 +58,6 @@
},
"outputs": [],
"source": [
"from typing import List\n",
"\n",
"import numpy as np\n",
"from datasets import load_dataset\n",
"from peft import AutoPeftModelForCausalLM\n",
Expand All @@ -72,11 +70,11 @@
},
{
"cell_type": "code",
"execution_count": 23,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"hf_token = <YOUR_HF_TOKEN_HERE> # Get your token from https://huggingface.co/settings/token, needed for Gemma weights"
"hf_token = \"<YOUR_HF_TOKEN_HERE>\" # Get your token from https://huggingface.co/settings/token, needed for Gemma weights"
]
},
{
Expand Down Expand Up @@ -246,7 +244,7 @@
},
"outputs": [],
"source": [
"context_embeddings: List[np.ndarray] = list(\n",
"context_embeddings: list[np.ndarray] = list(\n",
" embedding_model.embed(contexts)\n",
") # Note the list() call - this is a generator"
]
Expand Down
23 changes: 14 additions & 9 deletions docs/examples/SPLADE_with_FastEmbed.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": null,
"metadata": {
"ExecuteTime": {
"end_time": "2024-03-30T00:49:20.516644Z",
Expand All @@ -56,8 +56,7 @@
},
"outputs": [],
"source": [
"from fastembed import SparseTextEmbedding, SparseEmbedding\n",
"from typing import List"
"from fastembed import SparseTextEmbedding, SparseEmbedding"
]
},
{
Expand Down Expand Up @@ -134,7 +133,7 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": null,
"metadata": {
"ExecuteTime": {
"end_time": "2024-03-30T00:49:28.624109Z",
Expand All @@ -143,7 +142,7 @@
},
"outputs": [],
"source": [
"documents: List[str] = [\n",
"documents: list[str] = [\n",
" \"Chandrayaan-3 is India's third lunar mission\",\n",
" \"It aimed to land a rover on the Moon's surface - joining the US, China and Russia\",\n",
" \"The mission is a follow-up to Chandrayaan-2, which had partial success\",\n",
Expand All @@ -157,7 +156,7 @@
" \"Chandrayaan-3 was launched from the Satish Dhawan Space Centre in Sriharikota\",\n",
" \"Chandrayaan-3 was launched earlier in the year 2023\",\n",
"]\n",
"sparse_embeddings_list: List[SparseEmbedding] = list(\n",
"sparse_embeddings_list: list[SparseEmbedding] = list(\n",
" model.embed(documents, batch_size=6)\n",
") # batch_size is optional, notice the generator"
]
Expand Down Expand Up @@ -235,7 +234,9 @@
"source": [
"# Let's print the first 5 features and their weights for better understanding.\n",
"for i in range(5):\n",
" print(f\"Token at index {sparse_embeddings_list[0].indices[i]} has weight {sparse_embeddings_list[0].values[i]}\")"
" print(\n",
" f\"Token at index {sparse_embeddings_list[0].indices[i]} has weight {sparse_embeddings_list[0].values[i]}\"\n",
" )"
]
},
{
Expand All @@ -261,7 +262,9 @@
"import json\n",
"from transformers import AutoTokenizer\n",
"\n",
"tokenizer = AutoTokenizer.from_pretrained(SparseTextEmbedding.list_supported_models()[0][\"sources\"][\"hf\"])"
"tokenizer = AutoTokenizer.from_pretrained(\n",
" SparseTextEmbedding.list_supported_models()[0][\"sources\"][\"hf\"]\n",
")"
]
},
{
Expand Down Expand Up @@ -326,7 +329,9 @@
" token_weight_dict[token] = weight\n",
"\n",
" # Sort the dictionary by weights\n",
" token_weight_dict = dict(sorted(token_weight_dict.items(), key=lambda item: item[1], reverse=True))\n",
" token_weight_dict = dict(\n",
" sorted(token_weight_dict.items(), key=lambda item: item[1], reverse=True)\n",
" )\n",
" return token_weight_dict\n",
"\n",
"\n",
Expand Down
4 changes: 2 additions & 2 deletions docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,14 @@ pip install fastembed
```python
from fastembed import TextEmbedding

documents: List[str] = [
documents: list[str] = [
"passage: Hello, World!",
"query: Hello, World!",
"passage: This is an example passage.",
"fastembed is supported by and maintained by Qdrant."
]
embedding_model = TextEmbedding()
embeddings: List[np.ndarray] = embedding_model.embed(documents)
embeddings: list[np.ndarray] = embedding_model.embed(documents)
```

## Usage with Qdrant
Expand Down
16 changes: 14 additions & 2 deletions fastembed/common/model_management.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,11 @@

import requests
from huggingface_hub import snapshot_download
from huggingface_hub.utils import RepositoryNotFoundError
from huggingface_hub.utils import (
RepositoryNotFoundError,
disable_progress_bars,
enable_progress_bars,
)
from loguru import logger
from tqdm import tqdm

Expand Down Expand Up @@ -93,7 +97,7 @@ def download_file_from_gcs(cls, url: str, output_path: str, show_progress: bool
def download_files_from_huggingface(
cls,
hf_source_repo: str,
cache_dir: Optional[str] = None,
cache_dir: str,
extra_patterns: Optional[list[str]] = None,
local_files_only: bool = False,
**kwargs,
Expand All @@ -119,6 +123,12 @@ def download_files_from_huggingface(
if extra_patterns is not None:
allow_patterns.extend(extra_patterns)

snapshot_dir = Path(cache_dir) / f"models--{hf_source_repo.replace('/', '--')}"
is_cached = snapshot_dir.exists()

if is_cached:
disable_progress_bars()

return snapshot_download(
repo_id=hf_source_repo,
allow_patterns=allow_patterns,
Expand Down Expand Up @@ -265,6 +275,8 @@ def download_model(
f"Could not download model from HuggingFace: {e} "
"Falling back to other sources."
)
finally:
enable_progress_bars()
if url_source or local_files_only:
try:
return cls.retrieve_model_gcs(
Expand Down
12 changes: 6 additions & 6 deletions fastembed/common/utils.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
import os
import sys
import re
import tempfile
from itertools import islice
import unicodedata
from pathlib import Path
from itertools import islice
from typing import Generator, Iterable, Optional, Union
import unicodedata
import sys

import numpy as np
import re
from typing import Set


def normalize(input_array, p=2, dim=1, eps=1e-12) -> np.ndarray:
Expand Down Expand Up @@ -45,7 +45,7 @@ def define_cache_dir(cache_dir: Optional[str] = None) -> Path:
return cache_path


def get_all_punctuation() -> Set[str]:
def get_all_punctuation() -> set[str]:
return set(
chr(i) for i in range(sys.maxunicode) if unicodedata.category(chr(i)).startswith("P")
)
Expand Down
4 changes: 2 additions & 2 deletions fastembed/image/onnx_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def __init__(
Mutually exclusive with the `cuda` and `device_ids` arguments. Defaults to None.
cuda (bool, optional): Whether to use cuda for inference. Mutually exclusive with `providers`
Defaults to False.
device_ids (Optional[List[int]], optional): The list of device ids to use for data parallel processing in
device_ids (Optional[list[int]], optional): The list of device ids to use for data parallel processing in
workers. Should be used with `cuda=True`, mutually exclusive with `providers`. Defaults to None.
lazy_load (bool, optional): Whether to load the model during class initialization or on demand.
Should be set to True when using multiple-gpu and parallel encoding. Defaults to False.
Expand Down Expand Up @@ -134,7 +134,7 @@ def list_supported_models(cls) -> list[dict[str, Any]]:
Lists the supported models.
Returns:
List[Dict[str, Any]]: A list of dictionaries containing the model information.
list[Dict[str, Any]]: A list of dictionaries containing the model information.
"""
return supported_onnx_models

Expand Down
12 changes: 6 additions & 6 deletions tests/profiling.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

# %%
import time
from typing import Callable, List, Tuple
from typing import Callable

import matplotlib.pyplot as plt
import torch.nn.functional as F
Expand All @@ -23,7 +23,7 @@
# data is a list of strings, each string is a document.

# %%
documents: List[str] = [
documents: list[str] = [
"Chandrayaan-3 is India's third lunar mission",
"It aimed to land a rover on the Moon's surface - joining the US, China and Russia",
"The mission is a follow-up to Chandrayaan-2, which had partial success",
Expand Down Expand Up @@ -56,7 +56,7 @@ def __init__(self, model_id: str):
self.model = AutoModel.from_pretrained(model_id)
self.tokenizer = AutoTokenizer.from_pretrained(model_id)

def embed(self, texts: List[str]):
def embed(self, texts: list[str]):
encoded_input = self.tokenizer(
texts, max_length=512, padding=True, truncation=True, return_tensors="pt"
)
Expand Down Expand Up @@ -88,7 +88,7 @@ def embed(self, texts: List[str]):
# %%
def calculate_time_stats(
embed_func: Callable, documents: list, k: int
) -> Tuple[float, float, float]:
) -> tuple[float, float, float]:
times = []
for _ in range(k):
# Timing the embed_func call
Expand All @@ -111,8 +111,8 @@ def calculate_time_stats(

# %%
def plot_character_per_second_comparison(
hf_stats: Tuple[float, float, float],
fst_stats: Tuple[float, float, float],
hf_stats: tuple[float, float, float],
fst_stats: tuple[float, float, float],
documents: list,
):
# Calculating total characters in documents
Expand Down

0 comments on commit adfc03e

Please sign in to comment.