Skip to content

Commit

Permalink
Improve memory management in clustering_qr.kmeans_plusplus
Browse files Browse the repository at this point in the history
  • Loading branch information
RobertoDF committed Sep 4, 2024
1 parent b2f5ded commit 5169a39
Showing 1 changed file with 11 additions and 2 deletions.
13 changes: 11 additions & 2 deletions kilosort/clustering_qr.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,13 @@ def cluster(Xd, iclust = None, kn = None, nskip = 20, n_neigh = 10, nclust = 200

def kmeans_plusplus(Xg, niter = 200, seed = 1, device=torch.device('cuda')):
#Xg = torch.from_numpy(Xd).to(dev)
vtot = (Xg**2).sum(1)
Xg_squared = Xg ** 2

vtot = Xg_squared.sum(1)

del Xg_squared
gc.collect()
torch.cuda.empty_cache()

n1 = vtot.shape[0]
if n1 > 2**24:
Expand Down Expand Up @@ -199,7 +205,10 @@ def kmeans_plusplus(Xg, niter = 200, seed = 1, device=torch.device('cuda')):
isamp = torch.multinomial(v2, ntry)

Xc = Xg[isamp]
vexp = 2 * Xg @ Xc.T - (Xc**2).sum(1)
Xc_squared_sum = (Xc ** 2).sum(1)
vexp = Xg @ Xc.T
vexp.mul_(2)
vexp = vexp - Xc_squared_sum

dexp = vexp - vexp0.unsqueeze(1)
dexp = torch.relu(dexp)
Expand Down

0 comments on commit 5169a39

Please sign in to comment.