Skip to content

Commit

Permalink
Transaction on populate_all_common (#957)
Browse files Browse the repository at this point in the history
* WIP: transaction on populate_all_common

* ✅ : Seperate rollback and raise err options
  • Loading branch information
CBroz1 authored May 10, 2024
1 parent 2f6634b commit 042fd1c
Show file tree
Hide file tree
Showing 13 changed files with 441 additions and 94 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

- Create class `SpyglassGroupPart` to aid delete propagations #899
- Fix bug report template #955
- Add rollback option to `populate_all_common` #957
- Add long-distance restrictions via `<<` and `>>` operators. #943
- Fix relative pathing for `mkdocstring-python=>1.9.1`. #967, #968

Expand Down
20 changes: 17 additions & 3 deletions notebooks/01_Insert_Data.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -1082,8 +1082,22 @@
"- neural activity (extracellular recording of multiple brain areas)\n",
"- etc.\n",
"\n",
"_Note:_ this may take time as Spyglass creates the copy. You may see a prompt\n",
"about inserting device information.\n"
"_Notes:_ this may take time as Spyglass creates the copy. You may see a prompt\n",
"about inserting device information.\n",
"\n",
"By default, the session insert process is error permissive. It will log an\n",
"error and continue attempts across various tables. You have two options you can\n",
"toggle to adjust this.\n",
"\n",
"- `rollback_on_fail`: Default False. If True, errors will still be logged for\n",
" all tables and, if any are registered, the `Nwbfile` entry will be deleted.\n",
" This is helpful for knowing why your file failed, and making it easy to retry.\n",
"- `raise_err`: Default False. If True, errors will not be logged and will\n",
" instead be raised. This is useful for debugging and exploring the error stack.\n",
" The end result may be that some tables may still have entries from this file\n",
" that will need to be manually deleted after a failed attempt. 'transactions'\n",
" are used where possible to rollback sibling tables, but child table errors\n",
" will still leave entries from parent tables.\n"
]
},
{
Expand Down Expand Up @@ -1146,7 +1160,7 @@
}
],
"source": [
"sgi.insert_sessions(nwb_file_name)"
"sgi.insert_sessions(nwb_file_name, rollback_on_fail=False, raise_error=False)"
]
},
{
Expand Down
18 changes: 16 additions & 2 deletions notebooks/py_scripts/01_Insert_Data.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,11 +198,25 @@
# - neural activity (extracellular recording of multiple brain areas)
# - etc.
#
# _Note:_ this may take time as Spyglass creates the copy. You may see a prompt
# _Notes:_ this may take time as Spyglass creates the copy. You may see a prompt
# about inserting device information.
#
# By default, the session insert process is error permissive. It will log an
# error and continue attempts across various tables. You have two options you can
# toggle to adjust this.
#
# - `rollback_on_fail`: Default False. If True, errors will still be logged for
# all tables and, if any are registered, the `Nwbfile` entry will be deleted.
# This is helpful for knowing why your file failed, and making it easy to retry.
# - `raise_err`: Default False. If True, errors will not be logged and will
# instead be raised. This is useful for debugging and exploring the error stack.
# The end result may be that some tables may still have entries from this file
# that will need to be manually deleted after a failed attempt. 'transactions'
# are used where possible to rollback sibling tables, but child table errors
# will still leave entries from parent tables.
#

sgi.insert_sessions(nwb_file_name)
sgi.insert_sessions(nwb_file_name, rollback_on_fail=False, raise_error=False)

# ## Inspecting the data
#
Expand Down
111 changes: 111 additions & 0 deletions notebooks/py_scripts/50_MUA_Detection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
# ---
# jupyter:
# jupytext:
# text_representation:
# extension: .py
# format_name: light
# format_version: '1.5'
# jupytext_version: 1.16.0
# kernelspec:
# display_name: spyglass
# language: python
# name: python3
# ---

# +
import datajoint as dj
from pathlib import Path

dj.config.load(
Path("../dj_local_conf.json").absolute()
) # load config for database connection info

from spyglass.mua.v1.mua import MuaEventsV1, MuaEventsParameters

# -

MuaEventsParameters()

MuaEventsV1()

# +
from spyglass.position import PositionOutput

nwb_copy_file_name = "mediumnwb20230802_.nwb"

trodes_s_key = {
"nwb_file_name": nwb_copy_file_name,
"interval_list_name": "pos 0 valid times",
"trodes_pos_params_name": "single_led_upsampled",
}

pos_merge_id = (PositionOutput.TrodesPosV1 & trodes_s_key).fetch1("merge_id")
pos_merge_id

# +
from spyglass.spikesorting.analysis.v1.group import (
SortedSpikesGroup,
)

sorted_spikes_group_key = {
"nwb_file_name": nwb_copy_file_name,
"sorted_spikes_group_name": "test_group",
"unit_filter_params_name": "default_exclusion",
}

SortedSpikesGroup & sorted_spikes_group_key

# +
mua_key = {
"mua_param_name": "default",
**sorted_spikes_group_key,
"pos_merge_id": pos_merge_id,
"detection_interval": "pos 0 valid times",
}

MuaEventsV1().populate(mua_key)
MuaEventsV1 & mua_key
# -

mua_times = (MuaEventsV1 & mua_key).fetch1_dataframe()
mua_times

# +
import matplotlib.pyplot as plt
import numpy as np

fig, axes = plt.subplots(2, 1, sharex=True, figsize=(15, 4))
speed = MuaEventsV1.get_speed(mua_key).to_numpy()
time = speed.index.to_numpy()
multiunit_firing_rate = MuaEventsV1.get_firing_rate(mua_key, time)

time_slice = slice(
np.searchsorted(time, mua_times.loc[10].start_time) - 1_000,
np.searchsorted(time, mua_times.loc[10].start_time) + 5_000,
)

axes[0].plot(
time[time_slice],
multiunit_firing_rate[time_slice],
color="black",
)
axes[0].set_ylabel("firing rate (Hz)")
axes[0].set_title("multiunit")
axes[1].fill_between(time[time_slice], speed[time_slice], color="lightgrey")
axes[1].set_ylabel("speed (cm/s)")
axes[1].set_xlabel("time (s)")

for id, mua_time in mua_times.loc[
np.logical_and(
mua_times["start_time"] > time[time_slice].min(),
mua_times["end_time"] < time[time_slice].max(),
)
].iterrows():
axes[0].axvspan(
mua_time["start_time"], mua_time["end_time"], color="red", alpha=0.5
)
# -

(MuaEventsV1 & mua_key).create_figurl(
zscore_mua=True,
)
28 changes: 20 additions & 8 deletions src/spyglass/common/common_behav.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,8 @@ class SpatialSeries(SpyglassMixin, dj.Part):
name=null: varchar(32) # name of spatial series
"""

def populate(self, keys=None):
"""Insert position source data from NWB file.
WARNING: populate method on Manual table is not protected by transaction
protections like other DataJoint tables.
"""
def _no_transaction_make(self, keys=None):
"""Insert position source data from NWB file."""
if not isinstance(keys, list):
keys = [keys]
if isinstance(keys[0], (dj.Table, dj.expression.QueryExpression)):
Expand Down Expand Up @@ -227,6 +223,12 @@ def _get_column_names(rp, pos_id):
return column_names

def make(self, key):
self._no_transaction_make(key)

def _no_transaction_make(self, key):
"""Make without transaction
Allows populate_all_common to work within a single transaction."""
nwb_file_name = key["nwb_file_name"]
interval_list_name = key["interval_list_name"]

Expand All @@ -238,7 +240,7 @@ def make(self, key):
PositionSource.get_epoch_num(interval_list_name)
]

self.insert1(key)
self.insert1(key, allow_direct_insert=True)
self.PosObject.insert(
[
dict(
Expand Down Expand Up @@ -294,6 +296,12 @@ class StateScriptFile(SpyglassMixin, dj.Imported):
_nwb_table = Nwbfile

def make(self, key):
self._no_transaction_make(key)

def _no_transaction_make(self, key):
"""Make without transaction
Allows populate_all_common to work within a single transaction."""
"""Add a new row to the StateScriptFile table."""
nwb_file_name = key["nwb_file_name"]
nwb_file_abspath = Nwbfile.get_abs_path(nwb_file_name)
Expand All @@ -309,6 +317,7 @@ def make(self, key):
)
return # See #849

script_inserts = []
for associated_file_obj in associated_files.data_interfaces.values():
if not isinstance(
associated_file_obj, ndx_franklab_novela.AssociatedFiles
Expand Down Expand Up @@ -337,10 +346,13 @@ def make(self, key):
# find the file associated with this epoch
if str(key["epoch"]) in epoch_list:
key["file_object_id"] = associated_file_obj.object_id
self.insert1(key)
script_inserts.append(key.copy())
else:
logger.info("not a statescript file")

if script_inserts:
self.insert(script_inserts, allow_direct_insert=True)


@schema
class VideoFile(SpyglassMixin, dj.Imported):
Expand Down
15 changes: 14 additions & 1 deletion src/spyglass/common/common_dio.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,12 @@ class DIOEvents(SpyglassMixin, dj.Imported):
_nwb_table = Nwbfile

def make(self, key):
self._no_transaction_make(key)

def _no_transaction_make(self, key):
"""Make without transaction
Allows populate_all_common to work within a single transaction."""
nwb_file_name = key["nwb_file_name"]
nwb_file_abspath = Nwbfile.get_abs_path(nwb_file_name)
nwbf = get_nwb_file(nwb_file_abspath)
Expand All @@ -45,10 +51,17 @@ def make(self, key):
key["interval_list_name"] = (
Raw() & {"nwb_file_name": nwb_file_name}
).fetch1("interval_list_name")

dio_inserts = []
for event_series in behav_events.time_series.values():
key["dio_event_name"] = event_series.name
key["dio_object_id"] = event_series.object_id
self.insert1(key, skip_duplicates=True)
dio_inserts.append(key.copy())
self.insert(
dio_inserts,
skip_duplicates=True,
allow_direct_insert=True,
)

def plot_all_dio_events(self, return_fig=False):
"""Plot all DIO events in the session.
Expand Down
Loading

0 comments on commit 042fd1c

Please sign in to comment.