diff --git a/CHANGELOG.md b/CHANGELOG.md index 67aee3a82..d1ab96d4e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,8 @@ ### Infrastructure - Add user roles to `database_settings.py`. #832 +- Revise `dj_chains` to permit undirected paths for paths with multiple Merge + Tables. #846 ### Pipelines @@ -14,6 +16,9 @@ - Fixes to `_convert_mp4` #834 - Replace deprecated calls to `yaml.safe_load()` #834 +- Spikesorting: + - Bug fix in single artifact interval edge case #859 + ## [0.5.0] (February 9, 2024) ### Infrastructure diff --git a/src/spyglass/spikesorting/v0/spikesorting_artifact.py b/src/spyglass/spikesorting/v0/spikesorting_artifact.py index 3afae293f..4ea90e092 100644 --- a/src/spyglass/spikesorting/v0/spikesorting_artifact.py +++ b/src/spyglass/spikesorting/v0/spikesorting_artifact.py @@ -278,7 +278,8 @@ def _get_artifact_times( valid_timestamps[interval[1]] + half_removal_window_s, ] # make the artifact intervals disjoint - artifact_intervals_s = reduce(_union_concat, artifact_intervals_s) + if len(artifact_intervals_s) > 1: + artifact_intervals_s = reduce(_union_concat, artifact_intervals_s) # convert seconds back to indices artifact_intervals_new = [] diff --git a/src/spyglass/spikesorting/v0/spikesorting_populator.py b/src/spyglass/spikesorting/v0/spikesorting_populator.py index a099ac167..df8926a8c 100644 --- a/src/spyglass/spikesorting/v0/spikesorting_populator.py +++ b/src/spyglass/spikesorting/v0/spikesorting_populator.py @@ -278,7 +278,7 @@ def spikesorting_pipeline_populator( curation_keys = (Curation() & sort_dict).fetch("KEY") for curation_key in curation_keys: CuratedSpikeSortingSelection.insert1( - curation_auto_key, skip_duplicates=True + curation_key, skip_duplicates=True ) # Populate curated spike sorting diff --git a/src/spyglass/spikesorting/v1/artifact.py b/src/spyglass/spikesorting/v1/artifact.py index 5aa51468c..6f241578e 100644 --- a/src/spyglass/spikesorting/v1/artifact.py +++ b/src/spyglass/spikesorting/v1/artifact.py @@ -295,7 +295,8 @@ def _get_artifact_times( ] # make the artifact intervals disjoint - artifact_intervals_s = reduce(_union_concat, artifact_intervals_s) + if len(artifact_intervals_s) > 1: + artifact_intervals_s = reduce(_union_concat, artifact_intervals_s) # find non-artifact intervals in timestamps artifact_removed_valid_times = interval_list_complement( diff --git a/src/spyglass/utils/database_settings.py b/src/spyglass/utils/database_settings.py index f9c2cfc26..7e1834313 100755 --- a/src/spyglass/utils/database_settings.py +++ b/src/spyglass/utils/database_settings.py @@ -21,6 +21,7 @@ ] GRANT_ALL = "GRANT ALL PRIVILEGES ON " GRANT_SEL = "GRANT SELECT ON " +GRANT_SHOW = "GRANT SHOW DATABASES ON " CREATE_USR = "CREATE USER IF NOT EXISTS " CREATE_ROLE = "CREATE ROLE IF NOT EXISTS " TEMP_PASS = " IDENTIFIED BY 'temppass';" diff --git a/src/spyglass/utils/dj_chains.py b/src/spyglass/utils/dj_chains.py index f281bfb37..4e05763fc 100644 --- a/src/spyglass/utils/dj_chains.py +++ b/src/spyglass/utils/dj_chains.py @@ -1,3 +1,4 @@ +from collections import OrderedDict from functools import cached_property from typing import List, Union @@ -5,11 +6,26 @@ import networkx as nx from datajoint.expression import QueryExpression from datajoint.table import Table -from datajoint.utils import get_master +from datajoint.utils import get_master, to_camel_case from spyglass.utils.dj_merge_tables import RESERVED_PRIMARY_KEY as MERGE_PK from spyglass.utils.logging import logger +# Tables that should be excluded from the undirected graph when finding paths +# to maintain valid joins. +PERIPHERAL_TABLES = [ + "`common_interval`.`interval_list`", + "`common_nwbfile`.`__analysis_nwbfile_kachery`", + "`common_nwbfile`.`__nwbfile_kachery`", + "`common_nwbfile`.`analysis_nwbfile_kachery_selection`", + "`common_nwbfile`.`analysis_nwbfile_kachery`", + "`common_nwbfile`.`analysis_nwbfile`", + "`common_nwbfile`.`kachery_channel`", + "`common_nwbfile`.`nwbfile_kachery_selection`", + "`common_nwbfile`.`nwbfile_kachery`", + "`common_nwbfile`.`nwbfile`", +] + class TableChains: """Class for representing chains from parent to Merge table via parts. @@ -65,6 +81,11 @@ def __repr__(self): def __len__(self): return len([c for c in self.chains if c.has_link]) + @property + def max_len(self): + """Return length of longest chain.""" + return max([len(chain) for chain in self.chains]) + def __getitem__(self, index: Union[int, str]): """Return FreeTable object at index.""" if isinstance(index, str): @@ -106,10 +127,20 @@ class TableChain: Cached attribute to store whether parent is linked to child. False if child is not in parent.descendants or nx.NetworkXNoPath is raised by nx.shortest_path. + _has_directed_link : bool + True if directed graph is used to find path. False if undirected graph. + graph : nx.DiGraph + Directed graph of parent's dependencies from datajoint.connection. names : List[str] - List of full table names in chain. Generated by networkx.shortest_path. + List of full table names in chain. objects : List[dj.FreeTable] List of FreeTable objects for each table in chain. + attr_maps : List[dict] + List of attribute maps for each link in chain. + path : OrderedDict[str, Dict[str, Union[dj.FreeTable,dict]]] + Dictionary of full table names in chain. Keys are self.names + Values are a dict of free_table (self.objects) and + attr_map (dict of new_name: old_name, self.attr_map). Methods ------- @@ -121,14 +152,19 @@ class TableChain: Return number of tables in chain. __getitem__(index: Union[int, str]) Return FreeTable object at index, or use substring of table name. + find_path(directed=True) + Returns path OrderedDict of full table names in chain. If directed is + True, uses directed graph. If False, uses undirected graph. Undirected + excludes PERIPHERAL_TABLES like interval_list, nwbfile, etc. to maintain + valid joins. join(restriction: str = None) Return join of tables in chain with restriction applied to parent. """ def __init__(self, parent: Table, child: Table, connection=None): self._connection = connection or parent.connection - if not self._connection.dependencies._loaded: - self._connection.dependencies.load() + self.graph = self._connection.dependencies + self.graph.load() if ( # if child is a merge table get_master(child.full_table_name) == "" @@ -139,14 +175,23 @@ def __init__(self, parent: Table, child: Table, connection=None): self._link_symbol = " -> " self.parent = parent self.child = child - self._has_link = child.full_table_name in parent.descendants() + self._has_link = True + self._has_directed_link = None + + if child.full_table_name not in self.graph.nodes: + logger.warning( + "Can't find item in graph. Try importing: " + + f"{child.full_table_name}" + ) def __str__(self): """Return string representation of chain: parent -> child.""" if not self._has_link: return "No link" return ( - self.parent.table_name + self._link_symbol + self.child.table_name + to_camel_case(self.parent.table_name) + + self._link_symbol + + to_camel_case(self.child.table_name) ) def __repr__(self): @@ -182,43 +227,104 @@ def has_link(self) -> bool: def pk_link(self, src, trg, data) -> float: """Return 1 if data["primary"] else float("inf"). - Currently unused. Preserved for future debugging.""" + Currently unused. Preserved for future debugging. shortest_path accepts + an option weight callable parameter. + nx.shortest_path(G, source, target,weight=pk_link) + """ return 1 if data["primary"] else float("inf") - @cached_property - def names(self) -> List[str]: + def find_path(self, directed=True) -> OrderedDict: """Return list of full table names in chain. - Uses networkx.shortest_path. Ignores numeric table names, which are + Parameters + ---------- + directed : bool, optional + If True, use directed graph. If False, use undirected graph. + Defaults to True. Undirected permits paths to traverse from merge + part-parent -> merge part -> merge table. Undirected excludes + PERIPHERAL_TABLES likne interval_list, nwbfile, etc. + + Returns + ------- + OrderedDict + Dictionary of full table names in chain. Keys are full table names. + Values are free_table (dj.FreeTable representation) and attr_map + (dict of new_name: old_name). Attribute maps on the table upstream + of an alias node that can be used in .proj(). Returns None if no + path is found. + + Ignores numeric table names in paths, which are 'gaps' or alias nodes in the graph. See datajoint.Diagram._make_graph source code for comments on alias nodes. """ - if not self._has_link: - return None + source, target = self.parent.full_table_name, self.child.full_table_name + if not directed: + self.graph = self.graph.to_undirected() + self.graph.remove_nodes_from(PERIPHERAL_TABLES) try: - return [ - name - for name in nx.shortest_path( - self.parent.connection.dependencies, - self.parent.full_table_name, - self.child.full_table_name, - # weight: optional callable to determine edge weight - # weight=self.pk_link, - ) - if not name.isdigit() - ] + path = nx.shortest_path(self.graph, source, target) except nx.NetworkXNoPath: - self._has_link = False return None + ret = OrderedDict() + prev_table = None + for i, table in enumerate(path): + if table.isnumeric(): # get proj() attribute map for alias node + if not prev_table: + raise ValueError("Alias node found without prev table.") + attr_map = self.graph[table][prev_table]["attr_map"] + ret[prev_table]["attr_map"] = attr_map + else: + free_table = dj.FreeTable(self._connection, table) + ret[table] = {"free_table": free_table, "attr_map": {}} + prev_table = table + return ret + + @cached_property + def path(self) -> OrderedDict: + """Return list of full table names in chain.""" + if not self._has_link: + return None + + link = None + if link := self.find_path(directed=True): + self._has_directed_link = True + elif link := self.find_path(directed=False): + self._has_directed_link = False + + if link: + return link + + self._has_link = False + return None + + @cached_property + def names(self) -> List[str]: + """Return list of full table names in chain.""" + if self._has_link: + return list(self.path.keys()) + return None + @cached_property def objects(self) -> List[dj.FreeTable]: - """Return list of FreeTable objects for each table in chain.""" - return ( - [dj.FreeTable(self._connection, name) for name in self.names] - if self.names - else None - ) + """Return list of FreeTable objects for each table in chain. + + Unused. Preserved for future debugging. + """ + if self._has_link: + return [v["free_table"] for v in self.path.values()] + return None + + @cached_property + def attr_maps(self) -> List[dict]: + """Return list of attribute maps for each table in chain. + + Unused. Preserved for future debugging. + """ + # + if self._has_link: + return [v["attr_map"] for v in self.path.values()] + return None def join( self, restriction: str = None, reverse_order: bool = False @@ -236,16 +342,23 @@ def join( if not self._has_link: return None - objects = self.objects[::-1] if reverse_order else self.objects restriction = restriction or self.parent.restriction or True - join = objects[0] & restriction - for table in objects[1:]: + path = ( + OrderedDict(reversed(self.path.items())) + if reverse_order + else self.path + ).copy() + + _, first_val = path.popitem(last=False) + join = first_val["free_table"] & restriction + for i, val in enumerate(path.values()): + attr_map, free_table = val["attr_map"], val["free_table"] try: - join = join.proj() * table + join = (join.proj() * free_table).proj(**attr_map) except dj.DataJointError as e: attribute = str(e).split("attribute ")[-1] logger.error( - f"{str(self)} at {table.table_name} with {attribute}" + f"{str(self)} at {free_table.table_name} with {attribute}" ) return None return join diff --git a/src/spyglass/utils/dj_mixin.py b/src/spyglass/utils/dj_mixin.py index 1b4b24ff6..29978ae88 100644 --- a/src/spyglass/utils/dj_mixin.py +++ b/src/spyglass/utils/dj_mixin.py @@ -1,3 +1,4 @@ +from collections import OrderedDict from functools import cached_property from time import time from typing import Dict, List, Union @@ -130,17 +131,23 @@ def _merge_tables(self) -> Dict[str, dj.FreeTable]: Cache of items in parents of self.descendants(as_objects=True). Both descendant and parent must have the reserved primary key 'merge_id'. """ - self.connection.dependencies.load() merge_tables = {} - for desc in self.descendants(as_objects=True): - if MERGE_PK not in desc.heading.names or not ( - master_name := get_master(desc.full_table_name) - ): - continue - master = dj.FreeTable(self.connection, master_name) - if MERGE_PK in master.heading.names: - merge_tables[master_name] = master + + def search_descendants(parent): + for desc in parent.descendants(as_objects=True): + if ( + MERGE_PK not in desc.heading.names + or not (master_name := get_master(desc.full_table_name)) + or master_name in merge_tables + ): + continue + master = dj.FreeTable(self.connection, master_name) + if MERGE_PK in master.heading.names: + merge_tables[master_name] = master + search_descendants(master) + + _ = search_descendants(self) logger.info( f"Building merge cache for {self.table_name}.\n\t" @@ -150,7 +157,7 @@ def _merge_tables(self) -> Dict[str, dj.FreeTable]: return merge_tables @cached_property - def _merge_chains(self) -> Dict[str, List[dj.FreeTable]]: + def _merge_chains(self) -> OrderedDict[str, List[dj.FreeTable]]: """Dict of chains to merges downstream of self Format: {full_table_name: TableChains}. @@ -158,14 +165,24 @@ def _merge_chains(self) -> Dict[str, List[dj.FreeTable]]: For each merge table found in _merge_tables, find the path from self to merge via merge parts. If the path is valid, add it to the dict. Cache prevents need to recompute whenever delete_downstream_merge is called - with a new restriction. To recompute, add `reload_cache=True` to call. + with a new restriction. To recompute, add `reload_cache=True` to + delete_downstream_merge call. """ merge_chains = {} for name, merge_table in self._merge_tables.items(): chains = TableChains(self, merge_table, connection=self.connection) if len(chains): merge_chains[name] = chains - return merge_chains + + # This is ordered by max_len of chain from self to merge, which assumes + # that the merge table with the longest chain is the most downstream. + # A more sophisticated approach would order by length from self to + # each merge part independently, but this is a good first approximation. + return OrderedDict( + sorted( + merge_chains.items(), key=lambda x: x[1].max_len, reverse=True + ) + ) def _get_chain(self, substring) -> TableChains: """Return chain from self to merge table with substring in name."""