From 7290448809cb73f08f63c955550815775434beb4 Mon Sep 17 00:00:00 2001 From: Tom Aarsen <37621491+tomaarsen@users.noreply.github.com> Date: Thu, 19 Sep 2024 14:59:28 +0200 Subject: [PATCH] [`fix`] Ensure that the embeddings from hard negative mining are normalized (#2944) --- sentence_transformers/util.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/sentence_transformers/util.py b/sentence_transformers/util.py index 59e7bf0c4..5b8baa5f7 100644 --- a/sentence_transformers/util.py +++ b/sentence_transformers/util.py @@ -714,8 +714,12 @@ def mine_hard_negatives( except Exception: pass - corpus_embeddings = model.encode(corpus, batch_size=batch_size, convert_to_numpy=True, show_progress_bar=True) - query_embeddings = model.encode(queries, batch_size=batch_size, convert_to_numpy=True, show_progress_bar=True) + corpus_embeddings = model.encode( + corpus, batch_size=batch_size, normalize_embeddings=True, convert_to_numpy=True, show_progress_bar=True + ) + query_embeddings = model.encode( + queries, batch_size=batch_size, normalize_embeddings=True, convert_to_numpy=True, show_progress_bar=True + ) index.add(corpus_embeddings) scores_list = [] @@ -731,8 +735,12 @@ def mine_hard_negatives( else: # Embed the corpus and the queries - corpus_embeddings = model.encode(corpus, batch_size=batch_size, convert_to_numpy=True, show_progress_bar=True) - query_embeddings = model.encode(queries, batch_size=batch_size, convert_to_numpy=True, show_progress_bar=True) + corpus_embeddings = model.encode( + corpus, batch_size=batch_size, normalize_embeddings=True, convert_to_numpy=True, show_progress_bar=True + ) + query_embeddings = model.encode( + queries, batch_size=batch_size, normalize_embeddings=True, convert_to_numpy=True, show_progress_bar=True + ) scores = model.similarity(query_embeddings, corpus_embeddings).to(device) # Keep only the range_max + max_positives highest scores. We offset by 1 to potentially include the positive pair