Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
CBroz1 committed Feb 9, 2024
1 parent 228f375 commit b0b47cc
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 41 deletions.
5 changes: 4 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,10 @@
- Add overview of Spyglass to docs. #779
- Update linting for Black 24. #808
- Steamline dependency management. #822
- Add catch errorst during `populate_all_common`, log in `common_usage`. #XXX
- Add catch errors during `populate_all_common`, log in `common_usage`. #XXX
- Merge UUIDs #XXX
- Revise Merge table uuid generation to include source.
- Remove mutual exclusivity logic due to new UUID generation.

### Pipelines

Expand Down
23 changes: 15 additions & 8 deletions src/spyglass/common/common_behav.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,19 +42,24 @@ class SpatialSeries(SpyglassMixin, dj.Part):
name=null: varchar(32) # name of spatial series
"""

def populate(self, key=None):
def populate(self, keys=None):
"""Insert position source data from NWB file.
WARNING: populate method on Manual table is not protected by transaction
protections like other DataJoint tables.
"""
nwb_file_name = key.get("nwb_file_name")
if not nwb_file_name:
raise ValueError(
"PositionSource.populate is an alias for a non-computed table "
+ "and must be passed a key with nwb_file_name"
)
self.insert_from_nwbfile(nwb_file_name)
if not isinstance(keys, list):
keys = [keys]
if isinstance(keys[0], dj.Table):
keys = [k for tbl in keys for k in tbl.fetch("KEY", as_dict=True)]
for key in keys:
nwb_file_name = key.get("nwb_file_name")
if not nwb_file_name:
raise ValueError(
"PositionSource.populate is an alias for a non-computed table "
+ "and must be passed a key with nwb_file_name"
)
self.insert_from_nwbfile(nwb_file_name)

@classmethod
def insert_from_nwbfile(cls, nwb_file_name):
Expand Down Expand Up @@ -496,6 +501,7 @@ def _no_transaction_make(self, key):

# Skip populating if no pos interval list names
if len(pos_intervals) == 0:
# TODO: Now that populate_all accept errors, raise here?
logger.error(f"NO POS INTERVALS FOR {key}; {no_pop_msg}")
return

Expand Down Expand Up @@ -533,6 +539,7 @@ def _no_transaction_make(self, key):

# Check that each pos interval was matched to only one epoch
if len(matching_pos_intervals) != 1:
# TODO: Now that populate_all accept errors, raise here?
logger.error(
f"Found {len(matching_pos_intervals)} pos intervals for {key}; "
+ f"{no_pop_msg}\n{matching_pos_intervals}"
Expand Down
72 changes: 40 additions & 32 deletions src/spyglass/utils/dj_merge_tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,10 +274,8 @@ def _proj_part(part):
return query

@classmethod
def _merge_insert(
cls, rows: list, part_name: str = None, mutual_exclusvity=True, **kwargs
) -> None:
"""Insert rows into merge, ensuring db integrity and mutual exclusivity
def _merge_insert(cls, rows: list, part_name: str = None, **kwargs) -> None:
"""Insert rows into merge, ensuring data exists in part parent(s).
Parameters
---------
Expand All @@ -291,18 +289,17 @@ def _merge_insert(
TypeError
If rows is not a list of dicts
ValueError
If entry already exists, mutual exclusivity errors
If data doesn't exist in part parents, integrity error
"""
cls._ensure_dependencies_loaded()

type_err_msg = "Input `rows` must be a list of dictionaries"
try:
for r in iter(rows):
assert isinstance(
r, dict
), 'Input "rows" must be a list of dictionaries'
if not isinstance(r, dict):
raise TypeError(type_err_msg)
except TypeError:
raise TypeError('Input "rows" must be a list of dictionaries')
raise TypeError(type_err_msg)

parts = cls._merge_restrict_parts(as_objects=True)
if part_name:
Expand All @@ -315,30 +312,24 @@ def _merge_insert(
master_entries = []
parts_entries = {p: [] for p in parts}
for row in rows:
keys = [] # empty to-be-inserted key
keys = [] # empty to-be-inserted keys
for part in parts: # check each part
part_parent = part.parents(as_objects=True)[-1]
part_name = cls._part_name(part)
part_parent = part.parents(as_objects=True)[-1]
if part_parent & row: # if row is in part parent
if keys and mutual_exclusvity: # if key from other part
raise ValueError(
"Mutual Exclusivity Error! Entry exists in more "
+ f"than one table - Entry: {row}"
)

keys = (part_parent & row).fetch("KEY") # get pk
if len(keys) > 1:
raise ValueError(
"Ambiguous entry. Data has mult rows in "
+ f"{part_name}:\n\tData:{row}\n\t{keys}"
)
master_pk = { # make uuid
cls()._reserved_pk: dj.hash.key_hash(keys[0]),
}
parts_entries[part].append({**master_pk, **keys[0]})
master_entries.append(
{**master_pk, cls()._reserved_sk: part_name}
)
key = keys[0]
master_sk = {cls()._reserved_sk: part_name}
uuid = dj.hash.key_hash(key | master_sk)
master_pk = {cls()._reserved_pk: uuid}

master_entries.append({**master_pk, **master_sk})
parts_entries[part].append({**master_pk, **key})

if not keys:
raise ValueError(
Expand Down Expand Up @@ -369,27 +360,22 @@ def _ensure_dependencies_loaded(cls) -> None:
if not dj.conn.connection.dependencies._loaded:
dj.conn.connection.dependencies.load()

def insert(self, rows: list, mutual_exclusvity=True, **kwargs):
"""Merges table specific insert
Ensuring db integrity and mutual exclusivity
def insert(self, rows: list, **kwargs):
"""Merges table specific insert, ensuring data exists in part parents.
Parameters
---------
rows: List[dict]
An iterable where an element is a dictionary.
mutual_exclusvity: bool
Check for mutual exclusivity before insert. Default True.
Raises
------
TypeError
If rows is not a list of dicts
ValueError
If entry already exists, mutual exclusivity errors
If data doesn't exist in part parents, integrity error
"""
self._merge_insert(rows, mutual_exclusvity=mutual_exclusvity, **kwargs)
self._merge_insert(rows, **kwargs)

@classmethod
def merge_view(cls, restriction: str = True):
Expand Down Expand Up @@ -586,6 +572,8 @@ def merge_get_part(
+ "Try adding a restriction before invoking `get_part`.\n\t"
+ "Or permitting multiple sources with `multi_source=True`."
)
if len(sources) == 0:
return None

parts = [
(
Expand Down Expand Up @@ -777,6 +765,26 @@ def merge_populate(source: str, key=None):
+ "part_parent `make` and then inserting all entries into Merge"
)

def delete(self, force_permission=False, *args, **kwargs):
"""Alias for cautious_delete, overwrites datajoint.table.Table.delete"""
raise NotImplementedError(
"Please use delete_downstream_merge or cautious_delete "
+ "to clear merge entries."
)
# for part in self.merge_get_part(
# restriction=self.restriction,
# multi_source=True,
# return_empties=False,
# ):
# part.delete(force_permission=force_permission, *args, **kwargs)

def super_delete(self, *args, **kwargs):
"""Alias for datajoint.table.Table.delete.
Added to support MRO of SpyglassMixin"""
logger.warning("!! Using super_delete. Bypassing cautious_delete !!")
super().delete(*args, **kwargs)


_Merge = Merge

Expand Down

0 comments on commit b0b47cc

Please sign in to comment.