-
Notifications
You must be signed in to change notification settings - Fork 1
/
knn.py
117 lines (82 loc) · 3.06 KB
/
knn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import time
import torch
from pykeops.torch import Vi, Vj
def KNN_KeOps(K, metric="euclidean"):
def fit(x_train):
# Setup the K-NN estimator:
start = time.time()
# Encoding as KeOps LazyTensors:
D = x_train.shape[1]
X_i = Vi(0, D) # Purely symbolic "i" variable, without any data array
X_j = Vj(1, D) # Purely symbolic "j" variable, without any data array
# Symbolic distance matrix:
if metric == "euclidean":
D_ij = ((X_i - X_j) ** 2).sum(-1)
# K-NN query operator:
KNN_fun = D_ij.argKmin(K, dim=1)
# N.B.: The "training" time here should be negligible.
elapsed = time.time() - start
def f(x_test):
start = time.time()
# Actual K-NN query:
indices = KNN_fun(x_test, x_train)
elapsed = time.time() - start
indices = indices.cpu().numpy()
return indices, elapsed
return f, elapsed
return fit
def KNN_torch_fun(x_train, x_train_norm, x_test, K):
largest = False # Default behaviour is to look for the smallest values
x_test_norm = (x_test ** 2).sum(-1)
diss = (
x_test_norm.view(-1, 1)
+ x_train_norm.view(1, -1)
- 2 * x_test @ x_train.t() # Rely on cuBLAS for better performance!
)
return diss.topk(K, dim=1, largest=largest).indices
def KNN_torch(K):
def fit(x_train):
# Setup the K-NN estimator:
start = time.time()
# The "training" time here should be negligible:
x_train_norm = (x_train ** 2).sum(-1)
elapsed = time.time() - start
def f(x_test):
start = time.time()
# Actual K-NN query:
out = KNN_torch_fun(x_train, x_train_norm, x_test, K)
elapsed = time.time() - start
indices = out
return indices, elapsed
return f, elapsed
return fit
def get_knn(samples,context_cloud,n_neighbors,type='torch'):
if type == 'torch':
knn_func = KNN_torch
elif type == 'KeOps':
knn_func = KNN_KeOps
else:
raise Exception('Invalid knn func')
knn = knn_func(n_neighbors)
knn,_ = knn(context_cloud)
index,_ = knn(samples)
return index
if __name__ == '__main__':
use_cuda = True
dtype = torch.cuda.FloatTensor if use_cuda else torch.FloatTensor
device = 'cpu'
context_cloud = torch.randn((100000,3)).to(device).contiguous()
context = context_cloud[:,:3].contiguous()
x = torch.randn((1000,3)).to(device).contiguous()
k=1000
for knn_func in [KNN_torch,KNN_KeOps]:
print(str(knn_func))
knn = knn_func(k)
fitted_knn,elapsed_fit = knn(context)
print(f'Fitting: {elapsed_fit}')
index,elapsed_query = fitted_knn(x)
print(f'Query: {elapsed_query}')
elapsed_total = elapsed_fit+elapsed_query
print(f'Total: {elapsed_total}')
context_cloud = context_cloud.cpu()
index = index.cpu()