Skip to content

Commit

Permalink
[fix] Ensure that the embeddings from hard negative mining are norm…
Browse files Browse the repository at this point in the history
…alized (#2944)
  • Loading branch information
tomaarsen committed Sep 19, 2024
1 parent a201c6d commit 7290448
Showing 1 changed file with 12 additions and 4 deletions.
16 changes: 12 additions & 4 deletions sentence_transformers/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -714,8 +714,12 @@ def mine_hard_negatives(
except Exception:
pass

corpus_embeddings = model.encode(corpus, batch_size=batch_size, convert_to_numpy=True, show_progress_bar=True)
query_embeddings = model.encode(queries, batch_size=batch_size, convert_to_numpy=True, show_progress_bar=True)
corpus_embeddings = model.encode(
corpus, batch_size=batch_size, normalize_embeddings=True, convert_to_numpy=True, show_progress_bar=True
)
query_embeddings = model.encode(
queries, batch_size=batch_size, normalize_embeddings=True, convert_to_numpy=True, show_progress_bar=True
)
index.add(corpus_embeddings)

scores_list = []
Expand All @@ -731,8 +735,12 @@ def mine_hard_negatives(

else:
# Embed the corpus and the queries
corpus_embeddings = model.encode(corpus, batch_size=batch_size, convert_to_numpy=True, show_progress_bar=True)
query_embeddings = model.encode(queries, batch_size=batch_size, convert_to_numpy=True, show_progress_bar=True)
corpus_embeddings = model.encode(
corpus, batch_size=batch_size, normalize_embeddings=True, convert_to_numpy=True, show_progress_bar=True
)
query_embeddings = model.encode(
queries, batch_size=batch_size, normalize_embeddings=True, convert_to_numpy=True, show_progress_bar=True
)
scores = model.similarity(query_embeddings, corpus_embeddings).to(device)

# Keep only the range_max + max_positives highest scores. We offset by 1 to potentially include the positive pair
Expand Down

0 comments on commit 7290448

Please sign in to comment.