-
Notifications
You must be signed in to change notification settings - Fork 2.4k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[
feat
] Add binary & scalar embedding quantization support to Senten…
…ce Transformers (#2549) * Add quantization support to 'encode' + semantic search helpers * Add some examples for embedding quantization * Fix example script * Also add faiss benchmark script; fix exact * Allow embedding pre-quantization to be e.g. float16 or float64 as well * oversampling -> rescore_multiplier; rerank -> rescore * Clarify that we only search for more samples if rescoring happens * oversampling factor -> rescore_multiplier * Be more explicit with warnings if rescore is True but not used * Update examples with new signatures * Simplify FAISS index creation for uint8 * Remove non-ST-style typing * Move quantization code to separate file * Add quantization to API reference * Add example of recommended search * Reformatted * Add documentation page about Embedding Quantization * Update hyperlink * Add and fix references * Add Combining Binary and Scalar Quantization to docs * #Demo -> #demo
- Loading branch information
Showing
13 changed files
with
1,028 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
# quantization | ||
`sentence_transformers.quantization` defines different helpful functions to quantize. | ||
|
||
```eval_rst | ||
.. automodule:: sentence_transformers.quantization | ||
:members: quantize_embeddings, semantic_search_faiss, semantic_search_usearch | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,169 @@ | ||
# Embedding Quantization | ||
|
||
Embeddings may be challenging to scale up, which leads to expensive solutions and high latencies. Currently, many state-of-the-art models produce embeddings with 1024 dimensions, each of which is encoded in `float32`, i.e., they require 4 bytes per dimension. To perform retrieval over 50 million vectors, you would therefore need around 200GB of memory. This tends to require complex and costly solutions at scale. | ||
|
||
However, there is a new approach to counter this problem; it entails reducing the size of each of the individual values in the embedding: **Quantization**. Experiments on quantization have shown that we can maintain a large amount of performance while significantly speeding up computation and saving on memory, storage, and costs. | ||
|
||
To learn more about Embedding Quantization and their performance, please read the [blogpost](https://huggingface.co/blog/embedding-quantization) by Sentence Transformers and mixedbread.ai. | ||
|
||
## Binary Quantization | ||
|
||
Binary quantization refers to the conversion of the `float32` values in an embedding to 1-bit values, resulting in a 32x reduction in memory and storage usage. To quantize `float32` embeddings to binary, we simply threshold normalized embeddings at 0: if the value is larger than 0, we make it 1, otherwise we convert it to 0. We can use the Hamming Distance to efficiently perform retrieval with these binary embeddings. This is simply the number of positions at which the bits of two binary embeddings differ. The lower the Hamming Distance, the closer the embeddings, and thus the more relevant the document. A huge advantage of the Hamming Distance is that it can be easily calculated with 2 CPU cycles, allowing for blazingly fast performance. | ||
|
||
[Yamada et al. (2021)](https://arxiv.org/abs/2106.00882) introduced a rescore step, which they called *rerank*, to boost the performance. They proposed that the `float32` query embedding could be compared with the binary document embeddings using dot-product. In practice, we first retrieve `rescore_multiplier * top_k` results with the binary query embedding and the binary document embeddings -- i.e., the list of the first k results of the double-binary retrieval -- and then rescore that list of binary document embeddings with the `float32` query embedding. | ||
|
||
By applying this novel rescoring step, we are able to preserve up to ~96% of the total retrieval performance, while reducing the memory and disk space usage by 32x and improving the retrieval speed by up to 32x as well. | ||
|
||
### Binary Quantization in Sentence Transformers | ||
|
||
Quantizing an embedding with a dimensionality of 1024 to binary would result in 1024 bits. In practice, it is much more common to store bits as bytes instead, so when we quantize to binary embeddings, we pack the bits into bytes using `np.packbits`. | ||
|
||
As a result, in practice quantizing a `float32` embedding with a dimensionality of 1024 yields an `int8` or `uint8` embedding with a dimensionality of 128. See two approaches of how you can produce quantized embeddings using Sentence Transformers below: | ||
|
||
```python | ||
from sentence_transformers import SentenceTransformer | ||
from sentence_transformers.quantization import quantize_embeddings | ||
|
||
# 1. Load an embedding model | ||
model = SentenceTransformer("mixedbread-ai/mxbai-embed-large-v1") | ||
|
||
# 2a. Encode some text using "binary" quantization | ||
binary_embeddings = model.encode( | ||
["I am driving to the lake.", "It is a beautiful day."], | ||
precision="binary", | ||
) | ||
|
||
# 2b. or, encode some text without quantization & apply quantization afterwards | ||
embeddings = model.encode(["I am driving to the lake.", "It is a beautiful day."]) | ||
binary_embeddings = quantize_embeddings(embeddings, precision="binary") | ||
``` | ||
|
||
**References:** | ||
* <a href="https://huggingface.co/mixedbread-ai/mxbai-embed-large-v1"><code>mixedbread-ai/mxbai-embed-large-v1</code></a> | ||
* <a href="../../../docs/package_reference/SentenceTransformer.html#sentence_transformers.SentenceTransformer.encode"><code>SentenceTransformer.encode</code></a> | ||
* <a href="../../../docs/package_reference/quantization.html#sentence_transformers.quantization.quantize_embeddings"><code>quantize_embeddings</code></a> | ||
|
||
Here you can see the differences between default `float32` embeddings and binary embeddings in terms of shape, size, and `numpy` dtype: | ||
|
||
```python | ||
>>> embeddings.shape | ||
(2, 1024) | ||
>>> embeddings.nbytes | ||
8192 | ||
>>> embeddings.dtype | ||
float32 | ||
>>> binary_embeddings.shape | ||
(2, 128) | ||
>>> binary_embeddings.nbytes | ||
256 | ||
>>> binary_embeddings.dtype | ||
int8 | ||
``` | ||
Note that you can also choose `"ubinary"` to quantize to binary using the unsigned `uint8` data format. This may be a requirement for your vector library/database. | ||
|
||
## Scalar (int8) Quantization | ||
|
||
To convert the `float32` embeddings into `int8`, we use a process called scalar quantization. This involves mapping the continuous range of `float32` values to the discrete set of `int8` values, which can represent 256 distinct levels (from -128 to 127) as shown in the image below. This is done by using a large calibration dataset of embeddings. We compute the range of these embeddings, i.e. the `min` and `max` of each of the embedding dimensions. From there, we calculate the steps (buckets) in which we categorize each value. | ||
|
||
To further boost the retrieval performance, you can optionally apply the same rescoring step as for the binary embeddings. It is important to note here that the calibration dataset has a large influence on the performance, since it defines the buckets. | ||
|
||
### Scalar Quantization in Sentence Transformers | ||
|
||
Quantizing an embedding with a dimensionality of 1024 to `int8` results in 1024 bytes. In practice, we can choose either `uint8` or `int8`. This choice is usually made depending on what your vector library/database supports. | ||
|
||
In practice, it is recommended to provide the scalar quantization with either: | ||
1. a large set of embeddings to quantize all at once, or | ||
2. `min` and `max` ranges for each of the embedding dimensions, or | ||
3. a large calibration dataset of embeddings from which the `min` and `max` ranges can be computed. | ||
|
||
If none of these are the case, you will be given a warning like this: | ||
|
||
``` | ||
Computing int8 quantization buckets based on 2 embeddings. int8 quantization is more stable with 'ranges' calculated from more embeddings or a 'calibration_embeddings' that can be used to calculate the buckets. | ||
``` | ||
|
||
See how you can produce scalar quantized embeddings using Sentence Transformers below: | ||
|
||
```python | ||
from sentence_transformers import SentenceTransformer | ||
from sentence_transformers.quantization import quantize_embeddings | ||
from datasets import load_dataset | ||
|
||
# 1. Load an embedding model | ||
model = SentenceTransformer("mixedbread-ai/mxbai-embed-large-v1") | ||
|
||
# 2. Prepare an example calibration dataset | ||
corpus = load_dataset("nq_open", split="train[:1000]")["question"] | ||
calibration_embeddings = model.encode(corpus) | ||
|
||
# 3. Encode some text without quantization & apply quantization afterwards | ||
embeddings = model.encode(["I am driving to the lake.", "It is a beautiful day."]) | ||
int8_embeddings = quantize_embeddings( | ||
embeddings, | ||
precision="int8", | ||
calibration_embeddings=calibration_embeddings, | ||
) | ||
``` | ||
|
||
**References:** | ||
* <a href="https://huggingface.co/mixedbread-ai/mxbai-embed-large-v1"><code>mixedbread-ai/mxbai-embed-large-v1</code></a> | ||
* <a href="../../../docs/package_reference/SentenceTransformer.html#sentence_transformers.SentenceTransformer.encode"><code>SentenceTransformer.encode</code></a> | ||
* <a href="../../../docs/package_reference/quantization.html#sentence_transformers.quantization.quantize_embeddings"><code>quantize_embeddings</code></a> | ||
|
||
Here you can see the differences between default `float32` embeddings and `int8` scalar embeddings in terms of shape, size, and `numpy` dtype: | ||
|
||
```python | ||
>>> embeddings.shape | ||
(2, 1024) | ||
>>> embeddings.nbytes | ||
8192 | ||
>>> embeddings.dtype | ||
float32 | ||
>>> int8_embeddings.shape | ||
(2, 1024) | ||
>>> int8_embeddings.nbytes | ||
2048 | ||
>>> int8_embeddings.dtype | ||
int8 | ||
``` | ||
|
||
### Combining Binary and Scalar Quantization | ||
|
||
It is possible to combine binary and scalar quantization to get the best of both worlds: the extreme speed from binary embeddings and the great performance preservation of scalar embeddings with rescoring. See the [demo](#demo) below for a real-life implementation of this approach involving 41 million texts from Wikipedia. The pipeline for that setup is as follows: | ||
|
||
1. The query is embedded using the [`mixedbread-ai/mxbai-embed-large-v1`](https://huggingface.co/mixedbread-ai/mxbai-embed-large-v1) SentenceTransformer model. | ||
2. The query is quantized to binary using the <a href="../../../docs/package_reference/quantization.html#sentence_transformers.quantization.quantize_embeddings"><code>quantize_embeddings</code></a> function from the `sentence-transformers` library. | ||
3. A binary index (41M binary embeddings; 5.2GB of memory/disk space) is searched using the quantized query for the top 40 documents. | ||
4. The top 40 documents are loaded on the fly from an int8 index on disk (41M int8 embeddings; 0 bytes of memory, 47.5GB of disk space). | ||
5. The top 40 documents are rescored using the float32 query and the int8 embeddings to get the top 10 documents. | ||
6. The top 10 documents are sorted by score and displayed. | ||
|
||
Through this approach, we use 5.2GB of memory and 52GB of disk space for the indices. This is considerably less than normal retrieval, for which we would require 200GB of memory and 200GB of disk space. Especially as you scale up even further, this will result in notable reductions in both latency and costs. | ||
|
||
## Additional extensions | ||
|
||
Note that embedding quantization can be combined with other approaches to improve retrieval efficiency, such as [Matryoshka Embeddings](../../training/matryoshka/README.md). Additionally, the [Retrieve & Re-Rank](../retrieve_rerank/README.md) also works very well with quantized embeddings, i.e. you can still use a Cross-Encoder to rerank. | ||
|
||
## Demo | ||
|
||
The following demo showcases the retrieval efficiency using `exact` search through combining binary search with scalar (`int8`) rescoring. The solution requires 5GB of memory for the binary index and 50GB of disk space for the binary and scalar indices, considerably less than the 200GB of memory and disk space which would be required for regular `float32` retrieval. Additionally, retrieval is much faster. | ||
|
||
<iframe | ||
src="https://sentence-transformers-quantized-retrieval.hf.space" | ||
frameborder="0" | ||
width="100%" | ||
height="1000" | ||
></iframe> | ||
## Try it yourself | ||
|
||
The following scripts can be used to experiment with embedding quantization for retrieval & beyond. There are three categories: | ||
|
||
* **Recommended Retrieval**: | ||
* [semantic_search_recommended.py](semantic_search_recommended.py): This script combines binary search with scalar rescoring, much like the above demo, for cheap, efficient, and performant retrieval. | ||
* **Usage**: | ||
* [semantic_search_faiss.py](semantic_search_faiss.py): This script showcases regular usage of binary or scalar quantization, retrieval, and rescoring using FAISS, by using the <a href="../../../docs/package_reference/quantization.html#sentence_transformers.quantization.semantic_search_faiss"><code>semantic_search_faiss</code></a> utility function. | ||
* [semantic_search_usearch.py](semantic_search_usearch.py): This script showcases regular usage of binary or scalar quantization, retrieval, and rescoring using USearch, by using the <a href="../../../docs/package_reference/quantization.html#sentence_transformers.quantization.semantic_search_usearch"><code>semantic_search_usearch</code></a> utility function. | ||
* **Benchmarks**: | ||
* [semantic_search_faiss_benchmark.py](semantic_search_faiss_benchmark.py): This script includes a retrieval speed benchmark of `float32` retrieval, binary retrieval + rescoring, and scalar retrieval + rescoring, using FAISS. It uses the <a href="../../../docs/package_reference/quantization.html#sentence_transformers.quantization.semantic_search_faiss"><code>semantic_search_faiss</code></a> utility function. Our benchmarks especially show show speedups for `ubinary`. | ||
* [semantic_search_usearch_benchmark.py](semantic_search_usearch_benchmark.py): This script includes a retrieval speed benchmark of `float32` retrieval, binary retrieval + rescoring, and scalar retrieval + rescoring, using USearch. It uses the <a href="../../../docs/package_reference/quantization.html#sentence_transformers.quantization.semantic_search_usearch"><code>semantic_search_usearch</code></a> utility function. Our experiments show large speedups on newer hardware, particularly for `int8`. |
93 changes: 93 additions & 0 deletions
93
examples/applications/embedding-quantization/semantic_search_faiss.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,93 @@ | ||
import time | ||
from sentence_transformers import SentenceTransformer | ||
from sentence_transformers.quantization import quantize_embeddings, semantic_search_faiss | ||
from datasets import load_dataset | ||
|
||
# 1. Load the quora corpus with questions | ||
dataset = load_dataset("quora", split="train").map( | ||
lambda batch: {"text": [text for sample in batch["questions"] for text in sample["text"]]}, | ||
batched=True, | ||
remove_columns=["questions", "is_duplicate"], | ||
) | ||
max_corpus_size = 100_000 | ||
corpus = dataset["text"][:max_corpus_size] | ||
|
||
# 2. Come up with some queries | ||
queries = [ | ||
"How do I become a good programmer?", | ||
"How do I become a good data scientist?", | ||
] | ||
|
||
# 3. Load the model | ||
model = SentenceTransformer("mixedbread-ai/mxbai-embed-large-v1") | ||
|
||
# 4. Choose a target precision for the corpus embeddings | ||
corpus_precision = "ubinary" | ||
# Valid options are: "float32", "uint8", "int8", "ubinary", and "binary" | ||
# But FAISS only supports "float32", "uint8", and "ubinary" | ||
|
||
# 5. Encode the corpus | ||
full_corpus_embeddings = model.encode(corpus, normalize_embeddings=True, show_progress_bar=True) | ||
corpus_embeddings = quantize_embeddings(full_corpus_embeddings, precision=corpus_precision) | ||
# NOTE: We can also pass "precision=..." to the encode method to quantize the embeddings directly, | ||
# but we want to keep the full precision embeddings to act as a calibration dataset for quantizing | ||
# the query embeddings. This is important only if you are using uint8 or int8 precision | ||
|
||
# Initially, we don't have a FAISS index yet, we can use semantic_search_faiss to create it | ||
corpus_index = None | ||
while True: | ||
# 7. Encode the queries using the full precision | ||
start_time = time.time() | ||
query_embeddings = model.encode(queries, normalize_embeddings=True) | ||
print(f"Encoding time: {time.time() - start_time:.6f} seconds") | ||
|
||
# 8. Perform semantic search using FAISS | ||
results, search_time, corpus_index = semantic_search_faiss( | ||
query_embeddings, | ||
corpus_index=corpus_index, | ||
corpus_embeddings=corpus_embeddings if corpus_index is None else None, | ||
corpus_precision=corpus_precision, | ||
top_k=10, | ||
calibration_embeddings=full_corpus_embeddings, | ||
rescore=corpus_precision != "float32", | ||
rescore_multiplier=4, | ||
exact=True, | ||
output_index=True, | ||
) | ||
# This is a helper function to showcase how FAISS can be used with quantized embeddings. | ||
# You must either provide the `corpus_embeddings` or the `corpus_index` FAISS index, but not both. | ||
# In the first call we'll provide the `corpus_embeddings` and get the `corpus_index` back, which | ||
# we'll use in the next call. In practice, the index is stored in RAM or saved to disk, and not | ||
# recalculated for every query. | ||
|
||
# This function will 1) quantize the query embeddings to the same precision as the corpus embeddings, | ||
# 2) perform the semantic search using FAISS, 3) rescore the results using the full precision embeddings, | ||
# and 4) return the results and the search time (and perhaps the FAISS index). | ||
|
||
# `corpus_precision` must be the same as the precision used to quantize the corpus embeddings. | ||
# It is used to convert the query embeddings to the same precision as the corpus embeddings. | ||
# `top_k` determines how many results are returned for each query. | ||
# `rescore_multiplier` is a parameter for the rescoring step. Rather than searching for the top_k results, | ||
# we search for top_k * rescore_multiplier results and rescore the top_k results using the full precision embeddings. | ||
# So, higher values of rescore_multiplier will give better results, but will be slower. | ||
|
||
# `calibration_embeddings` is a set of embeddings used to calibrate the quantization of the query embeddings. | ||
# This is important only if you are using uint8 or int8 precision. In practice, this is used to calculate | ||
# the minimum and maximum values of each of the embedding dimensions, which are then used to determine the | ||
# quantization thresholds. | ||
|
||
# `rescore` determines whether to rescore the results using the full precision embeddings, if False & the | ||
# corpus is quantized, the results will be very poor. `exact` determines whether to use the exact search | ||
# or the approximate search method in FAISS. | ||
|
||
# 9. Output the results | ||
print("Precision:", corpus_precision) | ||
print(f"Search time: {search_time:.6f} seconds") | ||
for query, result in zip(queries, results): | ||
print(f"Query: {query}") | ||
for entry in result: | ||
print(f"(Score: {entry['score']:.4f}) {corpus[entry['corpus_id']]}") | ||
print("") | ||
|
||
# 10. Prompt for more queries | ||
queries = [input("Please enter a question: ")] |
48 changes: 48 additions & 0 deletions
48
examples/applications/embedding-quantization/semantic_search_faiss_benchmark.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
from sentence_transformers import SentenceTransformer | ||
from sentence_transformers.quantization import quantize_embeddings, semantic_search_faiss | ||
from datasets import load_dataset | ||
|
||
# 1. Load the quora corpus with questions | ||
dataset = load_dataset("quora", split="train").map( | ||
lambda batch: {"text": [text for sample in batch["questions"] for text in sample["text"]]}, | ||
batched=True, | ||
remove_columns=["questions", "is_duplicate"], | ||
) | ||
max_corpus_size = 100_000 | ||
corpus = dataset["text"][:max_corpus_size] | ||
num_queries = 1_000 | ||
queries = corpus[:num_queries] | ||
|
||
# 2. Load the model | ||
model = SentenceTransformer("mixedbread-ai/mxbai-embed-large-v1") | ||
|
||
# 3. Encode the corpus | ||
full_corpus_embeddings = model.encode(corpus, normalize_embeddings=True, show_progress_bar=True) | ||
|
||
# 4. Encode the queries using the full precision | ||
query_embeddings = model.encode(queries, normalize_embeddings=True) | ||
|
||
for exact in (True, False): | ||
for corpus_precision in ("float32", "uint8", "ubinary"): | ||
corpus_embeddings = quantize_embeddings(full_corpus_embeddings, precision=corpus_precision) | ||
# NOTE: We can also pass "precision=..." to the encode method to quantize the embeddings directly, | ||
# but we want to keep the full precision embeddings to act as a calibration dataset for quantizing | ||
# the query embeddings. This is important only if you are using uint8 or int8 precision | ||
|
||
# 5. Perform semantic search using FAISS | ||
rescore_multiplier = 4 | ||
results, search_time = semantic_search_faiss( | ||
query_embeddings, | ||
corpus_embeddings=corpus_embeddings, | ||
corpus_precision=corpus_precision, | ||
top_k=10, | ||
calibration_embeddings=full_corpus_embeddings, | ||
rescore=corpus_precision != "float32", | ||
rescore_multiplier=rescore_multiplier, | ||
exact=exact, | ||
) | ||
|
||
print( | ||
f"{'Exact' if exact else 'Approximate'} search time using {corpus_precision} corpus: {search_time:.6f} seconds" | ||
+ (f" (rescore_multiplier: {rescore_multiplier})" if corpus_precision != "float32" else "") | ||
) |
Oops, something went wrong.