Skip to content

Commit

Permalink
update mteb eval
Browse files Browse the repository at this point in the history
  • Loading branch information
545999961 committed Nov 15, 2024
1 parent 5951aa6 commit 61ab7e0
Show file tree
Hide file tree
Showing 5 changed files with 25 additions and 20 deletions.
3 changes: 1 addition & 2 deletions FlagEmbedding/abc/inference/AbsEmbedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,8 +264,7 @@ def encode(
return embeddings

def __del__(self):
if self.pool is not None:
self.stop_multi_process_pool(self.pool)
self.stop_self_pool()

@abstractmethod
def encode_single_device(
Expand Down
3 changes: 1 addition & 2 deletions FlagEmbedding/abc/inference/AbsReranker.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,8 +210,7 @@ def compute_score(
return scores

def __del__(self):
if self.pool is not None:
self.stop_multi_process_pool(self.pool)
self.stop_self_pool()

@abstractmethod
def compute_score_single_gpu(
Expand Down
9 changes: 3 additions & 6 deletions FlagEmbedding/evaluation/mteb/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,16 +114,13 @@ def run(self):
task_types=task_types
)
output_folder = self.eval_args.output_dir
new_tasks = []
for task in tasks:
if task.languages is not None:
if len(task.languages) == len([e for e in languages if e in task.languages]):
new_tasks.append(task)

for task in new_tasks:
for task in tasks:
task_name = task.metadata.name
task_type = task.metadata.type

self.retriever.stop_pool()

if self.eval_args.use_special_instructions:
try:
instruction = get_task_def_by_task_name_and_type(task_name, task_type)
Expand Down
7 changes: 7 additions & 0 deletions FlagEmbedding/evaluation/mteb/searcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,13 @@ def get_instruction(self):
def set_normalize_embeddings(self, normalize_embeddings: bool = True):
self.embedder.normalize_embeddings = normalize_embeddings

def stop_pool(self):
self.embedder.stop_self_pool()
try:
self.embedder.stop_self_query_pool()
except:
pass

def encode_queries(self, queries: List[str], **kwargs):
emb = self.embedder.encode_queries(queries)
if isinstance(emb, dict):
Expand Down
23 changes: 13 additions & 10 deletions FlagEmbedding/inference/embedder/decoder_only/icl.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import queue
from multiprocessing import Queue

import gc
import torch
import numpy as np
from transformers import AutoModel, AutoTokenizer
Expand Down Expand Up @@ -121,10 +122,8 @@ def __init__(
self.query_pool = None

def __del__(self):
if self.pool is not None:
self.stop_multi_process_pool(self.pool)
if self.query_pool is not None:
self.stop_multi_process_pool(self.query_pool)
self.stop_self_pool()
self.stop_self_query_pool()

def set_examples(self, examples_for_task: Optional[List[dict]] = None):
"""Set the prefix to the provided examples.
Expand Down Expand Up @@ -175,6 +174,14 @@ def get_detailed_example(instruction_format: str, instruction: str, query: str,
"""
return instruction_format.format(instruction, query, response)

def stop_self_query_pool(self):
if self.query_pool is not None:
self.stop_multi_process_pool(self.query_pool)
self.query_pool = None
self.model.to('cpu')
gc.collect()
torch.cuda.empty_cache()

def encode_queries(
self,
queries: Union[List[str], str],
Expand Down Expand Up @@ -209,9 +216,7 @@ def encode_queries(
**kwargs
)

if self.pool is not None:
self.stop_multi_process_pool(self.pool)
self.pool = None
self.stop_self_pool()
if self.query_pool is None:
self.query_pool = self.start_multi_process_pool(ICLLLMEmbedder._encode_queries_multi_process_worker)
embeddings = self.encode_multi_process(
Expand Down Expand Up @@ -244,9 +249,7 @@ def encode_corpus(
Returns:
Union[torch.Tensor, np.ndarray]: Return the embedding vectors in a numpy array or tensor.
"""
if self.query_pool is not None:
self.stop_multi_process_pool(self.query_pool)
self.query_pool = None
self.stop_self_query_pool()
return super().encode_corpus(
corpus,
batch_size=batch_size,
Expand Down

0 comments on commit 61ab7e0

Please sign in to comment.