Skip to content

Commit

Permalink
Periph table fallback on TableChain for experimenter summary (#1035)
Browse files Browse the repository at this point in the history
* Periph table fallback on TableChain

* Update Changelog

* Rely on search to remove no_visit, not id step

* Include generic load_shared_schemas

* Update changelog for release

* Allow add custom prefix for load schemas

* Fix merge error
  • Loading branch information
CBroz1 authored Aug 27, 2024
1 parent 012ea30 commit ecf468e
Show file tree
Hide file tree
Showing 3 changed files with 95 additions and 83 deletions.
18 changes: 3 additions & 15 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,20 +1,6 @@
# Change Log

## [0.5.3] (Unreleased)

## Release Notes

<!-- Running draft to be removed immediately prior to release. -->

```python
import datajoint as dj
from spyglass.common.common_behav import PositionIntervalMap
from spyglass.decoding.v1.core import PositionGroup

dj.schema("common_ripple").drop()
PositionIntervalMap.alter()
PositionGroup.alter()
```
## [0.5.3] (August 27, 2024)

### Infrastructure

Expand Down Expand Up @@ -46,6 +32,8 @@ PositionGroup.alter()
- Installation instructions -> Setup notebook. #1029
- Migrate SQL export tools to `utils` to support exporting `DandiPath` #1048
- Add tool for checking threads for metadata locks on a table #1063
- Use peripheral tables as fallback in `TableChains` #1035
- Ignore non-Spyglass tables during descendant check for `part_masters` #1035

### Pipelines

Expand Down
43 changes: 34 additions & 9 deletions src/spyglass/utils/dj_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ def _get_ft(self, table, with_restr=False, warn=True):

return ft & restr

def _is_out(self, table, warn=True):
def _is_out(self, table, warn=True, keep_alias=False):
"""Check if table is outside of spyglass."""
table = ensure_names(table)
if self.graph.nodes.get(table):
Expand Down Expand Up @@ -805,7 +805,8 @@ class TableChain(RestrGraph):
Returns path OrderedDict of full table names in chain. If directed is
True, uses directed graph. If False, uses undirected graph. Undirected
excludes PERIPHERAL_TABLES like interval_list, nwbfile, etc. to maintain
valid joins.
valid joins by default. If no path is found, another search is attempted
with PERIPHERAL_TABLES included.
cascade(restriction: str = None, direction: str = "up")
Given a restriction at the beginning, return a restricted FreeTable
object at the end of the chain. If direction is 'up', start at the child
Expand Down Expand Up @@ -835,8 +836,12 @@ def __init__(
super().__init__(seed_table=seed_table, verbose=verbose)

self._ignore_peripheral(except_tables=[self.parent, self.child])
self._ignore_outside_spy(except_tables=[self.parent, self.child])

self.no_visit.update(ensure_names(banned_tables) or [])

self.no_visit.difference_update(set([self.parent, self.child]))

self.searched_tables = set()
self.found_restr = False
self.link_type = None
Expand Down Expand Up @@ -872,7 +877,19 @@ def _ignore_peripheral(self, except_tables: List[str] = None):
except_tables = ensure_names(except_tables)
ignore_tables = set(PERIPHERAL_TABLES) - set(except_tables or [])
self.no_visit.update(ignore_tables)
self.undirect_graph.remove_nodes_from(ignore_tables)

def _ignore_outside_spy(self, except_tables: List[str] = None):
"""Ignore tables not shared on shared prefixes."""
except_tables = ensure_names(except_tables)
ignore_tables = set( # Ignore tables not in shared modules
[
t
for t in self.undirect_graph.nodes
if t not in except_tables
and self._is_out(t, warn=False, keep_alias=True)
]
)
self.no_visit.update(ignore_tables)

# --------------------------- Dunder Properties ---------------------------

Expand Down Expand Up @@ -1066,9 +1083,9 @@ def find_path(self, directed=True) -> List[str]:
List of names in the path.
"""
source, target = self.parent, self.child
search_graph = self.graph if directed else self.undirect_graph

search_graph.remove_nodes_from(self.no_visit)
search_graph = ( # Copy to ensure orig not modified by no_visit
self.graph.copy() if directed else self.undirect_graph.copy()
)

try:
path = shortest_path(search_graph, source, target)
Expand Down Expand Up @@ -1096,6 +1113,12 @@ def path(self) -> list:
self.link_type = "directed"
elif path := self.find_path(directed=False):
self.link_type = "undirected"
else: # Search with peripheral
self.no_visit.difference_update(PERIPHERAL_TABLES)
if path := self.find_path(directed=True):
self.link_type = "directed with peripheral"
elif path := self.find_path(directed=False):
self.link_type = "undirected with peripheral"
self.searched_path = True

return path
Expand Down Expand Up @@ -1126,9 +1149,11 @@ def cascade(
# Cascade will stop if any restriction is empty, so set rest to None
# This would cause issues if we want a table partway through the chain
# but that's not a typical use case, were the start and end are desired
non_numeric = [t for t in self.path if not t.isnumeric()]
if any(self._get_restr(t) is None for t in non_numeric):
for table in non_numeric:
safe_tbls = [
t for t in self.path if not t.isnumeric() and not self._is_out(t)
]
if any(self._get_restr(t) is None for t in safe_tbls):
for table in safe_tbls:
if table is not start:
self._set_restr(table, False, replace=True)

Expand Down
117 changes: 58 additions & 59 deletions src/spyglass/utils/dj_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,52 +261,41 @@ def fetch_pynapple(self, *attrs, **kwargs):

# ------------------------ delete_downstream_parts ------------------------

def _import_part_masters(self):
"""Import tables that may constrain a RestrGraph. See #1002"""
from spyglass.decoding.decoding_merge import DecodingOutput # noqa F401
from spyglass.decoding.v0.clusterless import (
UnitMarksIndicatorSelection,
) # noqa F401
from spyglass.decoding.v0.sorted_spikes import (
SortedSpikesIndicatorSelection,
) # noqa F401
from spyglass.decoding.v1.core import PositionGroup # noqa F401
from spyglass.lfp.analysis.v1 import LFPBandSelection # noqa F401
from spyglass.lfp.lfp_merge import LFPOutput # noqa F401
from spyglass.linearization.merge import ( # noqa F401
LinearizedPositionOutput,
LinearizedPositionV1,
)
from spyglass.mua.v1.mua import MuaEventsV1 # noqa F401
from spyglass.position.position_merge import PositionOutput # noqa F401
from spyglass.ripple.v1.ripple import RippleTimesV1 # noqa F401
from spyglass.spikesorting.analysis.v1.group import (
SortedSpikesGroup,
) # noqa F401
from spyglass.spikesorting.spikesorting_merge import (
SpikeSortingOutput,
) # noqa F401
from spyglass.spikesorting.v0.figurl_views import (
SpikeSortingRecordingView,
) # noqa F401

_ = (
DecodingOutput(),
LFPBandSelection(),
LFPOutput(),
LinearizedPositionOutput(),
LinearizedPositionV1(),
MuaEventsV1(),
PositionGroup(),
PositionOutput(),
RippleTimesV1(),
SortedSpikesGroup(),
SortedSpikesIndicatorSelection(),
SpikeSortingOutput(),
SpikeSortingRecordingView(),
UnitMarksIndicatorSelection(),
def load_shared_schemas(self, additional_prefixes: list = None) -> None:
"""Load shared schemas to include in graph traversal.
Parameters
----------
additional_prefixes : list, optional
Additional prefixes to load. Default None.
"""
all_shared = [
*SHARED_MODULES,
dj.config["database.user"],
"file",
"sharing",
]

if additional_prefixes:
all_shared.extend(additional_prefixes)

# Get a list of all shared schemas in spyglass
schemas = dj.conn().query(
"SELECT DISTINCT table_schema " # Unique schemas
+ "FROM information_schema.key_column_usage "
+ "WHERE"
+ ' table_name not LIKE "~%%"' # Exclude hidden
+ " AND constraint_name='PRIMARY'" # Only primary keys
+ "AND (" # Only shared schemas
+ " OR ".join([f"table_schema LIKE '{s}_%%'" for s in all_shared])
+ ") "
+ "ORDER BY table_schema;"
)

# Load the dependencies for all shared schemas
for schema in schemas:
dj.schema(schema[0]).connection.dependencies.load()

@cached_property
def _part_masters(self) -> set:
"""Set of master tables downstream of self.
Expand All @@ -318,23 +307,25 @@ def _part_masters(self) -> set:
part_masters = set()

def search_descendants(parent):
for desc in parent.descendants(as_objects=True):
for desc_name in parent.descendants():
if ( # Check if has master, is part
not (master := get_master(desc.full_table_name))
# has other non-master parent
or not set(desc.parents()) - set([master])
not (master := get_master(desc_name))
or master in part_masters # already in cache
or desc_name.replace("`", "").split("_")[0]
not in SHARED_MODULES
):
continue
if master not in part_masters:
part_masters.add(master)
search_descendants(dj.FreeTable(self.connection, master))
desc = dj.FreeTable(self.connection, desc_name)
if not set(desc.parents()) - set([master]): # no other parent
continue
part_masters.add(master)
search_descendants(dj.FreeTable(self.connection, master))

try:
_ = search_descendants(self)
except NetworkXError:
try: # Attempt to import missing table
self._import_part_masters()
try: # Attempt to import failing schema
self.load_shared_schemas()
_ = search_descendants(self)
except NetworkXError as e:
table_name = "".join(e.args[0].split("`")[1:4])
Expand Down Expand Up @@ -484,7 +475,7 @@ def _delete_deps(self) -> List[Table]:
self._member_pk = LabMember.primary_key[0]
return [LabMember, LabTeam, Session, schema.external, IntervalList]

def _get_exp_summary(self):
def _get_exp_summary(self) -> Union[QueryExpression, None]:
"""Get summary of experimenters for session(s), including NULL.
Parameters
Expand All @@ -494,9 +485,12 @@ def _get_exp_summary(self):
Returns
-------
str
Summary of experimenters for session(s).
Union[QueryExpression, None]
dj.Union object Summary of experimenters for session(s). If no link
to Session, return None.
"""
if not self._session_connection.has_link:
return None

Session = self._delete_deps[2]
SesExp = Session.Experimenter
Expand All @@ -521,8 +515,7 @@ def _session_connection(self):
"""Path from Session table to self. False if no connection found."""
from spyglass.utils.dj_graph import TableChain # noqa F401

connection = TableChain(parent=self._delete_deps[2], child=self)
return connection if connection.has_link else False
return TableChain(parent=self._delete_deps[2], child=self, verbose=True)

@cached_property
def _test_mode(self) -> bool:
Expand Down Expand Up @@ -564,7 +557,13 @@ def _check_delete_permission(self) -> None:
)
return

sess_summary = self._get_exp_summary()
if not (sess_summary := self._get_exp_summary()):
logger.warn(
f"Could not find a connection from {self.camel_name} "
+ "to Session.\n Be careful not to delete others' data."
)
return

experimenters = sess_summary.fetch(self._member_pk)
if None in experimenters:
raise PermissionError(
Expand Down

0 comments on commit ecf468e

Please sign in to comment.