-
Notifications
You must be signed in to change notification settings - Fork 46
/
chamfer_python.py
44 lines (35 loc) · 1.36 KB
/
chamfer_python.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
import torch
def pairwise_dist(x, y):
xx, yy, zz = torch.mm(x, x.t()), torch.mm(y, y.t()), torch.mm(x, y.t())
rx = xx.diag().unsqueeze(0).expand_as(xx)
ry = yy.diag().unsqueeze(0).expand_as(yy)
P = rx.t() + ry - 2 * zz
return P
def NN_loss(x, y, dim=0):
dist = pairwise_dist(x, y)
values, indices = dist.min(dim=dim)
return values.mean()
def batched_pairwise_dist(a, b):
x, y = a.double(), b.double()
bs, num_points_x, points_dim = x.size()
bs, num_points_y, points_dim = y.size()
xx = torch.pow(x, 2).sum(2)
yy = torch.pow(y, 2).sum(2)
zz = torch.bmm(x, y.transpose(2, 1))
rx = xx.unsqueeze(1).expand(bs, num_points_y, num_points_x) # Diagonal elements xx
ry = yy.unsqueeze(1).expand(bs, num_points_x, num_points_y) # Diagonal elements yy
P = rx.transpose(2, 1) + ry - 2 * zz
return P
def distChamfer(a, b):
"""
:param a: Pointclouds Batch x nul_points x dim
:param b: Pointclouds Batch x nul_points x dim
:return:
-closest point on b of points from a
-closest point on a of points from b
-idx of closest point on b of points from a
-idx of closest point on a of points from b
Works for pointcloud of any dimension
"""
P = batched_pairwise_dist(a, b)
return torch.min(P, 2)[0].float(), torch.min(P, 1)[0].float(), torch.min(P, 2)[1].int(), torch.min(P, 1)[1].int()