diff --git a/sentence_transformers/util.py b/sentence_transformers/util.py index 59e7bf0c4..5b8baa5f7 100644 --- a/sentence_transformers/util.py +++ b/sentence_transformers/util.py @@ -714,8 +714,12 @@ def mine_hard_negatives( except Exception: pass - corpus_embeddings = model.encode(corpus, batch_size=batch_size, convert_to_numpy=True, show_progress_bar=True) - query_embeddings = model.encode(queries, batch_size=batch_size, convert_to_numpy=True, show_progress_bar=True) + corpus_embeddings = model.encode( + corpus, batch_size=batch_size, normalize_embeddings=True, convert_to_numpy=True, show_progress_bar=True + ) + query_embeddings = model.encode( + queries, batch_size=batch_size, normalize_embeddings=True, convert_to_numpy=True, show_progress_bar=True + ) index.add(corpus_embeddings) scores_list = [] @@ -731,8 +735,12 @@ def mine_hard_negatives( else: # Embed the corpus and the queries - corpus_embeddings = model.encode(corpus, batch_size=batch_size, convert_to_numpy=True, show_progress_bar=True) - query_embeddings = model.encode(queries, batch_size=batch_size, convert_to_numpy=True, show_progress_bar=True) + corpus_embeddings = model.encode( + corpus, batch_size=batch_size, normalize_embeddings=True, convert_to_numpy=True, show_progress_bar=True + ) + query_embeddings = model.encode( + queries, batch_size=batch_size, normalize_embeddings=True, convert_to_numpy=True, show_progress_bar=True + ) scores = model.similarity(query_embeddings, corpus_embeddings).to(device) # Keep only the range_max + max_positives highest scores. We offset by 1 to potentially include the positive pair