Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix #1103, #1057 #1184

Merged
merged 22 commits into from
Sep 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
- Added - Datajoint python CLI ([#940](https://github.com/datajoint/datajoint-python/issues/940)) PR [#1095](https://github.com/datajoint/datajoint-python/pull/1095)
- Added - Ability to set hidden attributes on a table - PR [#1091](https://github.com/datajoint/datajoint-python/pull/1091)
- Added - Ability to specify a list of keys to popuate - PR [#989](https://github.com/datajoint/datajoint-python/pull/989)
- Fixed - fixed topological sort [#1057](https://github.com/datajoint/datajoint-python/issues/1057)- PR [#1184](https://github.com/datajoint/datajoint-python/pull/1184)
- Fixed - .parts() not always returning parts [#1103](https://github.com/datajoint/datajoint-python/issues/1103)- PR [#1184](https://github.com/datajoint/datajoint-python/pull/1184)

### 0.14.2 -- Aug 19, 2024
- Added - Migrate nosetests to pytest - PR [#1142](https://github.com/datajoint/datajoint-python/pull/1142)
Expand Down
96 changes: 64 additions & 32 deletions datajoint/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,28 +5,64 @@
from .errors import DataJointError


def unite_master_parts(lst):
def extract_master(part_table):
"""
re-order a list of table names so that part tables immediately follow their master tables without breaking
the topological order.
Without this correction, a simple topological sort may insert other descendants between master and parts.
The input list must be topologically sorted.
:example:
unite_master_parts(
['`s`.`a`', '`s`.`a__q`', '`s`.`b`', '`s`.`c`', '`s`.`c__q`', '`s`.`b__q`', '`s`.`d`', '`s`.`a__r`']) ->
['`s`.`a`', '`s`.`a__q`', '`s`.`a__r`', '`s`.`b`', '`s`.`b__q`', '`s`.`c`', '`s`.`c__q`', '`s`.`d`']
given a part table name, return master part. None if not a part table
"""
for i in range(2, len(lst)):
name = lst[i]
match = re.match(r"(?P<master>`\w+`.`#?\w+)__\w+`", name)
if match: # name is a part table
master = match.group("master")
for j in range(i - 1, -1, -1):
if lst[j] == master + "`" or lst[j].startswith(master + "__"):
# move from the ith position to the (j+1)th position
lst[j + 1 : i + 1] = [name] + lst[j + 1 : i]
break
return lst
match = re.match(r"(?P<master>`\w+`.`#?\w+)__\w+`", part_table)
return match["master"] + "`" if match else None


def topo_sort(graph):
"""
topological sort of a dependency graph that keeps part tables together with their masters
:return: list of table names in topological order
"""

graph = nx.DiGraph(graph) # make a copy

# collapse alias nodes
alias_nodes = [node for node in graph if node.isdigit()]
for node in alias_nodes:
try:
direct_edge = (
next(x for x in graph.in_edges(node))[0],
next(x for x in graph.out_edges(node))[1],
)
except StopIteration:
pass # a disconnected alias node
else:
graph.add_edge(*direct_edge)
graph.remove_nodes_from(alias_nodes)

# Add parts' dependencies to their masters' dependencies
# to ensure correct topological ordering of the masters.
for part in graph:
# find the part's master
if (master := extract_master(part)) in graph:
for edge in graph.in_edges(part):
parent = edge[0]
if master not in (parent, extract_master(parent)):
# if parent is neither master nor part of master
ethho marked this conversation as resolved.
Show resolved Hide resolved
graph.add_edge(parent, master)
sorted_nodes = list(nx.topological_sort(graph))

# bring parts up to their masters
pos = len(sorted_nodes) - 1
placed = set()
while pos > 1:
part = sorted_nodes[pos]
if (master := extract_master(part)) not in graph or part in placed:
pos -= 1
else:
placed.add(part)
j = sorted_nodes.index(master)
if pos > j + 1:
# move the part to its master
del sorted_nodes[pos]
sorted_nodes.insert(j + 1, part)

return sorted_nodes


class Dependencies(nx.DiGraph):
Expand Down Expand Up @@ -131,6 +167,10 @@ def load(self, force=True):
raise DataJointError("DataJoint can only work with acyclic dependencies")
self._loaded = True

def topo_sort(self):
""":return: list of tables names in topological order"""
return topo_sort(self)

def parents(self, table_name, primary=None):
"""
:param table_name: `schema`.`table`
Expand Down Expand Up @@ -167,22 +207,14 @@ def descendants(self, full_table_name):
:return: all dependent tables sorted in topological order. Self is included.
"""
self.load(force=False)
nodes = self.subgraph(nx.algorithms.dag.descendants(self, full_table_name))
return unite_master_parts(
[full_table_name] + list(nx.algorithms.dag.topological_sort(nodes))
)
nodes = self.subgraph(nx.descendants(self, full_table_name))
return [full_table_name] + nodes.topo_sort()

def ancestors(self, full_table_name):
"""
:param full_table_name: In form `schema`.`table_name`
:return: all dependent tables sorted in topological order. Self is included.
"""
self.load(force=False)
nodes = self.subgraph(nx.algorithms.dag.ancestors(self, full_table_name))
return list(
reversed(
unite_master_parts(
list(nx.algorithms.dag.topological_sort(nodes)) + [full_table_name]
)
)
)
nodes = self.subgraph(nx.ancestors(self, full_table_name))
return reversed(nodes.topo_sort() + [full_table_name])
55 changes: 14 additions & 41 deletions datajoint/diagram.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
import networkx as nx
import re
import functools
import io
import logging
import inspect
from .table import Table
from .dependencies import unite_master_parts
from .user_tables import Manual, Imported, Computed, Lookup, Part
from .dependencies import topo_sort
from .user_tables import Manual, Imported, Computed, Lookup, Part, _get_tier, _AliasNode
from .errors import DataJointError
from .table import lookup_class_name

Expand All @@ -27,29 +26,6 @@


logger = logging.getLogger(__name__.split(".")[0])
user_table_classes = (Manual, Lookup, Computed, Imported, Part)


class _AliasNode:
"""
special class to indicate aliased foreign keys
"""

pass


def _get_tier(table_name):
if not table_name.startswith("`"):
return _AliasNode
else:
try:
return next(
tier
for tier in user_table_classes
if re.fullmatch(tier.tier_regexp, table_name.split("`")[-2])
)
except StopIteration:
return None


if not diagram_active:
Expand All @@ -59,8 +35,7 @@ class Diagram:
Entity relationship diagram, currently disabled due to the lack of required packages: matplotlib and pygraphviz.

To enable Diagram feature, please install both matplotlib and pygraphviz. For instructions on how to install
these two packages, refer to http://docs.datajoint.io/setup/Install-and-connect.html#python and
http://tutorials.datajoint.io/setting-up/datajoint-python.html
these two packages, refer to https://datajoint.com/docs/core/datajoint-python/0.14/client/install/
"""

def __init__(self, *args, **kwargs):
Expand All @@ -72,19 +47,22 @@ def __init__(self, *args, **kwargs):

class Diagram(nx.DiGraph):
"""
Entity relationship diagram.
Schema diagram showing tables and foreign keys between in the form of a directed
acyclic graph (DAG). The diagram is derived from the connection.dependencies object.

Usage:

>>> diag = Diagram(source)

source can be a base table object, a base table class, a schema, or a module that has a schema.
source can be a table object, a table class, a schema, or a module that has a schema.

>>> diag.draw()

draws the diagram using pyplot

diag1 + diag2 - combines the two diagrams.
diag1 - diag2 - difference between diagrams
diag1 * diag2 - intersection of diagrams
diag + n - expands n levels of successors
diag - n - expands n levels of predecessors
Thus dj.Diagram(schema.Table)+1-1 defines the diagram of immediate ancestors and descendants of schema.Table
Expand All @@ -94,6 +72,7 @@ class Diagram(nx.DiGraph):
"""

def __init__(self, source, context=None):

if isinstance(source, Diagram):
# copy constructor
self.nodes_to_show = set(source.nodes_to_show)
Expand Down Expand Up @@ -154,7 +133,7 @@ def from_sequence(cls, sequence):

def add_parts(self):
"""
Adds to the diagram the part tables of tables already included in the diagram
Adds to the diagram the part tables of all master tables already in the diagram
:return:
"""

Expand All @@ -179,16 +158,6 @@ def is_part(part, master):
)
return self

def topological_sort(self):
""":return: list of nodes in topological order"""
return unite_master_parts(
list(
nx.algorithms.dag.topological_sort(
nx.DiGraph(self).subgraph(self.nodes_to_show)
)
)
)

def __add__(self, arg):
"""
:param arg: either another Diagram or a positive integer.
Expand Down Expand Up @@ -256,6 +225,10 @@ def __mul__(self, arg):
self.nodes_to_show.intersection_update(arg.nodes_to_show)
return self

def topo_sort(self):
"""return nodes in lexicographical topological order"""
return topo_sort(self)

def _make_graph(self):
"""
Make the self.graph - a graph object ready for drawing
Expand Down
16 changes: 7 additions & 9 deletions datajoint/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,16 @@
import logging
import inspect
import re
import itertools
import collections
import itertools
from .connection import conn
from .diagram import Diagram, _get_tier
from .settings import config
from .errors import DataJointError, AccessError
from .jobs import JobTable
from .external import ExternalMapping
from .heading import Heading
from .utils import user_choice, to_camel_case
from .user_tables import Part, Computed, Imported, Manual, Lookup
from .user_tables import Part, Computed, Imported, Manual, Lookup, _get_tier
from .table import lookup_class_name, Log, FreeTable
import types

Expand Down Expand Up @@ -413,6 +412,7 @@ def save(self, python_filename=None):

:return: a string containing the body of a complete Python module defining this schema.
"""
self.connection.dependencies.load()
self._assert_exists()
module_count = itertools.count()
# add virtual modules for referenced modules with names vmod0, vmod1, ...
Expand Down Expand Up @@ -451,10 +451,8 @@ def replace(s):
).replace("\n", "\n " + indent),
)

diagram = Diagram(self)
body = "\n\n".join(
make_class_definition(table) for table in diagram.topological_sort()
)
tables = self.connection.dependencies.topo_sort()
body = "\n\n".join(make_class_definition(table) for table in tables)
python_code = "\n\n".join(
(
'"""This module was auto-generated by datajoint from an existing schema"""',
Expand All @@ -480,11 +478,12 @@ def list_tables(self):

:return: A list of table names from the database schema.
"""
self.connection.dependencies.load()
return [
t
for d, t in (
full_t.replace("`", "").split(".")
for full_t in Diagram(self).topological_sort()
for full_t in self.connection.dependencies.topo_sort()
)
if d == self.database
]
Expand Down Expand Up @@ -533,7 +532,6 @@ def __init__(

def list_schemas(connection=None):
"""

:param connection: a dj.Connection object
:return: list of all accessible schemas on the server
"""
Expand Down
13 changes: 7 additions & 6 deletions datajoint/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,6 @@ def parents(self, primary=None, as_objects=False, foreign_key_info=False):

def children(self, primary=None, as_objects=False, foreign_key_info=False):
"""

:param primary: if None, then all children are returned. If True, then only foreign keys composed of
primary key attributes are considered. If False, return foreign keys including at least one
secondary attribute.
Expand All @@ -218,7 +217,6 @@ def children(self, primary=None, as_objects=False, foreign_key_info=False):

def descendants(self, as_objects=False):
"""

:param as_objects: False - a list of table names; True - a list of table objects.
:return: list of tables descendants in topological order.
"""
Expand All @@ -230,7 +228,6 @@ def descendants(self, as_objects=False):

def ancestors(self, as_objects=False):
"""

:param as_objects: False - a list of table names; True - a list of table objects.
:return: list of tables ancestors in topological order.
"""
Expand All @@ -246,6 +243,7 @@ def parts(self, as_objects=False):

:param as_objects: if False (default), the output is a dict describing the foreign keys. If True, return table objects.
"""
self.connection.dependencies.load(force=False)
nodes = [
node
for node in self.connection.dependencies.nodes
Expand Down Expand Up @@ -427,7 +425,8 @@ def insert(
self.connection.query(query)
return

field_list = [] # collects the field list from first row (passed by reference)
# collects the field list from first row (passed by reference)
field_list = []
rows = list(
self.__make_row_to_insert(row, field_list, ignore_extra_fields)
for row in rows
Expand Down Expand Up @@ -520,7 +519,8 @@ def cascade(table):
delete_count = table.delete_quick(get_count=True)
except IntegrityError as error:
match = foreign_key_error_regexp.match(error.args[0]).groupdict()
if "`.`" not in match["child"]: # if schema name missing, use table
# if schema name missing, use table
if "`.`" not in match["child"]:
match["child"] = "{}.{}".format(
table.full_table_name.split(".")[0], match["child"]
)
Expand Down Expand Up @@ -964,7 +964,8 @@ def lookup_class_name(name, context, depth=3):
while nodes:
node = nodes.pop(0)
for member_name, member in node["context"].items():
if not member_name.startswith("_"): # skip IPython's implicit variables
# skip IPython's implicit variables
if not member_name.startswith("_"):
if inspect.isclass(member) and issubclass(member, Table):
if member.full_table_name == name: # found it!
return ".".join([node["context_name"], member_name]).lstrip(".")
Expand Down
Loading
Loading