-
Notifications
You must be signed in to change notification settings - Fork 58
/
test.py
74 lines (59 loc) · 3.29 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import faiss
import torch
import logging
import numpy as np
from tqdm import tqdm
from typing import Tuple
from argparse import Namespace
from torch.utils.data.dataset import Subset
from torch.utils.data import DataLoader, Dataset
import visualizations
# Compute R@1, R@5, R@10, R@20
RECALL_VALUES = [1, 5, 10, 20]
def test(args: Namespace, eval_ds: Dataset, model: torch.nn.Module,
num_preds_to_save: int = 0) -> Tuple[np.ndarray, str]:
"""Compute descriptors of the given dataset and compute the recalls."""
model = model.eval()
with torch.no_grad():
logging.debug("Extracting database descriptors for evaluation/testing")
database_subset_ds = Subset(eval_ds, list(range(eval_ds.database_num)))
database_dataloader = DataLoader(dataset=database_subset_ds, num_workers=args.num_workers,
batch_size=args.infer_batch_size, pin_memory=(args.device == "cuda"))
all_descriptors = np.empty((len(eval_ds), args.fc_output_dim), dtype="float32")
for images, indices in tqdm(database_dataloader, ncols=100):
descriptors = model(images.to(args.device))
descriptors = descriptors.cpu().numpy()
all_descriptors[indices.numpy(), :] = descriptors
logging.debug("Extracting queries descriptors for evaluation/testing using batch size 1")
queries_infer_batch_size = 1
queries_subset_ds = Subset(eval_ds, list(range(eval_ds.database_num, eval_ds.database_num+eval_ds.queries_num)))
queries_dataloader = DataLoader(dataset=queries_subset_ds, num_workers=args.num_workers,
batch_size=queries_infer_batch_size, pin_memory=(args.device == "cuda"))
for images, indices in tqdm(queries_dataloader, ncols=100):
descriptors = model(images.to(args.device))
descriptors = descriptors.cpu().numpy()
all_descriptors[indices.numpy(), :] = descriptors
queries_descriptors = all_descriptors[eval_ds.database_num:]
database_descriptors = all_descriptors[:eval_ds.database_num]
# Use a kNN to find predictions
faiss_index = faiss.IndexFlatL2(args.fc_output_dim)
faiss_index.add(database_descriptors)
del database_descriptors, all_descriptors
logging.debug("Calculating recalls")
_, predictions = faiss_index.search(queries_descriptors, max(RECALL_VALUES))
#### For each query, check if the predictions are correct
positives_per_query = eval_ds.get_positives()
recalls = np.zeros(len(RECALL_VALUES))
for query_index, preds in enumerate(predictions):
for i, n in enumerate(RECALL_VALUES):
if np.any(np.in1d(preds[:n], positives_per_query[query_index])):
recalls[i:] += 1
break
# Divide by queries_num and multiply by 100, so the recalls are in percentages
recalls = recalls / eval_ds.queries_num * 100
recalls_str = ", ".join([f"R@{val}: {rec:.1f}" for val, rec in zip(RECALL_VALUES, recalls)])
# Save visualizations of predictions
if num_preds_to_save != 0:
# For each query save num_preds_to_save predictions
visualizations.save_preds(predictions[:, :num_preds_to_save], eval_ds, args.output_folder, args.save_only_wrong_preds)
return recalls, recalls_str