Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft compare Shijie v old #5

Draft
wants to merge 15 commits into
base: old
Choose a base branch
from
Draft
2,488 changes: 2,488 additions & 0 deletions notebooks/00_intro.ipynb

Large diffs are not rendered by default.

500 changes: 500 additions & 0 deletions notebooks/03_lfp.ipynb

Large diffs are not rendered by default.

3,458 changes: 3,458 additions & 0 deletions notebooks/07_linearization.ipynb

Large diffs are not rendered by default.

3,384 changes: 3,384 additions & 0 deletions notebooks/0_intro_DPG.ipynb

Large diffs are not rendered by default.

5,047 changes: 5,047 additions & 0 deletions notebooks/1D_Clusterless_Decoding.ipynb

Large diffs are not rendered by default.

11,802 changes: 11,802 additions & 0 deletions notebooks/1_spikesorting-anagh-mike.ipynb

Large diffs are not rendered by default.

11,172 changes: 11,172 additions & 0 deletions notebooks/1_spikesorting-shijie.ipynb

Large diffs are not rendered by default.

2,247 changes: 2,247 additions & 0 deletions notebooks/4_position_info-Copy1.ipynb

Large diffs are not rendered by default.

3,575 changes: 3,575 additions & 0 deletions notebooks/A.DataTableIntro.ipynb

Large diffs are not rendered by default.

2,365 changes: 2,365 additions & 0 deletions notebooks/Gyro_speed.ipynb

Large diffs are not rendered by default.

3,194 changes: 3,194 additions & 0 deletions notebooks/I.DecodingPosPreprocessing.ipynb

Large diffs are not rendered by default.

2,834 changes: 2,834 additions & 0 deletions notebooks/II.DecodingEphysPreprocessing.ipynb

Large diffs are not rendered by default.

21,983 changes: 21,983 additions & 0 deletions notebooks/II.copy.ipynb

Large diffs are not rendered by default.

6,025 changes: 6,025 additions & 0 deletions notebooks/III.LFP.ipynb

Large diffs are not rendered by default.

5,533 changes: 5,533 additions & 0 deletions notebooks/IV.Ripple_Detection-batch.ipynb

Large diffs are not rendered by default.

7,507 changes: 7,507 additions & 0 deletions notebooks/IV.Ripple_Detection.ipynb

Large diffs are not rendered by default.

738 changes: 738 additions & 0 deletions notebooks/V.statescript.ipynb

Large diffs are not rendered by default.

9,077 changes: 9,077 additions & 0 deletions notebooks/VII.Decoding.ipynb

Large diffs are not rendered by default.

37,367 changes: 37,367 additions & 0 deletions notebooks/VII.Decoding_batch.ipynb

Large diffs are not rendered by default.

7,138 changes: 7,138 additions & 0 deletions notebooks/VII.Decoding_batch2.ipynb

Large diffs are not rendered by default.

366 changes: 366 additions & 0 deletions notebooks/VIII.DecodeExamples.ipynb

Large diffs are not rendered by default.

5,960 changes: 5,960 additions & 0 deletions notebooks/X.analysis.ipynb

Large diffs are not rendered by default.

34,070 changes: 34,070 additions & 0 deletions notebooks/XI.1.Replay Composition by day.ipynb

Large diffs are not rendered by default.

298 changes: 298 additions & 0 deletions notebooks/XI.1p.plot_replay_composition_by_day.ipynb

Large diffs are not rendered by default.

9,675 changes: 9,675 additions & 0 deletions notebooks/XI.2.Remote Replay Composition by day.ipynb

Large diffs are not rendered by default.

343 changes: 343 additions & 0 deletions notebooks/XI.2p.plot_remote_replay_composition_by_day.ipynb

Large diffs are not rendered by default.

12,624 changes: 12,624 additions & 0 deletions notebooks/XI.3.Replay Composition by session.ipynb

Large diffs are not rendered by default.

268 changes: 268 additions & 0 deletions notebooks/XI.3p.plot_replay_composition_by_session.ipynb

Large diffs are not rendered by default.

1,086 changes: 1,086 additions & 0 deletions notebooks/XI.Replay Composition.ipynb

Large diffs are not rendered by default.

4,579 changes: 4,579 additions & 0 deletions notebooks/XII.Replay pairwise comparison.ipynb

Large diffs are not rendered by default.

1,907 changes: 1,907 additions & 0 deletions notebooks/importData_shijie.ipynb

Large diffs are not rendered by default.

3,556 changes: 3,556 additions & 0 deletions notebooks/linearization_example.ipynb

Large diffs are not rendered by default.

1,579 changes: 1,579 additions & 0 deletions notebooks/make_1Dtrack.ipynb

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion src/spyglass/common/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
AnalysisNwbfile,
AnalysisNwbfileKachery,
Nwbfile,
NwbfileKachery,
#NwbfileKachery,
)
from spyglass.common.common_position import (
IntervalLinearizationSelection,
Expand Down
5 changes: 5 additions & 0 deletions src/spyglass/common/common_ephys.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from dataclasses import replace
import warnings

import datajoint as dj
Expand Down Expand Up @@ -723,6 +724,10 @@ def make(self, key):

# load in the timestamps
timestamps = np.asarray(lfp_object.timestamps)

# lead nan timestamps out. Patchwork by Shijie: find out in LFP processing why nan is introduced
timestamps = timestamps[~np.isnan(timestamps)]

# get the indices of the first timestamp and the last timestamp that are within the valid times
included_indices = interval_list_contains_ind(
lfp_band_valid_times, timestamps
Expand Down
32 changes: 16 additions & 16 deletions src/spyglass/common/common_nwbfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -642,22 +642,22 @@ def nightly_cleanup():
# also check to see whether there are directories in the spikesorting folder with this


@schema
class NwbfileKachery(SpyglassMixin, dj.Computed):
definition = """
-> Nwbfile
---
nwb_file_uri: varchar(200) # the uri the NWB file for kachery
"""

def make(self, key):
import kachery_client as kc

logger.info(f'Linking {key["nwb_file_name"]} and storing in kachery...')
key["nwb_file_uri"] = kc.link_file(
Nwbfile().get_abs_path(key["nwb_file_name"])
)
self.insert1(key)
# @schema
# class NwbfileKachery(SpyglassMixin, dj.Computed):
# definition = """
# -> Nwbfile
# ---
# nwb_file_uri: varchar(200) # the uri the NWB file for kachery
# """

# def make(self, key):
# import kachery_client as kc

# logger.info(f'Linking {key["nwb_file_name"]} and storing in kachery...')
# key["nwb_file_uri"] = kc.link_file(
# Nwbfile().get_abs_path(key["nwb_file_name"])
# )
# self.insert1(key)


@schema
Expand Down
41 changes: 11 additions & 30 deletions src/spyglass/decoding/v0/clusterless.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
"""Pipeline for decoding the animal's mental position and some category of interest
from unclustered spikes and spike waveform features. See [1] for details.

References
----------
[1] Denovellis, E. L. et al. Hippocampal replay of experience at real-world
Expand All @@ -21,7 +20,7 @@
import spikeinterface as si
import xarray as xr

from spyglass.settings import waveform_dir
from spyglass.settings import waveforms_dir
from spyglass.utils import logger

try:
Expand Down Expand Up @@ -84,7 +83,6 @@ class MarkParameters(SpyglassMixin, dj.Manual):

def insert_default(self):
"""Insert the default parameter set

Examples
--------
{'peak_sign': 'neg', 'threshold' : 100}
Expand All @@ -101,11 +99,9 @@ def supported_mark_type(mark_type):
"""Checks whether the requested mark type is supported.

Currently only 'amplitude" is supported.

Parameters
----------
mark_type : str

"""
supported_types = ["amplitude"]
return mark_type in supported_types
Expand Down Expand Up @@ -157,7 +153,7 @@ def make(self, key):
f'{key["curation_id"]}_clusterless_waveforms'
)
waveform_extractor_path = str(
Path(waveform_dir) / Path(waveform_extractor_name)
Path(waveforms_dir) / Path(waveform_extractor_name)
)
if os.path.exists(waveform_extractor_path):
shutil.rmtree(waveform_extractor_path)
Expand Down Expand Up @@ -243,15 +239,9 @@ def _convert_to_dataframe(nwb_data) -> pd.DataFrame:
)

@staticmethod
def _get_peak_amplitude(
waveform: np.array,
peak_sign: str = "neg",
estimate_peak_time: bool = False,
) -> np.array:
"""Returns the amplitudes of all channels at the time of the peak.

Amplitude across channels.

def _get_peak_amplitude(waveform, peak_sign="neg", estimate_peak_time=False):
"""Returns the amplitudes of all channels at the time of the peak
amplitude across channels.
Parameters
----------
waveform : np.array
Expand All @@ -261,12 +251,9 @@ def _get_peak_amplitude(
estimate_peak_time : bool, optional
Find the peak times for each spike because some spikesorters do not
align the spike time (at index n_time // 2) to the peak

Returns
-------
peak_amplitudes : np.array
array-like, shape (n_spikes, n_channels)

peak_amplitudes : array-like, shape (n_spikes, n_channels)
"""
if estimate_peak_time:
if peak_sign == "neg":
Expand All @@ -289,22 +276,17 @@ def _threshold(
timestamps: np.array, marks: np.array, mark_param_dict: dict
):
"""Filter the marks by an amplitude threshold

Parameters
----------
timestamps : np.array
array-like, shape (n_time,)
marks : np.array
array-like, shape (n_time, n_channels)
mark_param_dict : dict

Returns
-------
filtered_timestamps : np.array
array-like, shape (n_filtered_time,)
filtered_marks : np.array
array-like, shape (n_filtered_time, n_channels)

filtered_timestamps : array-like, shape (n_filtered_time,)
filtered_marks : array-like, shape (n_filtered_time, n_channels)
"""
if mark_param_dict["peak_sign"] == "neg":
include = np.min(marks, axis=1) <= -1 * mark_param_dict["threshold"]
Expand Down Expand Up @@ -409,7 +391,6 @@ def plot_all_marks(

Plots 2D slices of each of the spike features against each other
for all electrodes.

Parameters
----------
marks_indicators : xr.DataArray, shape (n_time, n_electrodes, n_features)
Expand Down Expand Up @@ -628,7 +609,7 @@ def get_decoding_data_for_epoch(
interval_list_name: str,
position_info_param_name: str = "default_decoding",
additional_mark_keys: dict = {},
) -> tuple[pd.DataFrame, xr.DataArray, list[slice]]:
):
"""Collects necessary data for decoding.

Parameters
Expand Down Expand Up @@ -694,10 +675,10 @@ def get_decoding_data_for_epoch(

def get_data_for_multiple_epochs(
nwb_file_name: str,
epoch_names: list[str],
epoch_names: list,
position_info_param_name="default_decoding",
additional_mark_keys: dict = {},
) -> tuple[pd.DataFrame, xr.DataArray, dict[str, list[slice]], np.ndarray]:
):
"""Collects necessary data for decoding multiple environments

Parameters
Expand Down
8 changes: 4 additions & 4 deletions src/spyglass/decoding/v0/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def get_valid_ephys_position_times_from_interval(
)


def get_epoch_interval_names(nwb_file_name: str) -> list[str]:
def get_epoch_interval_names(nwb_file_name: str) -> list:
"""Find the interval names that are epochs.

Parameters
Expand All @@ -112,7 +112,7 @@ def get_epoch_interval_names(nwb_file_name: str) -> list[str]:

def get_valid_ephys_position_times_by_epoch(
nwb_file_name: str,
) -> dict[str, np.ndarray]:
) -> dict:
"""Get the valid ephys position times for each epoch.

Parameters
Expand Down Expand Up @@ -150,8 +150,8 @@ def convert_valid_times_to_slice(valid_times: np.ndarray) -> list[slice]:


def create_model_for_multiple_epochs(
epoch_names: list[str], env_kwargs: dict
) -> tuple[list[ObservationModel], list[Environment], list[list[object]]]:
epoch_names: list, env_kwargs: dict
):
"""Creates the observation model, environment, and continuous transition types for multiple epochs for decoding

Parameters
Expand Down
6 changes: 2 additions & 4 deletions src/spyglass/decoding/v0/dj_decoder_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,7 @@ def _to_dict(transition: object) -> dict:
return parameters


def _convert_transitions_to_dict(
transitions: list[list[object]],
) -> list[list[dict]]:
def _convert_transitions_to_dict(transitions: list) -> list:
"""Converts a list of lists of transition classes into a list of lists of dictionaries"""
return [
[_to_dict(transition) for transition in transition_rows]
Expand Down Expand Up @@ -203,4 +201,4 @@ def convert_classes_to_dict(key: dict) -> dict:
except KeyError:
pass

return key
return key
6 changes: 3 additions & 3 deletions src/spyglass/decoding/v0/sorted_spikes.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,7 @@ def fetch1(self, *args, **kwargs):


def get_spike_indicator(
key: dict, time_range: tuple[float, float], sampling_rate: float = 500.0
key: dict, time_range: tuple, sampling_rate: float = 500.0
) -> pd.DataFrame:
"""For a given key, returns a dataframe with the spike indicator for each unit

Expand Down Expand Up @@ -313,7 +313,7 @@ def get_decoding_data_for_epoch(
interval_list_name: str,
position_info_param_name: str = "default",
additional_spike_keys: dict = {},
) -> tuple[pd.DataFrame, pd.DataFrame, list[slice]]:
):
"""Collects the data needed for decoding

Parameters
Expand Down Expand Up @@ -393,7 +393,7 @@ def get_data_for_multiple_epochs(
epoch_names: list,
position_info_param_name: str = "decoding",
additional_spike_keys: dict = {},
) -> tuple[pd.DataFrame, pd.DataFrame, list[slice], np.ndarray, np.ndarray]:
):
"""Collects the data needed for decoding for multiple epochs

Parameters
Expand Down
8 changes: 2 additions & 6 deletions src/spyglass/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -503,10 +503,6 @@ def waveforms_dir(self) -> str:
def temp_dir(self) -> str:
return self.config.get(self.dir_to_var("temp"))

@property
def waveform_dir(self) -> str:
return self.config.get(self.dir_to_var("waveform"))

@property
def video_dir(self) -> str:
return self.config.get(self.dir_to_var("video"))
Expand Down Expand Up @@ -553,7 +549,7 @@ def dlc_output_dir(self) -> str:
temp_dir = None
analysis_dir = None
sorting_dir = None
waveform_dir = None
waveforms_dir = None
video_dir = None
dlc_project_dir = None
dlc_video_dir = None
Expand All @@ -566,7 +562,7 @@ def dlc_output_dir(self) -> str:
temp_dir = sg_config.temp_dir
analysis_dir = sg_config.analysis_dir
sorting_dir = sg_config.sorting_dir
waveform_dir = sg_config.waveform_dir
waveforms_dir = sg_config.waveforms_dir
video_dir = sg_config.video_dir
debug_mode = sg_config.debug_mode
test_mode = sg_config.test_mode
Expand Down
1 change: 1 addition & 0 deletions src/spyglass/spikesorting/v0/spikesorting_artifact.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,7 @@ def _get_artifact_times(
[(0, len(valid_timestamps) - 1)], artifact_intervals_new
)


# convert back to seconds
artifact_removed_valid_times = []
for i in artifact_removed_valid_times_ind:
Expand Down
4 changes: 2 additions & 2 deletions src/spyglass/spikesorting/v0/spikesorting_curation.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

from spyglass.common.common_interval import IntervalList
from spyglass.common.common_nwbfile import AnalysisNwbfile
from spyglass.settings import waveform_dir
from spyglass.settings import waveforms_dir
from spyglass.spikesorting.v0.merged_sorting_extractor import (
MergedSortingExtractor,
)
Expand Down Expand Up @@ -331,7 +331,7 @@ def make(self, key):

waveform_extractor_name = self._get_waveform_extractor_name(key)
key["waveform_extractor_path"] = str(
Path(waveform_dir) / Path(waveform_extractor_name)
Path(waveforms_dir) / Path(waveform_extractor_name)
)
if os.path.exists(key["waveform_extractor_path"]):
shutil.rmtree(key["waveform_extractor_path"])
Expand Down
4 changes: 3 additions & 1 deletion src/spyglass/spikesorting/v0/spikesorting_sorting.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,9 @@ def make(self, key: dict):
# need to remove tempdir and whiten from sorter_params
sorter_params.pop("tempdir", None)
sorter_params.pop("whiten", None)
sorter_params.pop("outputs", None)
sorter_params.pop("n_shifts",None)
sorter_params.pop("outputs",None)
sorter_params.pop('localization_dict',None)

# Detect peaks for clusterless decoding
detected_spikes = detect_peaks(recording, **sorter_params)
Expand Down
Loading