diff --git a/CHANGELOG.md b/CHANGELOG.md index 231e328d6..bf8804795 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,7 @@ - Create class `SpyglassGroupPart` to aid delete propagations #899 - Fix bug report template #955 +- Add rollback option to `populate_all_common` #957 - Add long-distance restrictions via `<<` and `>>` operators. #943 - Fix relative pathing for `mkdocstring-python=>1.9.1`. #967, #968 diff --git a/notebooks/01_Insert_Data.ipynb b/notebooks/01_Insert_Data.ipynb index e68623e72..2a2297642 100644 --- a/notebooks/01_Insert_Data.ipynb +++ b/notebooks/01_Insert_Data.ipynb @@ -1082,8 +1082,22 @@ "- neural activity (extracellular recording of multiple brain areas)\n", "- etc.\n", "\n", - "_Note:_ this may take time as Spyglass creates the copy. You may see a prompt\n", - "about inserting device information.\n" + "_Notes:_ this may take time as Spyglass creates the copy. You may see a prompt\n", + "about inserting device information.\n", + "\n", + "By default, the session insert process is error permissive. It will log an\n", + "error and continue attempts across various tables. You have two options you can\n", + "toggle to adjust this.\n", + "\n", + "- `rollback_on_fail`: Default False. If True, errors will still be logged for\n", + " all tables and, if any are registered, the `Nwbfile` entry will be deleted.\n", + " This is helpful for knowing why your file failed, and making it easy to retry.\n", + "- `raise_err`: Default False. If True, errors will not be logged and will\n", + " instead be raised. This is useful for debugging and exploring the error stack.\n", + " The end result may be that some tables may still have entries from this file\n", + " that will need to be manually deleted after a failed attempt. 'transactions'\n", + " are used where possible to rollback sibling tables, but child table errors\n", + " will still leave entries from parent tables.\n" ] }, { @@ -1146,7 +1160,7 @@ } ], "source": [ - "sgi.insert_sessions(nwb_file_name)" + "sgi.insert_sessions(nwb_file_name, rollback_on_fail=False, raise_error=False)" ] }, { diff --git a/notebooks/py_scripts/01_Insert_Data.py b/notebooks/py_scripts/01_Insert_Data.py index 975ed4ac5..870c6907a 100644 --- a/notebooks/py_scripts/01_Insert_Data.py +++ b/notebooks/py_scripts/01_Insert_Data.py @@ -198,11 +198,25 @@ # - neural activity (extracellular recording of multiple brain areas) # - etc. # -# _Note:_ this may take time as Spyglass creates the copy. You may see a prompt +# _Notes:_ this may take time as Spyglass creates the copy. You may see a prompt # about inserting device information. # +# By default, the session insert process is error permissive. It will log an +# error and continue attempts across various tables. You have two options you can +# toggle to adjust this. +# +# - `rollback_on_fail`: Default False. If True, errors will still be logged for +# all tables and, if any are registered, the `Nwbfile` entry will be deleted. +# This is helpful for knowing why your file failed, and making it easy to retry. +# - `raise_err`: Default False. If True, errors will not be logged and will +# instead be raised. This is useful for debugging and exploring the error stack. +# The end result may be that some tables may still have entries from this file +# that will need to be manually deleted after a failed attempt. 'transactions' +# are used where possible to rollback sibling tables, but child table errors +# will still leave entries from parent tables. +# -sgi.insert_sessions(nwb_file_name) +sgi.insert_sessions(nwb_file_name, rollback_on_fail=False, raise_error=False) # ## Inspecting the data # diff --git a/notebooks/py_scripts/50_MUA_Detection.py b/notebooks/py_scripts/50_MUA_Detection.py new file mode 100644 index 000000000..bc319ff82 --- /dev/null +++ b/notebooks/py_scripts/50_MUA_Detection.py @@ -0,0 +1,111 @@ +# --- +# jupyter: +# jupytext: +# text_representation: +# extension: .py +# format_name: light +# format_version: '1.5' +# jupytext_version: 1.16.0 +# kernelspec: +# display_name: spyglass +# language: python +# name: python3 +# --- + +# + +import datajoint as dj +from pathlib import Path + +dj.config.load( + Path("../dj_local_conf.json").absolute() +) # load config for database connection info + +from spyglass.mua.v1.mua import MuaEventsV1, MuaEventsParameters + +# - + +MuaEventsParameters() + +MuaEventsV1() + +# + +from spyglass.position import PositionOutput + +nwb_copy_file_name = "mediumnwb20230802_.nwb" + +trodes_s_key = { + "nwb_file_name": nwb_copy_file_name, + "interval_list_name": "pos 0 valid times", + "trodes_pos_params_name": "single_led_upsampled", +} + +pos_merge_id = (PositionOutput.TrodesPosV1 & trodes_s_key).fetch1("merge_id") +pos_merge_id + +# + +from spyglass.spikesorting.analysis.v1.group import ( + SortedSpikesGroup, +) + +sorted_spikes_group_key = { + "nwb_file_name": nwb_copy_file_name, + "sorted_spikes_group_name": "test_group", + "unit_filter_params_name": "default_exclusion", +} + +SortedSpikesGroup & sorted_spikes_group_key + +# + +mua_key = { + "mua_param_name": "default", + **sorted_spikes_group_key, + "pos_merge_id": pos_merge_id, + "detection_interval": "pos 0 valid times", +} + +MuaEventsV1().populate(mua_key) +MuaEventsV1 & mua_key +# - + +mua_times = (MuaEventsV1 & mua_key).fetch1_dataframe() +mua_times + +# + +import matplotlib.pyplot as plt +import numpy as np + +fig, axes = plt.subplots(2, 1, sharex=True, figsize=(15, 4)) +speed = MuaEventsV1.get_speed(mua_key).to_numpy() +time = speed.index.to_numpy() +multiunit_firing_rate = MuaEventsV1.get_firing_rate(mua_key, time) + +time_slice = slice( + np.searchsorted(time, mua_times.loc[10].start_time) - 1_000, + np.searchsorted(time, mua_times.loc[10].start_time) + 5_000, +) + +axes[0].plot( + time[time_slice], + multiunit_firing_rate[time_slice], + color="black", +) +axes[0].set_ylabel("firing rate (Hz)") +axes[0].set_title("multiunit") +axes[1].fill_between(time[time_slice], speed[time_slice], color="lightgrey") +axes[1].set_ylabel("speed (cm/s)") +axes[1].set_xlabel("time (s)") + +for id, mua_time in mua_times.loc[ + np.logical_and( + mua_times["start_time"] > time[time_slice].min(), + mua_times["end_time"] < time[time_slice].max(), + ) +].iterrows(): + axes[0].axvspan( + mua_time["start_time"], mua_time["end_time"], color="red", alpha=0.5 + ) +# - + +(MuaEventsV1 & mua_key).create_figurl( + zscore_mua=True, +) diff --git a/src/spyglass/common/common_behav.py b/src/spyglass/common/common_behav.py index bdb769e73..b7e8d953b 100644 --- a/src/spyglass/common/common_behav.py +++ b/src/spyglass/common/common_behav.py @@ -43,12 +43,8 @@ class SpatialSeries(SpyglassMixin, dj.Part): name=null: varchar(32) # name of spatial series """ - def populate(self, keys=None): - """Insert position source data from NWB file. - - WARNING: populate method on Manual table is not protected by transaction - protections like other DataJoint tables. - """ + def _no_transaction_make(self, keys=None): + """Insert position source data from NWB file.""" if not isinstance(keys, list): keys = [keys] if isinstance(keys[0], (dj.Table, dj.expression.QueryExpression)): @@ -227,6 +223,12 @@ def _get_column_names(rp, pos_id): return column_names def make(self, key): + self._no_transaction_make(key) + + def _no_transaction_make(self, key): + """Make without transaction + + Allows populate_all_common to work within a single transaction.""" nwb_file_name = key["nwb_file_name"] interval_list_name = key["interval_list_name"] @@ -238,7 +240,7 @@ def make(self, key): PositionSource.get_epoch_num(interval_list_name) ] - self.insert1(key) + self.insert1(key, allow_direct_insert=True) self.PosObject.insert( [ dict( @@ -294,6 +296,12 @@ class StateScriptFile(SpyglassMixin, dj.Imported): _nwb_table = Nwbfile def make(self, key): + self._no_transaction_make(key) + + def _no_transaction_make(self, key): + """Make without transaction + + Allows populate_all_common to work within a single transaction.""" """Add a new row to the StateScriptFile table.""" nwb_file_name = key["nwb_file_name"] nwb_file_abspath = Nwbfile.get_abs_path(nwb_file_name) @@ -309,6 +317,7 @@ def make(self, key): ) return # See #849 + script_inserts = [] for associated_file_obj in associated_files.data_interfaces.values(): if not isinstance( associated_file_obj, ndx_franklab_novela.AssociatedFiles @@ -337,10 +346,13 @@ def make(self, key): # find the file associated with this epoch if str(key["epoch"]) in epoch_list: key["file_object_id"] = associated_file_obj.object_id - self.insert1(key) + script_inserts.append(key.copy()) else: logger.info("not a statescript file") + if script_inserts: + self.insert(script_inserts, allow_direct_insert=True) + @schema class VideoFile(SpyglassMixin, dj.Imported): diff --git a/src/spyglass/common/common_dio.py b/src/spyglass/common/common_dio.py index 3db854e6a..629adef47 100644 --- a/src/spyglass/common/common_dio.py +++ b/src/spyglass/common/common_dio.py @@ -27,6 +27,12 @@ class DIOEvents(SpyglassMixin, dj.Imported): _nwb_table = Nwbfile def make(self, key): + self._no_transaction_make(key) + + def _no_transaction_make(self, key): + """Make without transaction + + Allows populate_all_common to work within a single transaction.""" nwb_file_name = key["nwb_file_name"] nwb_file_abspath = Nwbfile.get_abs_path(nwb_file_name) nwbf = get_nwb_file(nwb_file_abspath) @@ -45,10 +51,17 @@ def make(self, key): key["interval_list_name"] = ( Raw() & {"nwb_file_name": nwb_file_name} ).fetch1("interval_list_name") + + dio_inserts = [] for event_series in behav_events.time_series.values(): key["dio_event_name"] = event_series.name key["dio_object_id"] = event_series.object_id - self.insert1(key, skip_duplicates=True) + dio_inserts.append(key.copy()) + self.insert( + dio_inserts, + skip_duplicates=True, + allow_direct_insert=True, + ) def plot_all_dio_events(self, return_fig=False): """Plot all DIO events in the session. diff --git a/src/spyglass/common/common_ephys.py b/src/spyglass/common/common_ephys.py index 1880340a9..d03f6edff 100644 --- a/src/spyglass/common/common_ephys.py +++ b/src/spyglass/common/common_ephys.py @@ -45,6 +45,12 @@ class ElectrodeGroup(SpyglassMixin, dj.Imported): """ def make(self, key): + self._no_transaction_make(key) + + def _no_transaction_make(self, key): + """Make without transaction + + Allows populate_all_common to work within a single transaction.""" nwb_file_name = key["nwb_file_name"] nwb_file_abspath = Nwbfile.get_abs_path(nwb_file_name) nwbf = get_nwb_file(nwb_file_abspath) @@ -69,7 +75,7 @@ def make(self, key): else: # if negative x coordinate # define target location as left hemisphere key["target_hemisphere"] = "Left" - self.insert1(key, skip_duplicates=True) + self.insert1(key, skip_duplicates=True, allow_direct_insert=True) @schema @@ -95,6 +101,12 @@ class Electrode(SpyglassMixin, dj.Imported): """ def make(self, key): + self._no_transaction_make(key) + + def _no_transaction_make(self, key): + """Make without transaction + + Allows populate_all_common to work within a single transaction.""" nwb_file_name = key["nwb_file_name"] nwb_file_abspath = Nwbfile.get_abs_path(nwb_file_name) nwbf = get_nwb_file(nwb_file_abspath) @@ -108,23 +120,32 @@ def make(self, key): else: electrode_config_dicts = dict() + electrode_constants = { + "x_warped": 0, + "y_warped": 0, + "z_warped": 0, + "contacts": "", + } + + electrode_inserts = [] electrodes = nwbf.electrodes.to_dataframe() for elect_id, elect_data in electrodes.iterrows(): - key["electrode_id"] = elect_id - key["name"] = str(elect_id) - key["electrode_group_name"] = elect_data.group_name - key["region_id"] = BrainRegion.fetch_add( - region_name=elect_data.group.location + key.update( + { + "electrode_id": elect_id, + "name": str(elect_id), + "electrode_group_name": elect_data.group_name, + "region_id": BrainRegion.fetch_add( + region_name=elect_data.group.location + ), + "x": elect_data.x, + "y": elect_data.y, + "z": elect_data.z, + "filtering": elect_data.filtering, + "impedance": elect_data.get("imp"), + **electrode_constants, + } ) - key["x"] = elect_data.x - key["y"] = elect_data.y - key["z"] = elect_data.z - key["x_warped"] = 0 - key["y_warped"] = 0 - key["z_warped"] = 0 - key["contacts"] = "" - key["filtering"] = elect_data.filtering - key["impedance"] = elect_data.get("imp") # rough check of whether the electrodes table was created by # rec_to_nwb and has the appropriate custom columns used by @@ -140,13 +161,17 @@ def make(self, key): and "bad_channel" in elect_data and "ref_elect_id" in elect_data ): - key["probe_id"] = elect_data.group.device.probe_type - key["probe_shank"] = elect_data.probe_shank - key["probe_electrode"] = elect_data.probe_electrode - key["bad_channel"] = ( - "True" if elect_data.bad_channel else "False" + key.update( + { + "probe_id": elect_data.group.device.probe_type, + "probe_shank": elect_data.probe_shank, + "probe_electrode": elect_data.probe_electrode, + "bad_channel": ( + "True" if elect_data.bad_channel else "False" + ), + "original_reference_electrode": elect_data.ref_elect_id, + } ) - key["original_reference_electrode"] = elect_data.ref_elect_id # override with information from the config YAML based on primary # key (electrode id) @@ -163,8 +188,13 @@ def make(self, key): ) else: key.update(electrode_config_dicts[elect_id]) + electrode_inserts.append(key.copy()) - self.insert1(key, skip_duplicates=True) + self.insert1( + key, + skip_duplicates=True, + allow_direct_insert=True, # for no_transaction, pop_all_common + ) @classmethod def create_from_config(cls, nwb_file_name: str): @@ -246,10 +276,17 @@ class Raw(SpyglassMixin, dj.Imported): _nwb_table = Nwbfile def make(self, key): + self._no_transaction_make(key) + + def _no_transaction_make(self, key): + """Make without transaction + + Allows populate_all_common to work within a single transaction.""" nwb_file_name = key["nwb_file_name"] nwb_file_abspath = Nwbfile.get_abs_path(nwb_file_name) nwbf = get_nwb_file(nwb_file_abspath) raw_interval_name = "raw data valid times" + # get the acquisition object try: # TODO this assumes there is a single item in NWBFile.acquisition @@ -261,19 +298,21 @@ def make(self, key): + f"Skipping entry in {self.full_table_name}" ) return + if rawdata.rate is not None: - sampling_rate = rawdata.rate + key["sampling_rate"] = rawdata.rate else: logger.info("Estimating sampling rate...") # NOTE: Only use first 1e6 timepoints to save time - sampling_rate = estimate_sampling_rate( + key["sampling_rate"] = estimate_sampling_rate( np.asarray(rawdata.timestamps[: int(1e6)]), 1.5, verbose=True ) - key["sampling_rate"] = sampling_rate - interval_dict = dict() - interval_dict["nwb_file_name"] = key["nwb_file_name"] - interval_dict["interval_list_name"] = raw_interval_name + interval_dict = { + "nwb_file_name": key["nwb_file_name"], + "interval_list_name": raw_interval_name, + } + if rawdata.rate is not None: interval_dict["valid_times"] = np.array( [[0, len(rawdata.data) / rawdata.rate]] @@ -291,18 +330,25 @@ def make(self, key): # now insert each of the electrodes as an individual row, but with the # same nwb_object_id - key["raw_object_id"] = rawdata.object_id - key["sampling_rate"] = sampling_rate logger.info( - f'Importing raw data: Sampling rate:\t{key["sampling_rate"]} Hz' + f'Importing raw data: Sampling rate:\t{key["sampling_rate"]} Hz\n' + + f'Number of valid intervals:\t{len(interval_dict["valid_times"])}' ) - logger.info( - f'Number of valid intervals:\t{len(interval_dict["valid_times"])}' + + key.update( + { + "raw_object_id": rawdata.object_id, + "interval_list_name": raw_interval_name, + "comments": rawdata.comments, + "description": rawdata.description, + } + ) + + self.insert1( + key, + skip_duplicates=True, + allow_direct_insert=True, ) - key["interval_list_name"] = raw_interval_name - key["comments"] = rawdata.comments - key["description"] = rawdata.description - self.insert1(key, skip_duplicates=True) def nwb_object(self, key): # TODO return the nwb_object; FIX: this should be replaced with a fetch @@ -330,6 +376,12 @@ class SampleCount(SpyglassMixin, dj.Imported): _nwb_table = Nwbfile def make(self, key): + self._no_transaction_make(key) + + def _no_transaction_make(self, key): + """Make without transaction + + Allows populate_all_common to work within a single transaction.""" nwb_file_name = key["nwb_file_name"] nwb_file_abspath = Nwbfile.get_abs_path(nwb_file_name) nwbf = get_nwb_file(nwb_file_abspath) @@ -343,7 +395,7 @@ def make(self, key): ) return # see #849 key["sample_count_object_id"] = sample_count.object_id - self.insert1(key) + self.insert1(key, allow_direct_insert=True) @schema diff --git a/src/spyglass/common/common_nwbfile.py b/src/spyglass/common/common_nwbfile.py index 19700d3b3..d5bba9e51 100644 --- a/src/spyglass/common/common_nwbfile.py +++ b/src/spyglass/common/common_nwbfile.py @@ -65,6 +65,7 @@ def insert_from_relative_file_name(cls, nwb_file_name): The relative path to the NWB file. """ nwb_file_abs_path = Nwbfile.get_abs_path(nwb_file_name, new_file=True) + assert os.path.exists( nwb_file_abs_path ), f"File does not exist: {nwb_file_abs_path}" diff --git a/src/spyglass/common/common_session.py b/src/spyglass/common/common_session.py index acb4a0826..e97934122 100644 --- a/src/spyglass/common/common_session.py +++ b/src/spyglass/common/common_session.py @@ -52,6 +52,12 @@ class Experimenter(SpyglassMixin, dj.Part): """ def make(self, key): + self._no_transaction_make(key) + + def _no_transaction_make(self, key): + """Make without transaction + + Allows populate_all_common to work within a single transaction.""" # These imports must go here to avoid cyclic dependencies # from .common_task import Task, TaskEpoch from .common_interval import IntervalList @@ -114,6 +120,7 @@ def make(self, key): "experiment_description": nwbf.experiment_description, }, skip_duplicates=True, + allow_direct_insert=True, # for populate_all_common ) logger.info("Skipping Apparatus for now...") diff --git a/src/spyglass/common/common_task.py b/src/spyglass/common/common_task.py index 0dffa4ac5..49fd7bb0e 100644 --- a/src/spyglass/common/common_task.py +++ b/src/spyglass/common/common_task.py @@ -97,6 +97,12 @@ class TaskEpoch(SpyglassMixin, dj.Imported): """ def make(self, key): + self._no_transaction_make(key) + + def _no_transaction_make(self, key): + """Make without transaction + + Allows populate_all_common to work within a single transaction.""" nwb_file_name = key["nwb_file_name"] nwb_file_abspath = Nwbfile().get_abs_path(nwb_file_name) nwbf = get_nwb_file(nwb_file_abspath) @@ -120,6 +126,7 @@ def make(self, key): logger.warn(f"No tasks processing module found in {nwbf}\n") return + task_inserts = [] for task in tasks_mod.data_interfaces.values(): if self.check_task_table(task): # check if the task is in the Task table and if not, add it @@ -169,7 +176,8 @@ def make(self, key): break # TODO case when interval is not found is not handled key["interval_list_name"] = interval - self.insert1(key) + task_inserts.append(key.copy()) + self.insert(task_inserts, allow_direct_insert=True) @classmethod def update_entries(cls, restrict={}): diff --git a/src/spyglass/common/populate_all_common.py b/src/spyglass/common/populate_all_common.py index 2972ed145..04df52dec 100644 --- a/src/spyglass/common/populate_all_common.py +++ b/src/spyglass/common/populate_all_common.py @@ -1,3 +1,5 @@ +from typing import List, Union + import datajoint as dj from spyglass.common.common_behav import ( @@ -20,54 +22,147 @@ from spyglass.utils import logger -def populate_all_common(nwb_file_name): - """Insert all common tables for a given NWB file.""" +def log_insert_error( + table: str, err: Exception, error_constants: dict = None +) -> None: + """Log a given error to the InsertError table. + + Parameters + ---------- + table : str + The table name where the error occurred. + err : Exception + The exception that was raised. + error_constants : dict, optional + Dictionary with keys for dj_user, connection_id, and nwb_file_name. + Defaults to checking dj.conn and using "Unknown" for nwb_file_name. + """ + if error_constants is None: + error_constants = dict( + dj_user=dj.config["database.user"], + connection_id=dj.conn().connection_id, + nwb_file_name="Unknown", + ) + InsertError.insert1( + dict( + **error_constants, + table=table.__name__, + error_type=type(err).__name__, + error_message=str(err), + error_raw=str(err), + ) + ) + + +def single_transaction_make( + tables: List[dj.Table], + nwb_file_name: str, + raise_err: bool = False, + error_constants: dict = None, +): + """For each table, run the `_no_transaction_make` method. + + Requires `allow_direct_insert` set to True within each method. Uses + nwb_file_name search table key_source for relevant key. Currently assumes + all tables will have exactly one key_source entry per nwb file. + """ + file_restr = {"nwb_file_name": nwb_file_name} + with Nwbfile.connection.transaction: + for table in tables: + logger.info(f"Populating {table.__name__}...") + + # If imported/computed table, get key from key_source + key_source = getattr(table, "key_source", None) + if key_source is None: # Generate key from parents + parents = table.parents(as_objects=True) + key_source = parents[0].proj() + for parent in parents[1:]: + key_source *= parent.proj() + pop_key = (key_source & file_restr).fetch1("KEY") + + try: + table()._no_transaction_make(pop_key) + except Exception as err: + if raise_err: + raise err + log_insert_error( + table=table, err=err, error_constants=error_constants + ) + + +def populate_all_common( + nwb_file_name, rollback_on_fail=False, raise_err=False +) -> Union[List, None]: + """Insert all common tables for a given NWB file. + + Parameters + ---------- + nwb_file_name : str + The name of the NWB file to populate. + rollback_on_fail : bool, optional + If True, will delete the Session entry if any errors occur. + Defaults to False. + raise_err : bool, optional + If True, will raise any errors that occur during population. + Defaults to False. This will prevent any rollback from occurring. + + Returns + ------- + List + A list of keys for InsertError entries if any errors occurred. + """ from spyglass.spikesorting.imported import ImportedSpikeSorting - key = [(Nwbfile & f"nwb_file_name LIKE '{nwb_file_name}'").proj()] - tables = [ - Session, - # NwbfileKachery, # Not used by default - ElectrodeGroup, - Electrode, - Raw, - SampleCount, - DIOEvents, - # SensorData, # Not used by default. Generates large files - RawPosition, - TaskEpoch, - StateScriptFile, - VideoFile, - PositionSource, - RawPosition, - ImportedSpikeSorting, - ] error_constants = dict( dj_user=dj.config["database.user"], connection_id=dj.conn().connection_id, nwb_file_name=nwb_file_name, ) - for table in tables: - logger.info(f"Populating {table.__name__}...") - try: - table.populate(key) - except Exception as e: - InsertError.insert1( - dict( - **error_constants, - table=table.__name__, - error_type=type(e).__name__, - error_message=str(e), - error_raw=str(e), - ) - ) - query = InsertError & error_constants - if query: - err_tables = query.fetch("table") + table_lists = [ + [ # Tables that can be inserted in a single transaction + Session, + ElectrodeGroup, # Depends on Session + Electrode, # Depends on ElectrodeGroup + Raw, # Depends on Session + SampleCount, # Depends on Session + DIOEvents, # Depends on Session + TaskEpoch, # Depends on Session + ImportedSpikeSorting, # Depends on Session + # NwbfileKachery, # Not used by default + # SensorData, # Not used by default. Generates large files + ], + [ # Tables that depend on above transaction + PositionSource, # Depends on Session + VideoFile, # Depends on TaskEpoch + StateScriptFile, # Depends on TaskEpoch + ], + [ + RawPosition, # Depends on PositionSource + ], + ] + + for tables in table_lists: + single_transaction_make( + tables=tables, + nwb_file_name=nwb_file_name, + raise_err=raise_err, + error_constants=error_constants, + ) + + err_query = InsertError & error_constants + nwbfile_query = Nwbfile & {"nwb_file_name": nwb_file_name} + + if err_query and nwbfile_query and rollback_on_fail: + logger.error(f"Rolling back population for {nwb_file_name}...") + # Should this be safemode=False to prevent confirmation prompt? + nwbfile_query.super_delete(warn=False) + + if err_query: + err_tables = err_query.fetch("table") logger.error( f"Errors occurred during population for {nwb_file_name}:\n\t" + f"Failed tables {err_tables}\n\t" + "See common_usage.InsertError for more details" ) - return query.fetch("KEY") + return err_query.fetch("KEY") diff --git a/src/spyglass/data_import/insert_sessions.py b/src/spyglass/data_import/insert_sessions.py index 329a7be42..a5d539e8e 100644 --- a/src/spyglass/data_import/insert_sessions.py +++ b/src/spyglass/data_import/insert_sessions.py @@ -12,7 +12,11 @@ from spyglass.utils.nwb_helper_fn import get_nwb_copy_filename -def insert_sessions(nwb_file_names: Union[str, List[str]]): +def insert_sessions( + nwb_file_names: Union[str, List[str]], + rollback_on_fail: bool = False, + raise_err: bool = False, +): """ Populate the dj database with new sessions. @@ -23,6 +27,10 @@ def insert_sessions(nwb_file_names: Union[str, List[str]]): existing .nwb files. Each file represents a session. Also accepts strings with glob wildcards (e.g., *) so long as the wildcard specifies exactly one file. + rollback_on_fail : bool, optional + If True, undo all inserts if an error occurs. Default is False. + raise_err : bool, optional + If True, raise an error if an error occurs. Default is False. """ if not isinstance(nwb_file_names, list): @@ -66,7 +74,11 @@ def insert_sessions(nwb_file_names: Union[str, List[str]]): # the raw data in the original file copy_nwb_link_raw_ephys(nwb_file_name, out_nwb_file_name) Nwbfile().insert_from_relative_file_name(out_nwb_file_name) - populate_all_common(out_nwb_file_name) + return populate_all_common( + out_nwb_file_name, + rollback_on_fail=rollback_on_fail, + raise_err=raise_err, + ) def copy_nwb_link_raw_ephys(nwb_file_name, out_nwb_file_name): diff --git a/src/spyglass/spikesorting/imported.py b/src/spyglass/spikesorting/imported.py index ca1bdc9d0..048502081 100644 --- a/src/spyglass/spikesorting/imported.py +++ b/src/spyglass/spikesorting/imported.py @@ -31,6 +31,13 @@ class Annotations(SpyglassMixin, dj.Part): """ def make(self, key): + self._no_transaction_make(key) + + def _no_transaction_make(self, key): + """Make without transaction + + Allows populate_all_common to work within a single transaction.""" + raise RuntimeError("TEMP: This is a test error. Please ignore.") orig_key = copy.deepcopy(key) nwb_file_abs_path = Nwbfile.get_abs_path(key["nwb_file_name"]) @@ -49,7 +56,7 @@ def make(self, key): key["object_id"] = nwbfile.units.object_id - self.insert1(key, skip_duplicates=True) + self.insert1(key, skip_duplicates=True, allow_direct_insert=True) part_name = SpikeSortingOutput._part_name(self.table_name) SpikeSortingOutput._merge_insert(