Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Periph table fallback on TableChain for experimenter summary #1035

Merged
merged 12 commits into from
Aug 27, 2024
18 changes: 3 additions & 15 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,20 +1,6 @@
# Change Log

## [0.5.3] (Unreleased)

## Release Notes

<!-- Running draft to be removed immediately prior to release. -->

```python
import datajoint as dj
from spyglass.common.common_behav import PositionIntervalMap
from spyglass.decoding.v1.core import PositionGroup

dj.schema("common_ripple").drop()
PositionIntervalMap.alter()
PositionGroup.alter()
```
## [0.5.3] (August 27, 2024)

### Infrastructure

Expand Down Expand Up @@ -46,6 +32,8 @@ PositionGroup.alter()
- Installation instructions -> Setup notebook. #1029
- Migrate SQL export tools to `utils` to support exporting `DandiPath` #1048
- Add tool for checking threads for metadata locks on a table #1063
- Use peripheral tables as fallback in `TableChains` #1035
- Ignore non-Spyglass tables during descendant check for `part_masters` #1035

### Pipelines

Expand Down
43 changes: 34 additions & 9 deletions src/spyglass/utils/dj_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ def _get_ft(self, table, with_restr=False, warn=True):

return ft & restr

def _is_out(self, table, warn=True):
def _is_out(self, table, warn=True, keep_alias=False):
"""Check if table is outside of spyglass."""
table = ensure_names(table)
if self.graph.nodes.get(table):
Expand Down Expand Up @@ -805,7 +805,8 @@ class TableChain(RestrGraph):
Returns path OrderedDict of full table names in chain. If directed is
True, uses directed graph. If False, uses undirected graph. Undirected
excludes PERIPHERAL_TABLES like interval_list, nwbfile, etc. to maintain
valid joins.
valid joins by default. If no path is found, another search is attempted
with PERIPHERAL_TABLES included.
cascade(restriction: str = None, direction: str = "up")
Given a restriction at the beginning, return a restricted FreeTable
object at the end of the chain. If direction is 'up', start at the child
Expand Down Expand Up @@ -835,8 +836,12 @@ def __init__(
super().__init__(seed_table=seed_table, verbose=verbose)

self._ignore_peripheral(except_tables=[self.parent, self.child])
self._ignore_outside_spy(except_tables=[self.parent, self.child])

self.no_visit.update(ensure_names(banned_tables) or [])

self.no_visit.difference_update(set([self.parent, self.child]))

self.searched_tables = set()
self.found_restr = False
self.link_type = None
Expand Down Expand Up @@ -872,7 +877,19 @@ def _ignore_peripheral(self, except_tables: List[str] = None):
except_tables = ensure_names(except_tables)
ignore_tables = set(PERIPHERAL_TABLES) - set(except_tables or [])
self.no_visit.update(ignore_tables)
self.undirect_graph.remove_nodes_from(ignore_tables)

def _ignore_outside_spy(self, except_tables: List[str] = None):
"""Ignore tables not shared on shared prefixes."""
except_tables = self._ensure_names(except_tables)
ignore_tables = set( # Ignore tables not in shared modules
[
t
for t in self.undirect_graph.nodes
if t not in except_tables
and self._is_out(t, warn=False, keep_alias=True)
]
)
self.no_visit.update(ignore_tables)

# --------------------------- Dunder Properties ---------------------------

Expand Down Expand Up @@ -1066,9 +1083,9 @@ def find_path(self, directed=True) -> List[str]:
List of names in the path.
"""
source, target = self.parent, self.child
search_graph = self.graph if directed else self.undirect_graph

search_graph.remove_nodes_from(self.no_visit)
search_graph = ( # Copy to ensure orig not modified by no_visit
self.graph.copy() if directed else self.undirect_graph.copy()
)

try:
path = shortest_path(search_graph, source, target)
Expand Down Expand Up @@ -1096,6 +1113,12 @@ def path(self) -> list:
self.link_type = "directed"
elif path := self.find_path(directed=False):
self.link_type = "undirected"
else: # Search with peripheral
self.no_visit.difference_update(PERIPHERAL_TABLES)
if path := self.find_path(directed=True):
self.link_type = "directed with peripheral"
elif path := self.find_path(directed=False):
self.link_type = "undirected with peripheral"
self.searched_path = True

return path
Expand Down Expand Up @@ -1126,9 +1149,11 @@ def cascade(
# Cascade will stop if any restriction is empty, so set rest to None
# This would cause issues if we want a table partway through the chain
# but that's not a typical use case, were the start and end are desired
non_numeric = [t for t in self.path if not t.isnumeric()]
if any(self._get_restr(t) is None for t in non_numeric):
for table in non_numeric:
safe_tbls = [
t for t in self.path if not t.isnumeric() and not self._is_out(t)
]
if any(self._get_restr(t) is None for t in safe_tbls):
for table in safe_tbls:
if table is not start:
self._set_restr(table, False, replace=True)

Expand Down
117 changes: 58 additions & 59 deletions src/spyglass/utils/dj_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,52 +261,41 @@ def fetch_pynapple(self, *attrs, **kwargs):

# ------------------------ delete_downstream_parts ------------------------

def _import_part_masters(self):
"""Import tables that may constrain a RestrGraph. See #1002"""
from spyglass.decoding.decoding_merge import DecodingOutput # noqa F401
from spyglass.decoding.v0.clusterless import (
UnitMarksIndicatorSelection,
) # noqa F401
from spyglass.decoding.v0.sorted_spikes import (
SortedSpikesIndicatorSelection,
) # noqa F401
from spyglass.decoding.v1.core import PositionGroup # noqa F401
from spyglass.lfp.analysis.v1 import LFPBandSelection # noqa F401
from spyglass.lfp.lfp_merge import LFPOutput # noqa F401
from spyglass.linearization.merge import ( # noqa F401
LinearizedPositionOutput,
LinearizedPositionV1,
)
from spyglass.mua.v1.mua import MuaEventsV1 # noqa F401
from spyglass.position.position_merge import PositionOutput # noqa F401
from spyglass.ripple.v1.ripple import RippleTimesV1 # noqa F401
from spyglass.spikesorting.analysis.v1.group import (
SortedSpikesGroup,
) # noqa F401
from spyglass.spikesorting.spikesorting_merge import (
SpikeSortingOutput,
) # noqa F401
from spyglass.spikesorting.v0.figurl_views import (
SpikeSortingRecordingView,
) # noqa F401

_ = (
DecodingOutput(),
LFPBandSelection(),
LFPOutput(),
LinearizedPositionOutput(),
LinearizedPositionV1(),
MuaEventsV1(),
PositionGroup(),
PositionOutput(),
RippleTimesV1(),
SortedSpikesGroup(),
SortedSpikesIndicatorSelection(),
SpikeSortingOutput(),
SpikeSortingRecordingView(),
UnitMarksIndicatorSelection(),
def load_shared_schemas(self, additional_prefixes: list = None) -> None:
"""Load shared schemas to include in graph traversal.

Parameters
----------
additional_prefixes : list, optional
Additional prefixes to load. Default None.
"""
all_shared = [
*SHARED_MODULES,
dj.config["database.user"],
"file",
"sharing",
]

if additional_prefixes:
all_shared.extend(additional_prefixes)

# Get a list of all shared schemas in spyglass
schemas = dj.conn().query(
"SELECT DISTINCT table_schema " # Unique schemas
+ "FROM information_schema.key_column_usage "
+ "WHERE"
+ ' table_name not LIKE "~%%"' # Exclude hidden
+ " AND constraint_name='PRIMARY'" # Only primary keys
+ "AND (" # Only shared schemas
+ " OR ".join([f"table_schema LIKE '{s}_%%'" for s in all_shared])
+ ") "
+ "ORDER BY table_schema;"
)

# Load the dependencies for all shared schemas
for schema in schemas:
dj.schema(schema[0]).connection.dependencies.load()

@cached_property
def _part_masters(self) -> set:
"""Set of master tables downstream of self.
Expand All @@ -318,23 +307,25 @@ def _part_masters(self) -> set:
part_masters = set()

def search_descendants(parent):
for desc in parent.descendants(as_objects=True):
for desc_name in parent.descendants():
if ( # Check if has master, is part
not (master := get_master(desc.full_table_name))
# has other non-master parent
or not set(desc.parents()) - set([master])
not (master := get_master(desc_name))
or master in part_masters # already in cache
or desc_name.replace("`", "").split("_")[0]
not in SHARED_MODULES
):
continue
if master not in part_masters:
part_masters.add(master)
search_descendants(dj.FreeTable(self.connection, master))
desc = dj.FreeTable(self.connection, desc_name)
if not set(desc.parents()) - set([master]): # no other parent
continue
part_masters.add(master)
search_descendants(dj.FreeTable(self.connection, master))

try:
_ = search_descendants(self)
except NetworkXError:
try: # Attempt to import missing table
self._import_part_masters()
try: # Attempt to import failing schema
self.load_shared_schemas()
_ = search_descendants(self)
except NetworkXError as e:
table_name = "".join(e.args[0].split("`")[1:4])
Expand Down Expand Up @@ -484,7 +475,7 @@ def _delete_deps(self) -> List[Table]:
self._member_pk = LabMember.primary_key[0]
return [LabMember, LabTeam, Session, schema.external, IntervalList]

def _get_exp_summary(self):
def _get_exp_summary(self) -> Union[QueryExpression, None]:
"""Get summary of experimenters for session(s), including NULL.

Parameters
Expand All @@ -494,9 +485,12 @@ def _get_exp_summary(self):

Returns
-------
str
Summary of experimenters for session(s).
Union[QueryExpression, None]
dj.Union object Summary of experimenters for session(s). If no link
to Session, return None.
"""
if not self._session_connection.has_link:
return None

Session = self._delete_deps[2]
SesExp = Session.Experimenter
Expand All @@ -521,8 +515,7 @@ def _session_connection(self):
"""Path from Session table to self. False if no connection found."""
from spyglass.utils.dj_graph import TableChain # noqa F401

connection = TableChain(parent=self._delete_deps[2], child=self)
return connection if connection.has_link else False
return TableChain(parent=self._delete_deps[2], child=self, verbose=True)

@cached_property
def _test_mode(self) -> bool:
Expand Down Expand Up @@ -564,7 +557,13 @@ def _check_delete_permission(self) -> None:
)
return

sess_summary = self._get_exp_summary()
if not (sess_summary := self._get_exp_summary()):
logger.warn(
f"Could not find a connection from {self.camel_name} "
+ "to Session.\n Be careful not to delete others' data."
)
return

experimenters = sess_summary.fetch(self._member_pk)
if None in experimenters:
raise PermissionError(
Expand Down
Loading