From cb3600bfa04e6c4b9cbb90463e325e079182b413 Mon Sep 17 00:00:00 2001 From: Dimitri Yatsenko Date: Fri, 13 Sep 2024 17:45:57 -0500 Subject: [PATCH 01/21] Fix #1103, fix test_blob (for some numpy versions) --- datajoint/table.py | 12 +++++++----- tests/test_blob.py | 34 +++++++++++++++++++++++----------- 2 files changed, 30 insertions(+), 16 deletions(-) diff --git a/datajoint/table.py b/datajoint/table.py index 96e38082..b4daab9a 100644 --- a/datajoint/table.py +++ b/datajoint/table.py @@ -196,7 +196,6 @@ def parents(self, primary=None, as_objects=False, foreign_key_info=False): def children(self, primary=None, as_objects=False, foreign_key_info=False): """ - :param primary: if None, then all children are returned. If True, then only foreign keys composed of primary key attributes are considered. If False, return foreign keys including at least one secondary attribute. @@ -230,7 +229,6 @@ def descendants(self, as_objects=False): def ancestors(self, as_objects=False): """ - :param as_objects: False - a list of table names; True - a list of table objects. :return: list of tables ancestors in topological order. """ @@ -246,6 +244,7 @@ def parts(self, as_objects=False): :param as_objects: if False (default), the output is a dict describing the foreign keys. If True, return table objects. """ + self.connection.dependencies.load(force=False) nodes = [ node for node in self.connection.dependencies.nodes @@ -427,7 +426,8 @@ def insert( self.connection.query(query) return - field_list = [] # collects the field list from first row (passed by reference) + # collects the field list from first row (passed by reference) + field_list = [] rows = list( self.__make_row_to_insert(row, field_list, ignore_extra_fields) for row in rows @@ -520,7 +520,8 @@ def cascade(table): delete_count = table.delete_quick(get_count=True) except IntegrityError as error: match = foreign_key_error_regexp.match(error.args[0]).groupdict() - if "`.`" not in match["child"]: # if schema name missing, use table + # if schema name missing, use table + if "`.`" not in match["child"]: match["child"] = "{}.{}".format( table.full_table_name.split(".")[0], match["child"] ) @@ -962,7 +963,8 @@ def lookup_class_name(name, context, depth=3): while nodes: node = nodes.pop(0) for member_name, member in node["context"].items(): - if not member_name.startswith("_"): # skip IPython's implicit variables + # skip IPython's implicit variables + if not member_name.startswith("_"): if inspect.isclass(member) and issubclass(member, Table): if member.full_table_name == name: # found it! return ".".join([node["context_name"], member_name]).lstrip(".") diff --git a/tests/test_blob.py b/tests/test_blob.py index 12039f7f..3bce2f5d 100644 --- a/tests/test_blob.py +++ b/tests/test_blob.py @@ -124,16 +124,20 @@ def test_pack(): assert x == unpack(pack(x)), "Set did not pack/unpack correctly" x = tuple(range(10)) - assert x == unpack(pack(range(10))), "Iterator did not pack/unpack correctly" + assert x == unpack( + pack(range(10))), "Iterator did not pack/unpack correctly" x = Decimal("1.24") - assert x == approx(unpack(pack(x))), "Decimal object did not pack/unpack correctly" + assert x == approx( + unpack(pack(x))), "Decimal object did not pack/unpack correctly" x = datetime.now() - assert x == unpack(pack(x)), "Datetime object did not pack/unpack correctly" + assert x == unpack( + pack(x)), "Datetime object did not pack/unpack correctly" x = np.bool_(True) - assert x == unpack(pack(x)), "Numpy bool object did not pack/unpack correctly" + assert x == unpack( + pack(x)), "Numpy bool object did not pack/unpack correctly" x = "test" assert x == unpack(pack(x)), "String object did not pack/unpack correctly" @@ -154,13 +158,15 @@ def test_recarrays(): x = x.view(np.recarray) assert_array_equal(x, unpack(pack(x))) - x = np.array([(3, 4)], dtype=[("tmp0", float), ("tmp1", "O")]).view(np.recarray) + x = np.array([(3, 4)], dtype=[("tmp0", float), + ("tmp1", "O")]).view(np.recarray) assert_array_equal(x, unpack(pack(x))) def test_object_arrays(): x = np.array(((1, 2, 3), True), dtype="object") - assert_array_equal(x, unpack(pack(x)), "Object array did not serialize correctly") + assert_array_equal(x, unpack(pack(x)), + "Object array did not serialize correctly") def test_complex(): @@ -170,10 +176,12 @@ def test_complex(): z = np.random.randn(10) + 1j * np.random.randn(10) assert_array_equal(z, unpack(pack(z)), "Arrays do not match!") - x = np.float32(np.random.randn(3, 4, 5)) + 1j * np.float32(np.random.randn(3, 4, 5)) + x = np.float32(np.random.randn(3, 4, 5)) + 1j * \ + np.float32(np.random.randn(3, 4, 5)) assert_array_equal(x, unpack(pack(x)), "Arrays do not match!") - x = np.int16(np.random.randn(1, 2, 3)) + 1j * np.int16(np.random.randn(1, 2, 3)) + x = np.int16(np.random.randn(1, 2, 3)) + 1j * \ + np.int16(np.random.randn(1, 2, 3)) assert_array_equal(x, unpack(pack(x)), "Arrays do not match!") @@ -185,7 +193,8 @@ def test_insert_longblob(schema_any): query_mym_blob = {"id": 1, "data": np.array([1, 2, 3])} Longblob.insert1(query_mym_blob) - assert (Longblob & "id=1").fetch1()["data"].all() == query_mym_blob["data"].all() + assert_array_equal( + (Longblob & "id=1").fetch1()["data"], query_mym_blob["data"]) (Longblob & "id=1").delete() @@ -214,11 +223,14 @@ def test_insert_longblob_32bit(schema_any, enable_feature_32bit_dims): ) ] ], - dtype=[("hits", "O"), ("sides", "O"), ("tasks", "O"), ("stage", "O")], + dtype=[("hits", "O"), ("sides", "O"), + ("tasks", "O"), ("stage", "O")], ), } assert fetched["id"] == expected["id"] - assert np.array_equal(fetched["data"], expected["data"]) + for name in expected['data'][0][0].dtype.names: + assert_array_equal( + expected['data'][0][0][name], fetched['data'][0][0][name]) (Longblob & "id=1").delete() From a056a00bda80cc047e901cd0e0b72fa4f3c2bfc7 Mon Sep 17 00:00:00 2001 From: Dimitri Yatsenko Date: Fri, 13 Sep 2024 18:00:59 -0500 Subject: [PATCH 02/21] black formatting --- tests/test_blob.py | 35 ++++++++++++----------------------- 1 file changed, 12 insertions(+), 23 deletions(-) diff --git a/tests/test_blob.py b/tests/test_blob.py index 3bce2f5d..db03c687 100644 --- a/tests/test_blob.py +++ b/tests/test_blob.py @@ -124,20 +124,16 @@ def test_pack(): assert x == unpack(pack(x)), "Set did not pack/unpack correctly" x = tuple(range(10)) - assert x == unpack( - pack(range(10))), "Iterator did not pack/unpack correctly" + assert x == unpack(pack(range(10))), "Iterator did not pack/unpack correctly" x = Decimal("1.24") - assert x == approx( - unpack(pack(x))), "Decimal object did not pack/unpack correctly" + assert x == approx(unpack(pack(x))), "Decimal object did not pack/unpack correctly" x = datetime.now() - assert x == unpack( - pack(x)), "Datetime object did not pack/unpack correctly" + assert x == unpack(pack(x)), "Datetime object did not pack/unpack correctly" x = np.bool_(True) - assert x == unpack( - pack(x)), "Numpy bool object did not pack/unpack correctly" + assert x == unpack(pack(x)), "Numpy bool object did not pack/unpack correctly" x = "test" assert x == unpack(pack(x)), "String object did not pack/unpack correctly" @@ -158,15 +154,13 @@ def test_recarrays(): x = x.view(np.recarray) assert_array_equal(x, unpack(pack(x))) - x = np.array([(3, 4)], dtype=[("tmp0", float), - ("tmp1", "O")]).view(np.recarray) + x = np.array([(3, 4)], dtype=[("tmp0", float), ("tmp1", "O")]).view(np.recarray) assert_array_equal(x, unpack(pack(x))) def test_object_arrays(): x = np.array(((1, 2, 3), True), dtype="object") - assert_array_equal(x, unpack(pack(x)), - "Object array did not serialize correctly") + assert_array_equal(x, unpack(pack(x)), "Object array did not serialize correctly") def test_complex(): @@ -176,12 +170,10 @@ def test_complex(): z = np.random.randn(10) + 1j * np.random.randn(10) assert_array_equal(z, unpack(pack(z)), "Arrays do not match!") - x = np.float32(np.random.randn(3, 4, 5)) + 1j * \ - np.float32(np.random.randn(3, 4, 5)) + x = np.float32(np.random.randn(3, 4, 5)) + 1j * np.float32(np.random.randn(3, 4, 5)) assert_array_equal(x, unpack(pack(x)), "Arrays do not match!") - x = np.int16(np.random.randn(1, 2, 3)) + 1j * \ - np.int16(np.random.randn(1, 2, 3)) + x = np.int16(np.random.randn(1, 2, 3)) + 1j * np.int16(np.random.randn(1, 2, 3)) assert_array_equal(x, unpack(pack(x)), "Arrays do not match!") @@ -193,8 +185,7 @@ def test_insert_longblob(schema_any): query_mym_blob = {"id": 1, "data": np.array([1, 2, 3])} Longblob.insert1(query_mym_blob) - assert_array_equal( - (Longblob & "id=1").fetch1()["data"], query_mym_blob["data"]) + assert_array_equal((Longblob & "id=1").fetch1()["data"], query_mym_blob["data"]) (Longblob & "id=1").delete() @@ -223,14 +214,12 @@ def test_insert_longblob_32bit(schema_any, enable_feature_32bit_dims): ) ] ], - dtype=[("hits", "O"), ("sides", "O"), - ("tasks", "O"), ("stage", "O")], + dtype=[("hits", "O"), ("sides", "O"), ("tasks", "O"), ("stage", "O")], ), } assert fetched["id"] == expected["id"] - for name in expected['data'][0][0].dtype.names: - assert_array_equal( - expected['data'][0][0][name], fetched['data'][0][0][name]) + for name in expected["data"][0][0].dtype.names: + assert_array_equal(expected["data"][0][0][name], fetched["data"][0][0][name]) (Longblob & "id=1").delete() From 8985877ac9fa56c1f798b57a9e90ee21f1d31bf4 Mon Sep 17 00:00:00 2001 From: Dimitri Yatsenko Date: Fri, 13 Sep 2024 18:16:47 -0500 Subject: [PATCH 03/21] update CHANGELOG --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index c4e7b620..0978cf65 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,7 @@ - Added - Datajoint python CLI ([#940](https://github.com/datajoint/datajoint-python/issues/940)) PR [#1095](https://github.com/datajoint/datajoint-python/pull/1095) - Added - Ability to set hidden attributes on a table - PR [#1091](https://github.com/datajoint/datajoint-python/pull/1091) - Added - Ability to specify a list of keys to popuate - PR [#989](https://github.com/datajoint/datajoint-python/pull/989) +- Fixed - .parts() not always returning parts [#1103](https://github.com/datajoint/datajoint-python/issues/1103)- PR [#1184](https://github.com/datajoint/datajoint-python/pull/1184) ### 0.14.2 -- Aug 19, 2024 - Added - Migrate nosetests to pytest - PR [#1142](https://github.com/datajoint/datajoint-python/pull/1142) From 87c6884da7032c15518d8dd7f64c77b0ea13a8fc Mon Sep 17 00:00:00 2001 From: Dimitri Yatsenko Date: Fri, 13 Sep 2024 19:03:06 -0500 Subject: [PATCH 04/21] fix #1057 --- datajoint/dependencies.py | 32 ++------------------------ datajoint/diagram.py | 12 ++++------ tests/test_dependencies.py | 46 -------------------------------------- 3 files changed, 6 insertions(+), 84 deletions(-) diff --git a/datajoint/dependencies.py b/datajoint/dependencies.py index d9c425d4..84fd594b 100644 --- a/datajoint/dependencies.py +++ b/datajoint/dependencies.py @@ -5,30 +5,6 @@ from .errors import DataJointError -def unite_master_parts(lst): - """ - re-order a list of table names so that part tables immediately follow their master tables without breaking - the topological order. - Without this correction, a simple topological sort may insert other descendants between master and parts. - The input list must be topologically sorted. - :example: - unite_master_parts( - ['`s`.`a`', '`s`.`a__q`', '`s`.`b`', '`s`.`c`', '`s`.`c__q`', '`s`.`b__q`', '`s`.`d`', '`s`.`a__r`']) -> - ['`s`.`a`', '`s`.`a__q`', '`s`.`a__r`', '`s`.`b`', '`s`.`b__q`', '`s`.`c`', '`s`.`c__q`', '`s`.`d`'] - """ - for i in range(2, len(lst)): - name = lst[i] - match = re.match(r"(?P`\w+`.`#?\w+)__\w+`", name) - if match: # name is a part table - master = match.group("master") - for j in range(i - 1, -1, -1): - if lst[j] == master + "`" or lst[j].startswith(master + "__"): - # move from the ith position to the (j+1)th position - lst[j + 1 : i + 1] = [name] + lst[j + 1 : i] - break - return lst - - class Dependencies(nx.DiGraph): """ The graph of dependencies (foreign keys) between loaded tables. @@ -168,9 +144,7 @@ def descendants(self, full_table_name): """ self.load(force=False) nodes = self.subgraph(nx.algorithms.dag.descendants(self, full_table_name)) - return unite_master_parts( - [full_table_name] + list(nx.algorithms.dag.topological_sort(nodes)) - ) + return [full_table_name] + list(nx.algorithms.dag.topological_sort(nodes)) def ancestors(self, full_table_name): """ @@ -181,8 +155,6 @@ def ancestors(self, full_table_name): nodes = self.subgraph(nx.algorithms.dag.ancestors(self, full_table_name)) return list( reversed( - unite_master_parts( - list(nx.algorithms.dag.topological_sort(nodes)) + [full_table_name] - ) + list(nx.algorithms.dag.topological_sort(nodes)) + [full_table_name] ) ) diff --git a/datajoint/diagram.py b/datajoint/diagram.py index 7f47f746..0136ccaf 100644 --- a/datajoint/diagram.py +++ b/datajoint/diagram.py @@ -5,7 +5,6 @@ import logging import inspect from .table import Table -from .dependencies import unite_master_parts from .user_tables import Manual, Imported, Computed, Lookup, Part from .errors import DataJointError from .table import lookup_class_name @@ -59,8 +58,7 @@ class Diagram: Entity relationship diagram, currently disabled due to the lack of required packages: matplotlib and pygraphviz. To enable Diagram feature, please install both matplotlib and pygraphviz. For instructions on how to install - these two packages, refer to http://docs.datajoint.io/setup/Install-and-connect.html#python and - http://tutorials.datajoint.io/setting-up/datajoint-python.html + these two packages, refer to https://datajoint.com/docs/core/datajoint-python/0.14/client/install/ """ def __init__(self, *args, **kwargs): @@ -181,11 +179,9 @@ def is_part(part, master): def topological_sort(self): """:return: list of nodes in topological order""" - return unite_master_parts( - list( - nx.algorithms.dag.topological_sort( - nx.DiGraph(self).subgraph(self.nodes_to_show) - ) + return list( + nx.algorithms.dag.topological_sort( + nx.DiGraph(self).subgraph(self.nodes_to_show) ) ) diff --git a/tests/test_dependencies.py b/tests/test_dependencies.py index 987acc6c..5a4acd7d 100644 --- a/tests/test_dependencies.py +++ b/tests/test_dependencies.py @@ -1,51 +1,5 @@ from datajoint import errors from pytest import raises -from datajoint.dependencies import unite_master_parts - - -def test_unite_master_parts(): - assert unite_master_parts( - [ - "`s`.`a`", - "`s`.`a__q`", - "`s`.`b`", - "`s`.`c`", - "`s`.`c__q`", - "`s`.`b__q`", - "`s`.`d`", - "`s`.`a__r`", - ] - ) == [ - "`s`.`a`", - "`s`.`a__q`", - "`s`.`a__r`", - "`s`.`b`", - "`s`.`b__q`", - "`s`.`c`", - "`s`.`c__q`", - "`s`.`d`", - ] - assert unite_master_parts( - [ - "`lab`.`#equipment`", - "`cells`.`cell_analysis_method`", - "`cells`.`cell_analysis_method_task_type`", - "`cells`.`cell_analysis_method_users`", - "`cells`.`favorite_selection`", - "`cells`.`cell_analysis_method__cell_selection_params`", - "`lab`.`#equipment__config`", - "`cells`.`cell_analysis_method__field_detect_params`", - ] - ) == [ - "`lab`.`#equipment`", - "`lab`.`#equipment__config`", - "`cells`.`cell_analysis_method`", - "`cells`.`cell_analysis_method__cell_selection_params`", - "`cells`.`cell_analysis_method__field_detect_params`", - "`cells`.`cell_analysis_method_task_type`", - "`cells`.`cell_analysis_method_users`", - "`cells`.`favorite_selection`", - ] def test_nullable_dependency(thing_tables): From 207e6e98478cdd62c0108d1d75a61066b2f3e9a7 Mon Sep 17 00:00:00 2001 From: Dimitri Yatsenko Date: Fri, 13 Sep 2024 19:05:52 -0500 Subject: [PATCH 05/21] update CHANGELOG --- CHANGELOG.md | 1 + datajoint/dependencies.py | 1 - 2 files changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 0978cf65..4d90eea0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,7 @@ - Added - Datajoint python CLI ([#940](https://github.com/datajoint/datajoint-python/issues/940)) PR [#1095](https://github.com/datajoint/datajoint-python/pull/1095) - Added - Ability to set hidden attributes on a table - PR [#1091](https://github.com/datajoint/datajoint-python/pull/1091) - Added - Ability to specify a list of keys to popuate - PR [#989](https://github.com/datajoint/datajoint-python/pull/989) +- Fixed - fixed topological sort [#1057](https://github.com/datajoint/datajoint-python/issues/1057)- PR [#1184](https://github.com/datajoint/datajoint-python/pull/1184) - Fixed - .parts() not always returning parts [#1103](https://github.com/datajoint/datajoint-python/issues/1103)- PR [#1184](https://github.com/datajoint/datajoint-python/pull/1184) ### 0.14.2 -- Aug 19, 2024 diff --git a/datajoint/dependencies.py b/datajoint/dependencies.py index 84fd594b..06c380b5 100644 --- a/datajoint/dependencies.py +++ b/datajoint/dependencies.py @@ -1,6 +1,5 @@ import networkx as nx import itertools -import re from collections import defaultdict from .errors import DataJointError From 4cc712d4e62565aa84d65d4f7538d9071cb62524 Mon Sep 17 00:00:00 2001 From: Dimitri Yatsenko Date: Fri, 13 Sep 2024 19:16:19 -0500 Subject: [PATCH 06/21] removed the functionality for schema.code and schema.save() --- datajoint/schemas.py | 71 -------------------------------------------- 1 file changed, 71 deletions(-) diff --git a/datajoint/schemas.py b/datajoint/schemas.py index 62f45fa6..9abca14a 100644 --- a/datajoint/schemas.py +++ b/datajoint/schemas.py @@ -401,77 +401,6 @@ def jobs(self): self._jobs = JobTable(self.connection, self.database) return self._jobs - @property - def code(self): - self._assert_exists() - return self.save() - - def save(self, python_filename=None): - """ - Generate the code for a module that recreates the schema. - This method is in preparation for a future release and is not officially supported. - - :return: a string containing the body of a complete Python module defining this schema. - """ - self._assert_exists() - module_count = itertools.count() - # add virtual modules for referenced modules with names vmod0, vmod1, ... - module_lookup = collections.defaultdict( - lambda: "vmod" + str(next(module_count)) - ) - db = self.database - - def make_class_definition(table): - tier = _get_tier(table).__name__ - class_name = table.split(".")[1].strip("`") - indent = "" - if tier == "Part": - class_name = class_name.split("__")[-1] - indent += " " - class_name = to_camel_case(class_name) - - def replace(s): - d, tabs = s.group(1), s.group(2) - return ("" if d == db else (module_lookup[d] + ".")) + ".".join( - to_camel_case(tab) for tab in tabs.lstrip("__").split("__") - ) - - return ("" if tier == "Part" else "\n@schema\n") + ( - "{indent}class {class_name}(dj.{tier}):\n" - '{indent} definition = """\n' - '{indent} {defi}"""' - ).format( - class_name=class_name, - indent=indent, - tier=tier, - defi=re.sub( - r"`([^`]+)`.`([^`]+)`", - replace, - FreeTable(self.connection, table).describe(), - ).replace("\n", "\n " + indent), - ) - - diagram = Diagram(self) - body = "\n\n".join( - make_class_definition(table) for table in diagram.topological_sort() - ) - python_code = "\n\n".join( - ( - '"""This module was auto-generated by datajoint from an existing schema"""', - "import datajoint as dj\n\nschema = dj.Schema('{db}')".format(db=db), - "\n".join( - "{module} = dj.VirtualModule('{module}', '{schema_name}')".format( - module=v, schema_name=k - ) - for k, v in module_lookup.items() - ), - body, - ) - ) - if python_filename is None: - return python_code - with open(python_filename, "wt") as f: - f.write(python_code) def list_tables(self): """ From 24fe65d4110744930a4e0a21054977617da46cd0 Mon Sep 17 00:00:00 2001 From: Dimitri Yatsenko Date: Fri, 13 Sep 2024 19:17:43 -0500 Subject: [PATCH 07/21] formatting --- datajoint/schemas.py | 1 - 1 file changed, 1 deletion(-) diff --git a/datajoint/schemas.py b/datajoint/schemas.py index 9abca14a..9607870d 100644 --- a/datajoint/schemas.py +++ b/datajoint/schemas.py @@ -401,7 +401,6 @@ def jobs(self): self._jobs = JobTable(self.connection, self.database) return self._jobs - def list_tables(self): """ Return a list of all tables in the schema except tables with ~ in first character such From c56294695b073bce4776dd2e6422edb2e0e512af Mon Sep 17 00:00:00 2001 From: Dimitri Yatsenko Date: Fri, 13 Sep 2024 19:24:09 -0500 Subject: [PATCH 08/21] lint fix --- datajoint/schemas.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/datajoint/schemas.py b/datajoint/schemas.py index 9607870d..25c3f4b4 100644 --- a/datajoint/schemas.py +++ b/datajoint/schemas.py @@ -2,10 +2,8 @@ import logging import inspect import re -import itertools -import collections from .connection import conn -from .diagram import Diagram, _get_tier +from .diagram import Diagram from .settings import config from .errors import DataJointError, AccessError from .jobs import JobTable @@ -13,7 +11,7 @@ from .heading import Heading from .utils import user_choice, to_camel_case from .user_tables import Part, Computed, Imported, Manual, Lookup -from .table import lookup_class_name, Log, FreeTable +from .table import lookup_class_name, Log import types logger = logging.getLogger(__name__.split(".")[0]) From 13d7208119e335bb5a71b3accf12279c7e8ec6f3 Mon Sep 17 00:00:00 2001 From: Dimitri Yatsenko Date: Fri, 13 Sep 2024 19:37:05 -0500 Subject: [PATCH 09/21] reduce the speedup factor in blob serialization times --- tests/test_blob.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/test_blob.py b/tests/test_blob.py index db03c687..6c5a6f5a 100644 --- a/tests/test_blob.py +++ b/tests/test_blob.py @@ -249,4 +249,5 @@ def test_datetime_serialization_speed(): ) print(f"python time {baseline_exe_time}") - assert optimized_exe_time * 900 < baseline_exe_time + # The time savings were much greater (x1000) but use x10 for testing + assert optimized_exe_time * 10 < baseline_exe_time From 6719f4a12e1694fcb6276689be9056b9089b09ca Mon Sep 17 00:00:00 2001 From: Dimitri Yatsenko Date: Fri, 13 Sep 2024 19:52:13 -0500 Subject: [PATCH 10/21] remove tests for schema code generation --- tests/test_schema.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/tests/test_schema.py b/tests/test_schema.py index 257de221..857c1474 100644 --- a/tests/test_schema.py +++ b/tests/test_schema.py @@ -218,14 +218,6 @@ def test_list_tables(schema_simp): assert actual == expected, f"Missing from list_tables(): {expected - actual}" -def test_schema_save_any(schema_any): - assert "class Experiment(dj.Imported)" in schema_any.code - - -def test_schema_save_empty(schema_empty): - assert "class Experiment(dj.Imported)" in schema_empty.code - - def test_uppercase_schema(db_creds_root): """ https://github.com/datajoint/datajoint-python/issues/564 From 38d88135877ce1992e79b82551cf043b0eaa1a5f Mon Sep 17 00:00:00 2001 From: Dimitri Yatsenko Date: Fri, 13 Sep 2024 22:38:27 -0500 Subject: [PATCH 11/21] add schema.code back --- datajoint/diagram.py | 4 +-- datajoint/schemas.py | 78 ++++++++++++++++++++++++++++++++++++++++++-- 2 files changed, 78 insertions(+), 4 deletions(-) diff --git a/datajoint/diagram.py b/datajoint/diagram.py index 0136ccaf..884c01bf 100644 --- a/datajoint/diagram.py +++ b/datajoint/diagram.py @@ -178,9 +178,9 @@ def is_part(part, master): return self def topological_sort(self): - """:return: list of nodes in topological order""" + """:return: list of nodes in lexcigraphical topological order""" return list( - nx.algorithms.dag.topological_sort( + nx.algorithms.dag.lexicographical_topological_sort( nx.DiGraph(self).subgraph(self.nodes_to_show) ) ) diff --git a/datajoint/schemas.py b/datajoint/schemas.py index 25c3f4b4..e1212221 100644 --- a/datajoint/schemas.py +++ b/datajoint/schemas.py @@ -2,8 +2,10 @@ import logging import inspect import re +import collections +import itertools from .connection import conn -from .diagram import Diagram +from .diagram import Diagram, _get_tier from .settings import config from .errors import DataJointError, AccessError from .jobs import JobTable @@ -11,7 +13,7 @@ from .heading import Heading from .utils import user_choice, to_camel_case from .user_tables import Part, Computed, Imported, Manual, Lookup -from .table import lookup_class_name, Log +from .table import lookup_class_name, Log, FreeTable import types logger = logging.getLogger(__name__.split(".")[0]) @@ -399,6 +401,78 @@ def jobs(self): self._jobs = JobTable(self.connection, self.database) return self._jobs + @property + def code(self): + self._assert_exists() + return self.save() + + def save(self, python_filename=None): + """ + Generate the code for a module that recreates the schema. + This method is in preparation for a future release and is not officially supported. + + :return: a string containing the body of a complete Python module defining this schema. + """ + self._assert_exists() + module_count = itertools.count() + # add virtual modules for referenced modules with names vmod0, vmod1, ... + module_lookup = collections.defaultdict( + lambda: "vmod" + str(next(module_count)) + ) + db = self.database + + def make_class_definition(table): + tier = _get_tier(table).__name__ + class_name = table.split(".")[1].strip("`") + indent = "" + if tier == "Part": + class_name = class_name.split("__")[-1] + indent += " " + class_name = to_camel_case(class_name) + + def replace(s): + d, tabs = s.group(1), s.group(2) + return ("" if d == db else (module_lookup[d] + ".")) + ".".join( + to_camel_case(tab) for tab in tabs.lstrip("__").split("__") + ) + + return ("" if tier == "Part" else "\n@schema\n") + ( + "{indent}class {class_name}(dj.{tier}):\n" + '{indent} definition = """\n' + '{indent} {defi}"""' + ).format( + class_name=class_name, + indent=indent, + tier=tier, + defi=re.sub( + r"`([^`]+)`.`([^`]+)`", + replace, + FreeTable(self.connection, table).describe(), + ).replace("\n", "\n " + indent), + ) + + diagram = Diagram(self) + body = "\n\n".join( + make_class_definition(table) for table in diagram.topological_sort() + ) + python_code = "\n\n".join( + ( + '"""This module was auto-generated by datajoint from an existing schema"""', + "import datajoint as dj\n\nschema = dj.Schema('{db}')".format(db=db), + "\n".join( + "{module} = dj.VirtualModule('{module}', '{schema_name}')".format( + module=v, schema_name=k + ) + for k, v in module_lookup.items() + ), + body, + ) + ) + if python_filename is None: + return python_code + with open(python_filename, "wt") as f: + f.write(python_code) + def list_tables(self): """ Return a list of all tables in the schema except tables with ~ in first character such From a470d66aeed1abf85779e907511b28eb10f7bba0 Mon Sep 17 00:00:00 2001 From: Dimitri Yatsenko Date: Fri, 13 Sep 2024 23:51:56 -0500 Subject: [PATCH 12/21] optimize, fix topological sort. --- datajoint/dependencies.py | 16 ++++++++-------- datajoint/diagram.py | 8 -------- datajoint/schemas.py | 4 ++-- datajoint/table.py | 1 - tests/test_cli.py | 1 - tests/test_schema.py | 7 +++++++ 6 files changed, 17 insertions(+), 20 deletions(-) diff --git a/datajoint/dependencies.py b/datajoint/dependencies.py index 06c380b5..cb0fdbd4 100644 --- a/datajoint/dependencies.py +++ b/datajoint/dependencies.py @@ -106,6 +106,10 @@ def load(self, force=True): raise DataJointError("DataJoint can only work with acyclic dependencies") self._loaded = True + def topo_sort(self): + """:return: list of nodes in lexcigraphical topological order""" + return list(nx.algorithms.dag.lexicographical_topological_sort(self)) + def parents(self, table_name, primary=None): """ :param table_name: `schema`.`table` @@ -142,8 +146,8 @@ def descendants(self, full_table_name): :return: all dependent tables sorted in topological order. Self is included. """ self.load(force=False) - nodes = self.subgraph(nx.algorithms.dag.descendants(self, full_table_name)) - return [full_table_name] + list(nx.algorithms.dag.topological_sort(nodes)) + nodes = self.subgraph(nx.algorithms.dag.descendants(self, full_table_name)).copy() + return [full_table_name] + nodes.topo_sort() def ancestors(self, full_table_name): """ @@ -151,9 +155,5 @@ def ancestors(self, full_table_name): :return: all dependent tables sorted in topological order. Self is included. """ self.load(force=False) - nodes = self.subgraph(nx.algorithms.dag.ancestors(self, full_table_name)) - return list( - reversed( - list(nx.algorithms.dag.topological_sort(nodes)) + [full_table_name] - ) - ) + nodes = self.subgraph(nx.algorithms.dag.ancestors(self, full_table_name)).copy() + return reversed(nodes.topo_sort() + [full_table_name]) diff --git a/datajoint/diagram.py b/datajoint/diagram.py index 884c01bf..6ed3824b 100644 --- a/datajoint/diagram.py +++ b/datajoint/diagram.py @@ -177,14 +177,6 @@ def is_part(part, master): ) return self - def topological_sort(self): - """:return: list of nodes in lexcigraphical topological order""" - return list( - nx.algorithms.dag.lexicographical_topological_sort( - nx.DiGraph(self).subgraph(self.nodes_to_show) - ) - ) - def __add__(self, arg): """ :param arg: either another Diagram or a positive integer. diff --git a/datajoint/schemas.py b/datajoint/schemas.py index e1212221..650634b8 100644 --- a/datajoint/schemas.py +++ b/datajoint/schemas.py @@ -453,7 +453,7 @@ def replace(s): diagram = Diagram(self) body = "\n\n".join( - make_class_definition(table) for table in diagram.topological_sort() + make_class_definition(table) for table in diagram.topo_sort() ) python_code = "\n\n".join( ( @@ -484,7 +484,7 @@ def list_tables(self): t for d, t in ( full_t.replace("`", "").split(".") - for full_t in Diagram(self).topological_sort() + for full_t in Diagram(self).topo_sort() ) if d == self.database ] diff --git a/datajoint/table.py b/datajoint/table.py index a597956e..db9eaffa 100644 --- a/datajoint/table.py +++ b/datajoint/table.py @@ -217,7 +217,6 @@ def children(self, primary=None, as_objects=False, foreign_key_info=False): def descendants(self, as_objects=False): """ - :param as_objects: False - a list of table names; True - a list of table objects. :return: list of tables descendants in topological order. """ diff --git a/tests/test_cli.py b/tests/test_cli.py index 3f0fd00c..29fedf22 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -3,7 +3,6 @@ """ import json -import ast import subprocess import pytest import datajoint as dj diff --git a/tests/test_schema.py b/tests/test_schema.py index 857c1474..e44ac6ad 100644 --- a/tests/test_schema.py +++ b/tests/test_schema.py @@ -217,6 +217,13 @@ def test_list_tables(schema_simp): actual = set(schema_simp.list_tables()) assert actual == expected, f"Missing from list_tables(): {expected - actual}" +def test_schema_save_any(schema_any): + assert "class Experiment(dj.Imported)" in schema_any.code + + +def test_schema_save_empty(schema_empty): + assert "class Experiment(dj.Imported)" in schema_empty.code + def test_uppercase_schema(db_creds_root): """ From 4be8e39727a4cda1acee76c311957c433f43239f Mon Sep 17 00:00:00 2001 From: Dimitri Yatsenko Date: Sun, 15 Sep 2024 16:39:20 -0500 Subject: [PATCH 13/21] fix topological sort --- datajoint/dependencies.py | 64 +++++++++++++++++++++++++++++++++++++-- datajoint/diagram.py | 18 ++++++++--- 2 files changed, 75 insertions(+), 7 deletions(-) diff --git a/datajoint/dependencies.py b/datajoint/dependencies.py index cb0fdbd4..a9df0c4f 100644 --- a/datajoint/dependencies.py +++ b/datajoint/dependencies.py @@ -1,9 +1,65 @@ import networkx as nx import itertools +import re from collections import defaultdict from .errors import DataJointError +def topo_sort(graph): + """ + topological sort of a dependency graph that keeps part tables together with their masters + :return: list of table names in topological order + """ + graph = nx.DiGraph(graph) # make a copy + + # collapse alias nodes + alias_nodes = [node for node in graph if node.isdigit()] + for node in alias_nodes: + direct_edge = ( + next(x for x in graph.in_edges(node))[0], + next(x for x in graph.out_edges(node))[1], + ) + graph.add_edge(*direct_edge) + graph.remove_nodes_from(alias_nodes) + + # Add parts' dependencies to their masters' dependencies + # to ensure correct topological ordering of the masters. + part_pattern = re.compile(r"(?P`\w+`.`#?\w+)__\w+`") + for part in graph: + # print part tables and their master + match = part_pattern.match(part) + if match: + master = match["master"] + "`" + for edge in graph.in_edges(part): + if edge[0] != master: + graph.add_edge(edge[0], master) + + sorted_nodes = list(nx.algorithms.topological_sort(graph)) + + # bring parts up to their masters + pos = len(sorted_nodes) + while pos > 0: + pos -= 1 + part = sorted_nodes[pos] + match = part_pattern.match(part) + if match: + master = match["master"] + "`" + print(part, master) + try: + j = sorted_nodes.index(master) + except ValueError: + # master not found + continue + if pos > j + 1: + print(pos, j) + # move the part to its master + del sorted_nodes[pos] + sorted_nodes.insert(j + 1, part) + pos += 1 + + return sorted_nodes + + class Dependencies(nx.DiGraph): """ The graph of dependencies (foreign keys) between loaded tables. @@ -107,8 +163,8 @@ def load(self, force=True): self._loaded = True def topo_sort(self): - """:return: list of nodes in lexcigraphical topological order""" - return list(nx.algorithms.dag.lexicographical_topological_sort(self)) + """:return: list of tables names in topological order""" + return topo_sort(self) def parents(self, table_name, primary=None): """ @@ -146,7 +202,9 @@ def descendants(self, full_table_name): :return: all dependent tables sorted in topological order. Self is included. """ self.load(force=False) - nodes = self.subgraph(nx.algorithms.dag.descendants(self, full_table_name)).copy() + nodes = self.subgraph( + nx.algorithms.dag.descendants(self, full_table_name) + ).copy() return [full_table_name] + nodes.topo_sort() def ancestors(self, full_table_name): diff --git a/datajoint/diagram.py b/datajoint/diagram.py index 6ed3824b..0f8717e4 100644 --- a/datajoint/diagram.py +++ b/datajoint/diagram.py @@ -5,6 +5,7 @@ import logging import inspect from .table import Table +from .dependencies import topo_sort from .user_tables import Manual, Imported, Computed, Lookup, Part from .errors import DataJointError from .table import lookup_class_name @@ -38,6 +39,7 @@ class _AliasNode: def _get_tier(table_name): + """given the table name, return""" if not table_name.startswith("`"): return _AliasNode else: @@ -70,19 +72,22 @@ def __init__(self, *args, **kwargs): class Diagram(nx.DiGraph): """ - Entity relationship diagram. + Schema diagram showing tables and foreign keys between in the form of a directed + acyclic graph (DAG). The diagram is derived from the connection.dependencies object. Usage: >>> diag = Diagram(source) - source can be a base table object, a base table class, a schema, or a module that has a schema. + source can be a table object, a table class, a schema, or a module that has a schema. >>> diag.draw() draws the diagram using pyplot diag1 + diag2 - combines the two diagrams. + diag1 - diag2 - differente between diagrams + diag1 * diag2 - intersction of diagrams diag + n - expands n levels of successors diag - n - expands n levels of predecessors Thus dj.Diagram(schema.Table)+1-1 defines the diagram of immediate ancestors and descendants of schema.Table @@ -91,7 +96,8 @@ class Diagram(nx.DiGraph): Only those tables that are loaded in the connection object are displayed """ - def __init__(self, source, context=None): + def __init__(self, source=None, context=None): + if isinstance(source, Diagram): # copy constructor self.nodes_to_show = set(source.nodes_to_show) @@ -152,7 +158,7 @@ def from_sequence(cls, sequence): def add_parts(self): """ - Adds to the diagram the part tables of tables already included in the diagram + Adds to the diagram the part tables of all master tables already in the diagram :return: """ @@ -244,6 +250,10 @@ def __mul__(self, arg): self.nodes_to_show.intersection_update(arg.nodes_to_show) return self + def topo_sort(self): + """return nodes in lexicographical topological order""" + return topo_sort(self) + def _make_graph(self): """ Make the self.graph - a graph object ready for drawing From adfdc653c2c4f40e14665598658e2b308bd1a281 Mon Sep 17 00:00:00 2001 From: Dimitri Yatsenko Date: Sun, 15 Sep 2024 18:17:12 -0500 Subject: [PATCH 14/21] fix topological sort --- datajoint/dependencies.py | 74 ++++++++++++++++++++++----------------- datajoint/diagram.py | 26 +------------- datajoint/schemas.py | 9 ++--- datajoint/user_tables.py | 27 ++++++++++++++ 4 files changed, 73 insertions(+), 63 deletions(-) diff --git a/datajoint/dependencies.py b/datajoint/dependencies.py index a9df0c4f..4ad58527 100644 --- a/datajoint/dependencies.py +++ b/datajoint/dependencies.py @@ -4,58 +4,70 @@ from collections import defaultdict from .errors import DataJointError +def extract_master(part_table): + """ + given a part table name, return master part. None if not a part table + """ + match = re.match(r"(?P`\w+`.`#?\w+)__\w+`", part_table) + return match['master'] + '`' if match else None + + def topo_sort(graph): """ topological sort of a dependency graph that keeps part tables together with their masters :return: list of table names in topological order """ + graph = nx.DiGraph(graph) # make a copy # collapse alias nodes alias_nodes = [node for node in graph if node.isdigit()] for node in alias_nodes: - direct_edge = ( - next(x for x in graph.in_edges(node))[0], - next(x for x in graph.out_edges(node))[1], - ) - graph.add_edge(*direct_edge) + try: + direct_edge = ( + next(x for x in graph.in_edges(node))[0], + next(x for x in graph.out_edges(node))[1], + ) + except StopIteration: + pass # a disconnected alias node + else: + graph.add_edge(*direct_edge) graph.remove_nodes_from(alias_nodes) # Add parts' dependencies to their masters' dependencies # to ensure correct topological ordering of the masters. - part_pattern = re.compile(r"(?P`\w+`.`#?\w+)__\w+`") for part in graph: - # print part tables and their master - match = part_pattern.match(part) - if match: - master = match["master"] + "`" + # find the part's master + master = extract_master(part) + if master: for edge in graph.in_edges(part): - if edge[0] != master: - graph.add_edge(edge[0], master) + parent = edge[0] + if parent != master and extract_master(parent) != master: + graph.add_edge(parent, master) - sorted_nodes = list(nx.algorithms.topological_sort(graph)) + sorted_nodes = list(nx.topological_sort(graph)) # bring parts up to their masters - pos = len(sorted_nodes) - while pos > 0: - pos -= 1 + pos = len(sorted_nodes) - 1 + placed = set() + while pos > 1: part = sorted_nodes[pos] - match = part_pattern.match(part) - if match: - master = match["master"] + "`" - print(part, master) + master = extract_master(part) + if not master or part in placed: + pos -= 1 + else: + placed.add(part) try: j = sorted_nodes.index(master) except ValueError: # master not found - continue - if pos > j + 1: - print(pos, j) - # move the part to its master - del sorted_nodes[pos] - sorted_nodes.insert(j + 1, part) - pos += 1 + pass + else: + if pos > j + 1: + # move the part to its master + del sorted_nodes[pos] + sorted_nodes.insert(j + 1, part) return sorted_nodes @@ -202,10 +214,8 @@ def descendants(self, full_table_name): :return: all dependent tables sorted in topological order. Self is included. """ self.load(force=False) - nodes = self.subgraph( - nx.algorithms.dag.descendants(self, full_table_name) - ).copy() - return [full_table_name] + nodes.topo_sort() + nodes = self.subgraph(nx.descendants(self, full_table_name)) + return [full_table_name] + nodes.topo_sort() def ancestors(self, full_table_name): """ @@ -213,5 +223,5 @@ def ancestors(self, full_table_name): :return: all dependent tables sorted in topological order. Self is included. """ self.load(force=False) - nodes = self.subgraph(nx.algorithms.dag.ancestors(self, full_table_name)).copy() + nodes = self.subgraph(nx.ancestors(self, full_table_name)) return reversed(nodes.topo_sort() + [full_table_name]) diff --git a/datajoint/diagram.py b/datajoint/diagram.py index 0f8717e4..ca1df82b 100644 --- a/datajoint/diagram.py +++ b/datajoint/diagram.py @@ -6,7 +6,7 @@ import inspect from .table import Table from .dependencies import topo_sort -from .user_tables import Manual, Imported, Computed, Lookup, Part +from .user_tables import Manual, Imported, Computed, Lookup, Part, _get_tier, _AliasNode from .errors import DataJointError from .table import lookup_class_name @@ -27,30 +27,6 @@ logger = logging.getLogger(__name__.split(".")[0]) -user_table_classes = (Manual, Lookup, Computed, Imported, Part) - - -class _AliasNode: - """ - special class to indicate aliased foreign keys - """ - - pass - - -def _get_tier(table_name): - """given the table name, return""" - if not table_name.startswith("`"): - return _AliasNode - else: - try: - return next( - tier - for tier in user_table_classes - if re.fullmatch(tier.tier_regexp, table_name.split("`")[-2]) - ) - except StopIteration: - return None if not diagram_active: diff --git a/datajoint/schemas.py b/datajoint/schemas.py index 650634b8..7545f828 100644 --- a/datajoint/schemas.py +++ b/datajoint/schemas.py @@ -5,14 +5,13 @@ import collections import itertools from .connection import conn -from .diagram import Diagram, _get_tier from .settings import config from .errors import DataJointError, AccessError from .jobs import JobTable from .external import ExternalMapping from .heading import Heading from .utils import user_choice, to_camel_case -from .user_tables import Part, Computed, Imported, Manual, Lookup +from .user_tables import Part, Computed, Imported, Manual, Lookup, _get_tier from .table import lookup_class_name, Log, FreeTable import types @@ -451,10 +450,8 @@ def replace(s): ).replace("\n", "\n " + indent), ) - diagram = Diagram(self) - body = "\n\n".join( - make_class_definition(table) for table in diagram.topo_sort() - ) + tables = self.connection.dependencies.topo_sort() + body = "\n\n".join(make_class_definition(table) for table in tables) python_code = "\n\n".join( ( '"""This module was auto-generated by datajoint from an existing schema"""', diff --git a/datajoint/user_tables.py b/datajoint/user_tables.py index bcb6a027..0a784560 100644 --- a/datajoint/user_tables.py +++ b/datajoint/user_tables.py @@ -2,6 +2,7 @@ Hosts the table tiers, user tables should be derived from. """ +import re from .table import Table from .autopopulate import AutoPopulate from .utils import from_camel_case, ClassProperty @@ -242,3 +243,29 @@ def drop(self, force=False): def alter(self, prompt=True, context=None): # without context, use declaration context which maps master keyword to master table super().alter(prompt=prompt, context=context or self.declaration_context) + + +user_table_classes = (Manual, Lookup, Computed, Imported, Part) + + +class _AliasNode: + """ + special class to indicate aliased foreign keys + """ + + pass + + +def _get_tier(table_name): + """given the table name, return""" + if not table_name.startswith("`"): + return _AliasNode + else: + try: + return next( + tier + for tier in user_table_classes + if re.fullmatch(tier.tier_regexp, table_name.split("`")[-2]) + ) + except StopIteration: + return None From 24c090d5bf19daffb6f15abf627b9dea3855c6e5 Mon Sep 17 00:00:00 2001 From: Dimitri Yatsenko Date: Sun, 15 Sep 2024 18:26:49 -0500 Subject: [PATCH 15/21] debugged topological sort --- datajoint/dependencies.py | 19 ++++++++----------- tests/test_schema.py | 1 + 2 files changed, 9 insertions(+), 11 deletions(-) diff --git a/datajoint/dependencies.py b/datajoint/dependencies.py index 4ad58527..4f78ad4f 100644 --- a/datajoint/dependencies.py +++ b/datajoint/dependencies.py @@ -4,13 +4,13 @@ from collections import defaultdict from .errors import DataJointError + def extract_master(part_table): """ - given a part table name, return master part. None if not a part table + given a part table name, return master part. None if not a part table """ match = re.match(r"(?P`\w+`.`#?\w+)__\w+`", part_table) - return match['master'] + '`' if match else None - + return match["master"] + "`" if match else None def topo_sort(graph): @@ -39,13 +39,11 @@ def topo_sort(graph): # to ensure correct topological ordering of the masters. for part in graph: # find the part's master - master = extract_master(part) - if master: + if (master := extract_master(part)) in graph: for edge in graph.in_edges(part): parent = edge[0] if parent != master and extract_master(parent) != master: graph.add_edge(parent, master) - sorted_nodes = list(nx.topological_sort(graph)) # bring parts up to their masters @@ -53,8 +51,7 @@ def topo_sort(graph): placed = set() while pos > 1: part = sorted_nodes[pos] - master = extract_master(part) - if not master or part in placed: + if not (master := extract_master) or part in placed: pos -= 1 else: placed.add(part) @@ -63,7 +60,7 @@ def topo_sort(graph): except ValueError: # master not found pass - else: + else: if pos > j + 1: # move the part to its master del sorted_nodes[pos] @@ -214,8 +211,8 @@ def descendants(self, full_table_name): :return: all dependent tables sorted in topological order. Self is included. """ self.load(force=False) - nodes = self.subgraph(nx.descendants(self, full_table_name)) - return [full_table_name] + nodes.topo_sort() + nodes = self.subgraph(nx.descendants(self, full_table_name)) + return [full_table_name] + nodes.topo_sort() def ancestors(self, full_table_name): """ diff --git a/tests/test_schema.py b/tests/test_schema.py index e44ac6ad..257de221 100644 --- a/tests/test_schema.py +++ b/tests/test_schema.py @@ -217,6 +217,7 @@ def test_list_tables(schema_simp): actual = set(schema_simp.list_tables()) assert actual == expected, f"Missing from list_tables(): {expected - actual}" + def test_schema_save_any(schema_any): assert "class Experiment(dj.Imported)" in schema_any.code From 92bfd4a1b32ea039d0e096b2f7211d7c471d9cb1 Mon Sep 17 00:00:00 2001 From: Dimitri Yatsenko Date: Sun, 15 Sep 2024 18:48:32 -0500 Subject: [PATCH 16/21] debug topological sort --- datajoint/schemas.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/datajoint/schemas.py b/datajoint/schemas.py index 7545f828..c3894ba2 100644 --- a/datajoint/schemas.py +++ b/datajoint/schemas.py @@ -412,6 +412,7 @@ def save(self, python_filename=None): :return: a string containing the body of a complete Python module defining this schema. """ + self.connection.dependencies.load() self._assert_exists() module_count = itertools.count() # add virtual modules for referenced modules with names vmod0, vmod1, ... @@ -477,11 +478,12 @@ def list_tables(self): :return: A list of table names from the database schema. """ + self.connection.dependencies.load() return [ t for d, t in ( full_t.replace("`", "").split(".") - for full_t in Diagram(self).topo_sort() + for full_t in self.connection.dependencies.topo_sort() ) if d == self.database ] @@ -530,7 +532,6 @@ def __init__( def list_schemas(connection=None): """ - :param connection: a dj.Connection object :return: list of all accessible schemas on the server """ From b5e7cf94d8c75ea1348ad80c3a06bb3c1d465ee6 Mon Sep 17 00:00:00 2001 From: Dimitri Yatsenko Date: Sun, 15 Sep 2024 18:52:41 -0500 Subject: [PATCH 17/21] lint fix --- datajoint/diagram.py | 1 - 1 file changed, 1 deletion(-) diff --git a/datajoint/diagram.py b/datajoint/diagram.py index ca1df82b..451b50a4 100644 --- a/datajoint/diagram.py +++ b/datajoint/diagram.py @@ -1,5 +1,4 @@ import networkx as nx -import re import functools import io import logging From 224785e1f2c86fe3250a41876ad96e138bc3dbf4 Mon Sep 17 00:00:00 2001 From: Dimitri Yatsenko Date: Mon, 16 Sep 2024 00:03:37 +0000 Subject: [PATCH 18/21] optimize topographical sort --- datajoint/dependencies.py | 17 ++++++----------- 1 file changed, 6 insertions(+), 11 deletions(-) diff --git a/datajoint/dependencies.py b/datajoint/dependencies.py index 4f78ad4f..aefb1bd2 100644 --- a/datajoint/dependencies.py +++ b/datajoint/dependencies.py @@ -51,20 +51,15 @@ def topo_sort(graph): placed = set() while pos > 1: part = sorted_nodes[pos] - if not (master := extract_master) or part in placed: + if (master := extract_master(part)) not in graph or part in placed: pos -= 1 else: placed.add(part) - try: - j = sorted_nodes.index(master) - except ValueError: - # master not found - pass - else: - if pos > j + 1: - # move the part to its master - del sorted_nodes[pos] - sorted_nodes.insert(j + 1, part) + j = sorted_nodes.index(master) + if pos > j + 1: + # move the part to its master + del sorted_nodes[pos] + sorted_nodes.insert(j + 1, part) return sorted_nodes From 76d40ae2a737eead18ad3a5869f3011caceed0ae Mon Sep 17 00:00:00 2001 From: Dimitri Yatsenko Date: Sun, 15 Sep 2024 19:57:43 -0500 Subject: [PATCH 19/21] improve comments in topological sort --- datajoint/dependencies.py | 3 ++- datajoint/diagram.py | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/datajoint/dependencies.py b/datajoint/dependencies.py index aefb1bd2..5a34dc15 100644 --- a/datajoint/dependencies.py +++ b/datajoint/dependencies.py @@ -42,7 +42,8 @@ def topo_sort(graph): if (master := extract_master(part)) in graph: for edge in graph.in_edges(part): parent = edge[0] - if parent != master and extract_master(parent) != master: + if master not in (parent, extract_master(parent)): + # if parent is neither master nor part of master graph.add_edge(parent, master) sorted_nodes = list(nx.topological_sort(graph)) diff --git a/datajoint/diagram.py b/datajoint/diagram.py index 451b50a4..0425256d 100644 --- a/datajoint/diagram.py +++ b/datajoint/diagram.py @@ -71,7 +71,7 @@ class Diagram(nx.DiGraph): Only those tables that are loaded in the connection object are displayed """ - def __init__(self, source=None, context=None): + def __init__(self, source, context=None): if isinstance(source, Diagram): # copy constructor From 69a8c258c0e628ba235f6be35b1239212db3a156 Mon Sep 17 00:00:00 2001 From: Dimitri Yatsenko Date: Mon, 16 Sep 2024 14:37:30 -0500 Subject: [PATCH 20/21] Update datajoint/diagram.py Co-authored-by: Ethan Ho <53266718+ethho@users.noreply.github.com> --- datajoint/diagram.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datajoint/diagram.py b/datajoint/diagram.py index 0425256d..1edc62c6 100644 --- a/datajoint/diagram.py +++ b/datajoint/diagram.py @@ -61,7 +61,7 @@ class Diagram(nx.DiGraph): draws the diagram using pyplot diag1 + diag2 - combines the two diagrams. - diag1 - diag2 - differente between diagrams + diag1 - diag2 - difference between diagrams diag1 * diag2 - intersction of diagrams diag + n - expands n levels of successors diag - n - expands n levels of predecessors From 477c326af63fc87cc3bc9ca97dccc8ece17ef7e3 Mon Sep 17 00:00:00 2001 From: Dimitri Yatsenko Date: Mon, 16 Sep 2024 14:37:39 -0500 Subject: [PATCH 21/21] Update datajoint/diagram.py Co-authored-by: Ethan Ho <53266718+ethho@users.noreply.github.com> --- datajoint/diagram.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datajoint/diagram.py b/datajoint/diagram.py index 1edc62c6..aeced065 100644 --- a/datajoint/diagram.py +++ b/datajoint/diagram.py @@ -62,7 +62,7 @@ class Diagram(nx.DiGraph): diag1 + diag2 - combines the two diagrams. diag1 - diag2 - difference between diagrams - diag1 * diag2 - intersction of diagrams + diag1 * diag2 - intersection of diagrams diag + n - expands n levels of successors diag - n - expands n levels of predecessors Thus dj.Diagram(schema.Table)+1-1 defines the diagram of immediate ancestors and descendants of schema.Table