Skip to content

Commit

Permalink
get-part multi-source flag
Browse files Browse the repository at this point in the history
  • Loading branch information
CBroz1 committed Jul 3, 2023
1 parent e44ece3 commit 5b33b74
Showing 1 changed file with 26 additions and 12 deletions.
38 changes: 26 additions & 12 deletions src/spyglass/utils/dj_merge_tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def _merge_restrict_parts(
if (
not return_empties
and isinstance(restr_str, str)
and cls()._reserved_sk in restr_str
and f"`{cls()._reserved_sk}`" in restr_str
):
parts_all = [
part
Expand Down Expand Up @@ -455,6 +455,7 @@ def merge_get_part(
restriction: str = True,
join_master: bool = False,
restrict_part=True,
multi_source=False,
) -> dj.Table:
"""Retrieve part table from a restricted Merge table.
Expand All @@ -471,11 +472,13 @@ def merge_get_part(
restrict_part: bool
Apply restriction to part. Default True. If False, return the
native part table.
multi_source: bool
Return multiple parts. Default False.
Returns
------
dj.Table
Native part table of Merge Table master
Union[dj.Table, List[dj.Table]]
Native part table(s) of Merge. If `multi_source`, returns list.
Example
-------
Expand All @@ -485,7 +488,8 @@ def merge_get_part(
Raises
------
ValueError
If multiple sources are found, lists and suggests restricting
If multiple sources are found, but not expected lists and suggests
restricting
"""
sources = [
to_camel_case(n.split("__")[-1].strip("`")) # friendly part name
Expand All @@ -497,19 +501,23 @@ def merge_get_part(
)
]

if len(sources) != 1:
if not multi_source and len(sources) != 1:
raise ValueError(
f"Found multiple potential parts: {sources}\n\t"
+ "Try adding a restriction before invoking `get_part`."
+ "Try adding a restriction before invoking `get_part`.\n\t"
+ "Or permitting multiple sources with `multi_source=True`."
)

part = (
getattr(cls, sources[0])().restrict(restriction)
parts = [
getattr(cls, source)().restrict(restriction)
if restrict_part # Re-apply restriction or don't
else getattr(cls, sources[0])()
)
else getattr(cls, source)()
for source in sources
]
if join_master:
parts = [cls * part for part in parts]

return cls * part if join_master else part
return parts if multi_source else parts[0]

@classmethod
def merge_get_parent(
Expand Down Expand Up @@ -588,7 +596,13 @@ def merge_fetch(self, restriction: str = True, *attrs, **kwargs) -> list:
# Note: this could collapse results like merge_view, but user may call
# for recarray, pd.DataFrame, or dict, and fetched contents differ if
# attrs or "KEY" called. Intercept format, merge, and then transform?

if not results:
print(
"No merge_fetch results.\n\t"
+ "If not restriction, try: `M.merge_fetch(True,'attr')\n\t"
+ "If restricting by source, use dict: "
+ "`M.merge_fetch({'source':'X'})"
)
return results[0] if len(results) == 1 else results


Expand Down

0 comments on commit 5b33b74

Please sign in to comment.