forked from cvg/Hierarchical-Localization
-
Notifications
You must be signed in to change notification settings - Fork 1
/
adalam.py
55 lines (49 loc) · 1.92 KB
/
adalam.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
import torch
from ..utils.base_model import BaseModel
from kornia.feature.adalam import AdalamFilter
from kornia.utils.helpers import get_cuda_device_if_available
class AdaLAM(BaseModel):
# See https://kornia.readthedocs.io/en/latest/_modules/kornia/feature/adalam/adalam.html.
default_conf = {
'area_ratio': 100,
'search_expansion': 4,
'ransac_iters': 128,
'min_inliers': 6,
'min_confidence': 200,
'orientation_difference_threshold': 30,
'scale_rate_threshold': 1.5,
'detected_scale_rate_threshold': 5,
'refit': True,
'force_seed_mnn': True,
'device': get_cuda_device_if_available()
}
required_inputs = [
'image0', 'image1',
'descriptors0', 'descriptors1',
'keypoints0', 'keypoints1',
'scales0', 'scales1',
'oris0', 'oris1']
def _init(self, conf):
self.adalam = AdalamFilter(conf)
def _forward(self, data):
assert data['keypoints0'].size(0) == 1
if data['keypoints0'].size(1) < 2 or data['keypoints1'].size(1) < 2:
matches = torch.zeros(
(0, 2), dtype=torch.int64,
device=data['keypoints0'].device)
else:
matches = self.adalam.match_and_filter(
data['keypoints0'][0], data['keypoints1'][0],
data['descriptors0'][0].T, data['descriptors1'][0].T,
data['image0'].shape[2:], data['image1'].shape[2:],
data['oris0'][0], data['oris1'][0],
data['scales0'][0], data['scales1'][0]
)
matches_new = torch.full(
(data['keypoints0'].size(1),), -1,
dtype=torch.int64, device=data['keypoints0'].device)
matches_new[matches[:, 0]] = matches[:, 1]
return {
'matches0': matches_new.unsqueeze(0),
'matching_scores0': torch.zeros(matches_new.size(0)).unsqueeze(0)
}