Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Log when file accessed #941

Merged
merged 10 commits into from
Apr 22, 2024
11 changes: 4 additions & 7 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,10 +1,6 @@
# Change Log

## [0.5.2] (Unreleased)

### Release Notes

<!-- Running draft to be removed immediately prior to release. -->
## [0.5.2] (April 22, 2024)

### Infrastructure

Expand All @@ -20,14 +16,15 @@
- Prioritize datajoint filepath entry for defining abs_path of analysis nwbfile
#918
- Fix potential duplicate entries in Merge part tables #922
- Add logging of AnalysisNwbfile creation time and size #937
- Add log of AnalysisNwbfile creation time, size, and access count #937, #941

### Pipelines

- Spikesorting
- Update calls in v0 pipeline for spikeinterface>=0.99 #893
- Fix method type of `get_spike_times` #904
- Add helper functions for restricting spikesorting results and linking to probe info #910
- Add helper functions for restricting spikesorting results and linking to
probe info #910
- Decoding
- Handle dimensions of clusterless `get_ahead_behind_distance` #904
- Fix improper handling of nwb file names with .strip #929
Expand Down
4 changes: 2 additions & 2 deletions CITATION.cff
Original file line number Diff line number Diff line change
Expand Up @@ -166,5 +166,5 @@ keywords:
- spike sorting
- kachery
license: MIT
version: 0.5.1
date-released: '2024-03-07'
version: 0.5.2
date-released: '2024-04-22'
54 changes: 48 additions & 6 deletions src/spyglass/common/common_nwbfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -693,21 +693,37 @@ def log(self, analysis_file_name, table=None):
table=table,
)

def increment_access(self, keys, table=None):
"""Passthrough to the AnalysisNwbfileLog table. Avoid new imports."""
if not isinstance(keys, list):
key = [keys]

for key in keys:
AnalysisNwbfileLog().increment_access(key, table=table)


@schema
class AnalysisNwbfileLog(dj.Manual):
definition = """
id: int auto_increment
---
dj_user: varchar(64)
-> AnalysisNwbfile
table=null: varchar(64)
timestamp = CURRENT_TIMESTAMP : timestamp
time_delta=null: float
file_size=null: float
dj_user : varchar(64) # user who created the file
timestamp = CURRENT_TIMESTAMP : timestamp # when the file was created
table = null : varchar(64) # creating table
time_delta = null : float # how long it took to create
file_size = null : float # size of the file in bytes
accessed = 0 : int # n times accessed
unique index (analysis_file_name)
"""

def log(self, analysis_file_name, time_delta, file_size, table=None):
def log(
self,
analysis_file_name=None,
time_delta=None,
file_size=None,
table=None,
):
"""Log the creation of an analysis NWB file.

Parameters
Expand All @@ -724,3 +740,29 @@ def log(self, analysis_file_name, time_delta, file_size, table=None):
"table": table,
}
)

def increment_access(self, key, table=None):
"""Increment the accessed field for the given analysis file name.

Parameters
----------
key : Union[str, dict]
The name of the analysis NWB file, or a key to the table.
table : str, optional
The table that created the file.
"""
if isinstance(key, str):
key = {"analysis_file_name": key}

if not (query := self & key):
self.log(**key, table=table)
entries = query.fetch(as_dict=True)

inserts = []
for entry in entries:
entry["accessed"] += 1
if table and not entry.get("table"):
entry["table"] = table
inserts.append(entry)

self.insert(inserts, replace=True)
7 changes: 2 additions & 5 deletions src/spyglass/decoding/v1/waveform_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,8 @@ def make(self, key):
nwb_file_name,
key["analysis_file_name"],
)
AnalysisNwbfile().log(key, table=self.full_table_name)

self.insert1(key)

@staticmethod
Expand Down Expand Up @@ -392,9 +394,4 @@ def _write_waveform_features_to_nwb(
units_object_id = nwbf.units.object_id
io.write(nwbf)

AnalysisNwbfile().log(
analysis_nwb_file,
table="`decoding_waveform_features`.`__unit_waveform_features`",
)

return analysis_nwb_file, units_object_id
5 changes: 2 additions & 3 deletions src/spyglass/spikesorting/v0/spikesorting_curation.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,9 +268,6 @@ def save_sorting_nwb(
else:
units_object_id = object_ids[0]

AnalysisNwbfile().log(
analysis_file_name, table="`spikesorting_curation`.`curation`"
)
return analysis_file_name, units_object_id


Expand Down Expand Up @@ -1003,6 +1000,8 @@ def make(self, key):
unit_ids=accepted_units,
labels=labels,
)

AnalysisNwbfile().log(key, table=self.full_table_name)
self.insert1(key)

# now add the units
Expand Down
4 changes: 1 addition & 3 deletions src/spyglass/spikesorting/v1/curation.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ def insert_curation(
key,
skip_duplicates=True,
)
AnalysisNwbfile().log(analysis_file_name, table=cls.full_table_name)

return key

Expand Down Expand Up @@ -425,9 +426,6 @@ def _write_sorting_to_nwb_with_curation(

units_object_id = nwbf.units.object_id
io.write(nwbf)
AnalysisNwbfile().log(
analysis_nwb_file, table="`spikesorting_v1_sorting`.`__spike_sorting`"
)
return analysis_nwb_file, units_object_id


Expand Down
5 changes: 1 addition & 4 deletions src/spyglass/spikesorting/v1/metric_curation.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,7 @@ def make(self, key):
nwb_file_name,
key["analysis_file_name"],
)
AnalysisNwbfile().log(key, table=self.full_table_name)
self.insert1(key)

@classmethod
Expand Down Expand Up @@ -586,8 +587,4 @@ def _write_metric_curation_to_nwb(

units_object_id = nwbf.units.object_id
io.write(nwbf)
AnalysisNwbfile().log(
analysis_nwb_file,
table="`spikesorting_v1_metric_curation`.`__metric_curation`",
)
return analysis_nwb_file, units_object_id
7 changes: 3 additions & 4 deletions src/spyglass/spikesorting/v1/recording.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,9 @@ def make(self, key):
(SpikeSortingRecordingSelection & key).fetch1("nwb_file_name"),
key["analysis_file_name"],
)
AnalysisNwbfile().log(
recording_nwb_file_name, table=self.full_table_name
)
self.insert1(key)

@classmethod
Expand Down Expand Up @@ -651,10 +654,6 @@ def _write_recording_to_nwb(
"ProcessedElectricalSeries"
].object_id
io.write(nwbfile)
AnalysisNwbfile().log(
analysis_nwb_file,
table="`spikesorting_v1_sorting`.`__spike_sorting_recording`",
)
return analysis_nwb_file, recording_object_id


Expand Down
4 changes: 1 addition & 3 deletions src/spyglass/spikesorting/v1/sorting.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,7 @@ def make(self, key: dict):
(SpikeSortingSelection & key).fetch1("nwb_file_name"),
key["analysis_file_name"],
)
AnalysisNwbfile().log(key, table=self.full_table_name)
self.insert1(key, skip_duplicates=True)

@classmethod
Expand Down Expand Up @@ -405,7 +406,4 @@ def _write_sorting_to_nwb(
)
units_object_id = nwbf.units.object_id
io.write(nwbf)
AnalysisNwbfile().log(
analysis_nwb_file, table="`spikesorting_v1_curation`.`curation_v1`"
)
return analysis_nwb_file, units_object_id
28 changes: 28 additions & 0 deletions src/spyglass/utils/dj_helper_fn.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@

import datajoint as dj
import numpy as np
from datajoint.user_tables import UserTable

from spyglass.utils.dj_chains import PERIPHERAL_TABLES
from spyglass.utils.logging import logger
from spyglass.utils.nwb_helper_fn import get_nwb_file

Expand Down Expand Up @@ -110,6 +112,26 @@ def dj_replace(original_table, new_values, key_column, replace_column):
return original_table


def get_fetching_table_from_stack(stack):
"""Get all classes from a stack of tables."""
classes = set()
for frame_info in stack:
locals_dict = frame_info.frame.f_locals
for obj in locals_dict.values():
if not isinstance(obj, UserTable):
continue # skip non-tables
if (name := obj.full_table_name) in PERIPHERAL_TABLES:
continue # skip common_nwbfile tables
classes.add(name)
if len(classes) > 1:
logger.warn(
f"Multiple classes found in stack: {classes}. "
"Please submit a bug report with the snippet used."
)
classes = None # predict only one but not sure, so return None
return next(iter(classes)) if classes else None


def get_nwb_table(query_expression, tbl, attr_name, *attrs, **kwargs):
"""Get the NWB file name and path from the given DataJoint query.

Expand Down Expand Up @@ -150,6 +172,11 @@ def get_nwb_table(query_expression, tbl, attr_name, *attrs, **kwargs):
query_expression * tbl.proj(nwb2load_filepath=attr_name)
).fetch(file_name_str)

if which == "analysis": # log access of analysis files to log table
AnalysisNwbfile().increment_access(
nwb_files, table=get_fetching_table_from_stack(inspect.stack())
)

return nwb_files, file_path_fn


Expand Down Expand Up @@ -185,6 +212,7 @@ def fetch_nwb(query_expression, nwb_master, *attrs, **kwargs):
nwb_files, file_path_fn = get_nwb_table(
query_expression, tbl, attr_name, *attrs, **kwargs
)

for file_name in nwb_files:
file_path = file_path_fn(file_name)
if not os.path.exists(file_path): # retrieve the file from kachery.
Expand Down
1 change: 1 addition & 0 deletions src/spyglass/utils/dj_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@ def fetch_nwb(self, *attrs, **kwargs):
Additional logic support Export table logging.
"""
table, tbl_attr = self._nwb_table_tuple

if self.export_id and "analysis" in tbl_attr:
tbl_pk = "analysis_file_name"
fnames = (self * table).fetch(tbl_pk)
Expand Down
Loading