Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Fix] Retriever tokenization function in atlas.py needs correction #20

Open
silencio94 opened this issue Sep 24, 2023 · 1 comment
Open

Comments

@silencio94
Copy link

silencio94 commented Sep 24, 2023

When the code runs, the maximum passage length becomes the smaller of the two variables, self.opt.text_maxlength and gpu_embedder_batch_size. By default, gpu_embedder_batch_size is set to 512, and if you run the code without modifying default option, most BERT-style dual encoders will work without issues (see line 74).

However, if you reduce gpu_embedder_batch_size to conserve GPU memory, unexpected results can occur without warning.

atlas/src/atlas.py

Lines 61 to 89 in f8bec5c

@torch.no_grad()
def build_index(self, index, passages, gpu_embedder_batch_size, logger=None):
n_batch = math.ceil(len(passages) / gpu_embedder_batch_size)
retrieverfp16 = self._get_fp16_retriever_copy()
total = 0
for i in range(n_batch):
batch = passages[i * gpu_embedder_batch_size : (i + 1) * gpu_embedder_batch_size]
batch = [self.opt.retriever_format.format(**example) for example in batch]
batch_enc = self.retriever_tokenizer(
batch,
padding="longest",
return_tensors="pt",
max_length=min(self.opt.text_maxlength, gpu_embedder_batch_size),
truncation=True,
)
embeddings = retrieverfp16(**_to_cuda(batch_enc), is_passages=True)
index.embeddings[:, total : total + len(embeddings)] = embeddings.T
total += len(embeddings)
if i % 500 == 0 and i > 0:
logger.info(f"Number of passages encoded: {total}")
dist_utils.barrier()
logger.info(f"{total} passages encoded on process: {dist_utils.get_rank()}")
if not index.is_index_trained():
logger.info(f"Building faiss indices")
index.train_index()

So, it is recommended to modify line 74 as follows (as done in other parts of the code):

min(self.opt.text_maxlength, BERT_MAX_SEQ_LENGTH),
@littlewine
Copy link
Contributor

Indeed, I saw this in the code and I was wondering what was the logic behind it - if any! Would be useful if the authors could clarify!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants