Skip to content

Commit

Permalink
Merge pull request #599 from MouseLand/jacob/typecast_spike_removal
Browse files Browse the repository at this point in the history
Jacob/typecast spike removal
  • Loading branch information
jacobpennington authored Mar 2, 2024
2 parents 73430b6 + bec8e6f commit 58d19d0
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 5 deletions.
4 changes: 2 additions & 2 deletions kilosort/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,8 +162,8 @@ def save_to_phy(st, clu, tF, Wall, probe, ops, imin, results_dir=None,
np.save((results_dir / 'whitening_mat_inv.npy'), whitening_mat_inv)

# spike properties
spike_times = st[:,0] + imin # shift by minimum sample index
spike_templates = st[:,1]
spike_times = st[:,0].astype('int64') + imin # shift by minimum sample index
spike_templates = st[:,1].astype('int32')
spike_clusters = clu
xs, ys = compute_spike_positions(st, tF, ops)
spike_positions = np.vstack([xs, ys]).T
Expand Down
12 changes: 9 additions & 3 deletions kilosort/postprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,19 @@
import torch


@njit
@njit("(int64[:], int32[:], int32)")
def remove_duplicates(spike_times, spike_clusters, dt=15):
'''Removes same-cluster spikes that occur within `dt` samples.'''
keep = np.zeros_like(spike_times, bool_)
cluster_t0 = {}
for (i,t), c in zip(enumerate(spike_times), spike_clusters):
t0 = cluster_t0.get(c, t-dt)
for i in range(spike_times.size):
t = spike_times[i]
c = spike_clusters[i]
if c in cluster_t0:
t0 = cluster_t0[c]
else:
t0 = t - dt

if t >= (t0 + dt):
# Separate spike, reset t0 and keep spike
cluster_t0[c] = t
Expand Down

0 comments on commit 58d19d0

Please sign in to comment.