Skip to content

Commit

Permalink
Merge branch 'master' into dev
Browse files Browse the repository at this point in the history
  • Loading branch information
CBroz1 authored Mar 7, 2024
2 parents 8a1ff2d + 48cdf00 commit 3bfc33b
Show file tree
Hide file tree
Showing 7 changed files with 188 additions and 50 deletions.
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
### Infrastructure

- Add user roles to `database_settings.py`. #832
- Revise `dj_chains` to permit undirected paths for paths with multiple Merge
Tables. #846

### Pipelines

Expand All @@ -14,6 +16,9 @@
- Fixes to `_convert_mp4` #834
- Replace deprecated calls to `yaml.safe_load()` #834

- Spikesorting:
- Bug fix in single artifact interval edge case #859

## [0.5.0] (February 9, 2024)

### Infrastructure
Expand Down
3 changes: 2 additions & 1 deletion src/spyglass/spikesorting/v0/spikesorting_artifact.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,8 @@ def _get_artifact_times(
valid_timestamps[interval[1]] + half_removal_window_s,
]
# make the artifact intervals disjoint
artifact_intervals_s = reduce(_union_concat, artifact_intervals_s)
if len(artifact_intervals_s) > 1:
artifact_intervals_s = reduce(_union_concat, artifact_intervals_s)

# convert seconds back to indices
artifact_intervals_new = []
Expand Down
2 changes: 1 addition & 1 deletion src/spyglass/spikesorting/v0/spikesorting_populator.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,7 @@ def spikesorting_pipeline_populator(
curation_keys = (Curation() & sort_dict).fetch("KEY")
for curation_key in curation_keys:
CuratedSpikeSortingSelection.insert1(
curation_auto_key, skip_duplicates=True
curation_key, skip_duplicates=True
)

# Populate curated spike sorting
Expand Down
3 changes: 2 additions & 1 deletion src/spyglass/spikesorting/v1/artifact.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,8 @@ def _get_artifact_times(
]

# make the artifact intervals disjoint
artifact_intervals_s = reduce(_union_concat, artifact_intervals_s)
if len(artifact_intervals_s) > 1:
artifact_intervals_s = reduce(_union_concat, artifact_intervals_s)

# find non-artifact intervals in timestamps
artifact_removed_valid_times = interval_list_complement(
Expand Down
1 change: 1 addition & 0 deletions src/spyglass/utils/database_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
]
GRANT_ALL = "GRANT ALL PRIVILEGES ON "
GRANT_SEL = "GRANT SELECT ON "
GRANT_SHOW = "GRANT SHOW DATABASES ON "
CREATE_USR = "CREATE USER IF NOT EXISTS "
CREATE_ROLE = "CREATE ROLE IF NOT EXISTS "
TEMP_PASS = " IDENTIFIED BY 'temppass';"
Expand Down
183 changes: 148 additions & 35 deletions src/spyglass/utils/dj_chains.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,31 @@
from collections import OrderedDict
from functools import cached_property
from typing import List, Union

import datajoint as dj
import networkx as nx
from datajoint.expression import QueryExpression
from datajoint.table import Table
from datajoint.utils import get_master
from datajoint.utils import get_master, to_camel_case

from spyglass.utils.dj_merge_tables import RESERVED_PRIMARY_KEY as MERGE_PK
from spyglass.utils.logging import logger

# Tables that should be excluded from the undirected graph when finding paths
# to maintain valid joins.
PERIPHERAL_TABLES = [
"`common_interval`.`interval_list`",
"`common_nwbfile`.`__analysis_nwbfile_kachery`",
"`common_nwbfile`.`__nwbfile_kachery`",
"`common_nwbfile`.`analysis_nwbfile_kachery_selection`",
"`common_nwbfile`.`analysis_nwbfile_kachery`",
"`common_nwbfile`.`analysis_nwbfile`",
"`common_nwbfile`.`kachery_channel`",
"`common_nwbfile`.`nwbfile_kachery_selection`",
"`common_nwbfile`.`nwbfile_kachery`",
"`common_nwbfile`.`nwbfile`",
]


class TableChains:
"""Class for representing chains from parent to Merge table via parts.
Expand Down Expand Up @@ -65,6 +81,11 @@ def __repr__(self):
def __len__(self):
return len([c for c in self.chains if c.has_link])

@property
def max_len(self):
"""Return length of longest chain."""
return max([len(chain) for chain in self.chains])

def __getitem__(self, index: Union[int, str]):
"""Return FreeTable object at index."""
if isinstance(index, str):
Expand Down Expand Up @@ -106,10 +127,20 @@ class TableChain:
Cached attribute to store whether parent is linked to child. False if
child is not in parent.descendants or nx.NetworkXNoPath is raised by
nx.shortest_path.
_has_directed_link : bool
True if directed graph is used to find path. False if undirected graph.
graph : nx.DiGraph
Directed graph of parent's dependencies from datajoint.connection.
names : List[str]
List of full table names in chain. Generated by networkx.shortest_path.
List of full table names in chain.
objects : List[dj.FreeTable]
List of FreeTable objects for each table in chain.
attr_maps : List[dict]
List of attribute maps for each link in chain.
path : OrderedDict[str, Dict[str, Union[dj.FreeTable,dict]]]
Dictionary of full table names in chain. Keys are self.names
Values are a dict of free_table (self.objects) and
attr_map (dict of new_name: old_name, self.attr_map).
Methods
-------
Expand All @@ -121,14 +152,19 @@ class TableChain:
Return number of tables in chain.
__getitem__(index: Union[int, str])
Return FreeTable object at index, or use substring of table name.
find_path(directed=True)
Returns path OrderedDict of full table names in chain. If directed is
True, uses directed graph. If False, uses undirected graph. Undirected
excludes PERIPHERAL_TABLES like interval_list, nwbfile, etc. to maintain
valid joins.
join(restriction: str = None)
Return join of tables in chain with restriction applied to parent.
"""

def __init__(self, parent: Table, child: Table, connection=None):
self._connection = connection or parent.connection
if not self._connection.dependencies._loaded:
self._connection.dependencies.load()
self.graph = self._connection.dependencies
self.graph.load()

if ( # if child is a merge table
get_master(child.full_table_name) == ""
Expand All @@ -139,14 +175,23 @@ def __init__(self, parent: Table, child: Table, connection=None):
self._link_symbol = " -> "
self.parent = parent
self.child = child
self._has_link = child.full_table_name in parent.descendants()
self._has_link = True
self._has_directed_link = None

if child.full_table_name not in self.graph.nodes:
logger.warning(
"Can't find item in graph. Try importing: "
+ f"{child.full_table_name}"
)

def __str__(self):
"""Return string representation of chain: parent -> child."""
if not self._has_link:
return "No link"
return (
self.parent.table_name + self._link_symbol + self.child.table_name
to_camel_case(self.parent.table_name)
+ self._link_symbol
+ to_camel_case(self.child.table_name)
)

def __repr__(self):
Expand Down Expand Up @@ -182,43 +227,104 @@ def has_link(self) -> bool:
def pk_link(self, src, trg, data) -> float:
"""Return 1 if data["primary"] else float("inf").
Currently unused. Preserved for future debugging."""
Currently unused. Preserved for future debugging. shortest_path accepts
an option weight callable parameter.
nx.shortest_path(G, source, target,weight=pk_link)
"""
return 1 if data["primary"] else float("inf")

@cached_property
def names(self) -> List[str]:
def find_path(self, directed=True) -> OrderedDict:
"""Return list of full table names in chain.
Uses networkx.shortest_path. Ignores numeric table names, which are
Parameters
----------
directed : bool, optional
If True, use directed graph. If False, use undirected graph.
Defaults to True. Undirected permits paths to traverse from merge
part-parent -> merge part -> merge table. Undirected excludes
PERIPHERAL_TABLES likne interval_list, nwbfile, etc.
Returns
-------
OrderedDict
Dictionary of full table names in chain. Keys are full table names.
Values are free_table (dj.FreeTable representation) and attr_map
(dict of new_name: old_name). Attribute maps on the table upstream
of an alias node that can be used in .proj(). Returns None if no
path is found.
Ignores numeric table names in paths, which are
'gaps' or alias nodes in the graph. See datajoint.Diagram._make_graph
source code for comments on alias nodes.
"""
if not self._has_link:
return None
source, target = self.parent.full_table_name, self.child.full_table_name
if not directed:
self.graph = self.graph.to_undirected()
self.graph.remove_nodes_from(PERIPHERAL_TABLES)
try:
return [
name
for name in nx.shortest_path(
self.parent.connection.dependencies,
self.parent.full_table_name,
self.child.full_table_name,
# weight: optional callable to determine edge weight
# weight=self.pk_link,
)
if not name.isdigit()
]
path = nx.shortest_path(self.graph, source, target)
except nx.NetworkXNoPath:
self._has_link = False
return None

ret = OrderedDict()
prev_table = None
for i, table in enumerate(path):
if table.isnumeric(): # get proj() attribute map for alias node
if not prev_table:
raise ValueError("Alias node found without prev table.")
attr_map = self.graph[table][prev_table]["attr_map"]
ret[prev_table]["attr_map"] = attr_map
else:
free_table = dj.FreeTable(self._connection, table)
ret[table] = {"free_table": free_table, "attr_map": {}}
prev_table = table
return ret

@cached_property
def path(self) -> OrderedDict:
"""Return list of full table names in chain."""
if not self._has_link:
return None

link = None
if link := self.find_path(directed=True):
self._has_directed_link = True
elif link := self.find_path(directed=False):
self._has_directed_link = False

if link:
return link

self._has_link = False
return None

@cached_property
def names(self) -> List[str]:
"""Return list of full table names in chain."""
if self._has_link:
return list(self.path.keys())
return None

@cached_property
def objects(self) -> List[dj.FreeTable]:
"""Return list of FreeTable objects for each table in chain."""
return (
[dj.FreeTable(self._connection, name) for name in self.names]
if self.names
else None
)
"""Return list of FreeTable objects for each table in chain.
Unused. Preserved for future debugging.
"""
if self._has_link:
return [v["free_table"] for v in self.path.values()]
return None

@cached_property
def attr_maps(self) -> List[dict]:
"""Return list of attribute maps for each table in chain.
Unused. Preserved for future debugging.
"""
#
if self._has_link:
return [v["attr_map"] for v in self.path.values()]
return None

def join(
self, restriction: str = None, reverse_order: bool = False
Expand All @@ -236,16 +342,23 @@ def join(
if not self._has_link:
return None

objects = self.objects[::-1] if reverse_order else self.objects
restriction = restriction or self.parent.restriction or True
join = objects[0] & restriction
for table in objects[1:]:
path = (
OrderedDict(reversed(self.path.items()))
if reverse_order
else self.path
).copy()

_, first_val = path.popitem(last=False)
join = first_val["free_table"] & restriction
for i, val in enumerate(path.values()):
attr_map, free_table = val["attr_map"], val["free_table"]
try:
join = join.proj() * table
join = (join.proj() * free_table).proj(**attr_map)
except dj.DataJointError as e:
attribute = str(e).split("attribute ")[-1]
logger.error(
f"{str(self)} at {table.table_name} with {attribute}"
f"{str(self)} at {free_table.table_name} with {attribute}"
)
return None
return join
Loading

0 comments on commit 3bfc33b

Please sign in to comment.