From dec7e6ab9a89803252e2b52e57745d23afe719bc Mon Sep 17 00:00:00 2001 From: Rowland Ogwara Date: Mon, 29 Jul 2024 09:39:10 -0500 Subject: [PATCH] DEV-2875: resolve link resolution by label or name (#242) Update logging usage to stop using root logger. Fixes issue with resolving links by name and allow data generation to fail for invalid links --- .secrets.baseline | 4 +- src/psqlgraph/hydrator.py | 152 +++++++++++++++++++++++++++----------- src/psqlgraph/psql.py | 11 ++- src/psqlgraph/util.py | 7 +- test/test_mocks.py | 40 +++++++++- test/test_psqlgraph.py | 3 - test/test_psqlgraph2.py | 1 - 7 files changed, 159 insertions(+), 59 deletions(-) diff --git a/.secrets.baseline b/.secrets.baseline index fe2f4c4..954f8ca 100644 --- a/.secrets.baseline +++ b/.secrets.baseline @@ -144,10 +144,10 @@ "filename": "test/test_mocks.py", "hashed_secret": "99ffd88615aaba1c26dbae068b6b6360bd1358a4", "is_verified": false, - "line_number": 171, + "line_number": 209, "is_secret": false } ] }, - "generated_at": "2024-06-26T17:01:47Z" + "generated_at": "2024-07-26T21:01:14Z" } diff --git a/src/psqlgraph/hydrator.py b/src/psqlgraph/hydrator.py index 3f7a195..c2dea67 100644 --- a/src/psqlgraph/hydrator.py +++ b/src/psqlgraph/hydrator.py @@ -10,6 +10,9 @@ import psqlgraph from psqlgraph import Node +from psqlgraph.exc import PSQLGraphError + +logger = logging.getLogger(__name__) class Randomizer: @@ -168,7 +171,7 @@ def __init__(self, properties): self.properties.get(name, {}) ) except ValueError as ve: - logging.debug( + logger.debug( "Property: '{}' is most likely a relationship. Error: {}" "".format(name, ve) ) @@ -255,7 +258,7 @@ def create( _, value = self.property_factories[label].create(prop, override_val) except (KeyError, ValueError): - logging.debug(f"No factory for property: '{prop}'") + logger.debug(f"No factory for property: '{prop}'") continue node_json["properties"][prop] = value @@ -304,6 +307,7 @@ def create_from_nodes_and_edges( edges: List[Dict[str, str]], unique_key: str = "submitter_id", all_props: bool = False, + strict: bool = False, ) -> List[Node]: """Create a graph from nodes and edges. @@ -318,6 +322,7 @@ def create_from_nodes_and_edges( [{'src': 'read_group_1', 'dst': 'aliquot_1'}] unique_key: a name of the property that will be used to connect nodes all_props: generate all node properties or not + strict: raises error if invalid links are provided Returns: List of psqlgraph nodes @@ -344,10 +349,10 @@ def create_from_nodes_and_edges( node2 = nodes_map.get(sub_id2) if not node1 or not node2: - logging.debug(f"Could not find nodes for edge: '{sub_id1}'<->'{sub_id2}'") + logger.debug(f"Could not find nodes for edge: '{sub_id1}'<->'{sub_id2}'") continue - self.make_association(node1, node2, edge_label) + self.make_association(node1, node2, edge_label, strict) return list(nodes_map.values()) @@ -509,57 +514,118 @@ def is_parent_relation(self, label, relation): return relation in links def make_association( - self, src_node: Node, dst_node: Node, edge_label: Optional[str] = None + self, + src_node: Node, + dst_node: Node, + edge_label: Optional[str] = None, + strict: bool = False, ) -> None: - """Create an Edge between 2 nodes + """Create an Edge between two nodes Given 2 instances of a Node, find appropriate association between the - 2 nodes and create a relation between them + two nodes and create a relation between them There are some special cases like auxiliary_files and - structural_variant_calling_workflow, there are 2 different edges between them + structural_variant_calling_workflow, there are two different edges between them in opposite directions. In this case, we use the label to differentiate them. Bonus: Label could be added to all edges and will make the lookup faster. Args: - src_node: first node of the edge - dst_node: second node of the edge - edge_label: label of the edge + src_node: source node of the edge + dst_node: destination node of the edge + edge_label: identifier defined in the dictionary `links` section. It can be + the name of the edge, or the label + strict: raise error is edge is invalid + Raises: + PSQLGraphError if either no association is found or multiple associations are found + for the source and destination nodes. Specifying a label reduces the chances of finding + multiple associations. """ + association_name = None if edge_label: - edge_class = self.models.Edge.get_subclass(edge_label) - if not edge_class: - logging.warning(f"Edge with label {edge_label} not found") - elif ( - edge_class.__src_class__ == src_node.__class__.__name__ - and edge_class.__dst_class__ == dst_node.__class__.__name__ - ): - getattr(src_node, edge_class.__src_dst_assoc__).append(dst_node) - return - elif ( - edge_class.__src_class__ == dst_node.__class__.__name__ - and edge_class.__dst_class__ == src_node.__class__.__name__ - ): - getattr(dst_node, edge_class.__src_dst_assoc__).append(src_node) - return - else: - logging.warning( - "Edge with label {} is not allowed between nodes {} and {}".format( - edge_label, src_node.label, dst_node.label - ) - ) + # attempt to get association using link name - e.g., performed_on + association_name = self.get_association_by_edge_label(src_node, dst_node, edge_label) - link_found = False - for assoc_name, assoc_meta in src_node._pg_edges.items(): - if isinstance(dst_node, assoc_meta["type"]): - getattr(src_node, assoc_name).append(dst_node) - link_found = True - break - - if not link_found: - logging.debug( - "Could not find a direct relation between '{}'<->'{}'".format( - src_node.label, dst_node.label + if association_name: + getattr(src_node, association_name).append(dst_node) + return + + # attempt to get association by going through all links defined on the source node if needed. + association_names = self.get_association_by_edge_name(src_node, dst_node, edge_label) + + if len(association_names) == 1: + getattr(src_node, association_names[0]).append(dst_node) + + elif len(association_names) > 1: + if strict: + raise PSQLGraphError( + f"Multiple associations '{association_names}' found between " + f"'{src_node.label}' and '{dst_node.label}'" ) + logger.warning( + "Multiple association '%s' found between '%s' and '%s'", + association_names, + src_node.label, + dst_node.label, ) + else: + if strict: + raise PSQLGraphError( + f"Could not find a direct relation (edge name or label = '{edge_label}') " + f"between '{src_node.label}' and '{dst_node.label}' " + ) + logger.warning( + "Could not find a direct relation (edge name or label = '%s') between '%s' and '%s'", + edge_label, + src_node.label, + dst_node.label, + ) + + # try the reverse + self.make_association(dst_node, src_node, edge_label, strict=True) + + def get_association_by_edge_name( + self, src_node: psqlgraph.Node, dst_node: psqlgraph.Node, edge_name: Optional[str] = None + ) -> List[str]: + """Get the association name used to link the src and dst nodes + Args: + src_node: the source node + dst_node: the destination node + edge_name: the edge name as defined in the dictionary. + If None, a unique association between the two nodes will be used. + An exception is raised if no unique association is found. + Returns: + The name of the edge from the source node. + For example, if the source node is case and the destination + node is aliquot, the name of the edge from the `case` node is `aliquots` + """ + association_names = [] + for assoc_name, assoc_meta in src_node._pg_edges.items(): # noqa + if edge_name and assoc_name != edge_name: + continue + + if isinstance(dst_node, assoc_meta["type"]): + association_names.append(assoc_name) + return association_names + + def get_association_by_edge_label( + self, src_node: psqlgraph.Node, dst_node: psqlgraph.Node, edge_label: str + ) -> Optional[str]: + """Get association name for the unique combination of src and dst node + + Args: + src_node: source node + dst_node: destination node + edge_label: dictionary defined edge label - see links section in the dictionary + + Returns: + association name + """ + logger.debug("Resolving association using edge labels") + edge_class = self.models.Edge.get_unique_subclass( + src_node.label, edge_label, dst_node.label + ) + if edge_class: + return edge_class.__src_dst_assoc__ + return None diff --git a/src/psqlgraph/psql.py b/src/psqlgraph/psql.py index 767d218..310832a 100644 --- a/src/psqlgraph/psql.py +++ b/src/psqlgraph/psql.py @@ -70,7 +70,7 @@ def __init__(self, host, user, password, database, **kwargs): user=user, password=password, host=host, database=database ) if kwargs["isolation_level"] not in self.acceptable_isolation_levels: - logging.warning( + logger.warning( ( "Using an isolation level '{}' that is not in the list of " "acceptable isolation levels {} is not safe and should be " @@ -229,7 +229,7 @@ def session_scope( local.commit() except Exception as msg: - logging.error(f"Rolling back session {msg}") + logger.error(f"Rolling back session {msg}") local.rollback() raise @@ -256,10 +256,9 @@ def _configure_driver_mappers(self): try: configure_mappers() except Exception as e: - logging.error( - ("{}: Unable to configure mappers. " "Have you imported your models?").format( - str(e) - ) + logger.error( + "{}: Unable to configure mappers. " + "Have you imported your models?".format(str(e)) ) def __expand_query(self, query=None): diff --git a/src/psqlgraph/util.py b/src/psqlgraph/util.py index 979ff72..7480ec2 100644 --- a/src/psqlgraph/util.py +++ b/src/psqlgraph/util.py @@ -10,6 +10,7 @@ # PsqlNode modules DEFAULT_RETRIES = 0 +logger = logging.getLogger(__name__) def validate(f, value, types, enum=None): @@ -17,7 +18,7 @@ def validate(f, value, types, enum=None): if enum: if value not in enum and value is not None: raise ValidationError( - ("Value '{}' not in allowed value list for {} for property {}.").format( + "Value '{}' not in allowed value list for {} for property {}.".format( value, enum, f.__name__ ) ) @@ -113,9 +114,9 @@ def wrapper(*args, **kwargs): try: return func(*args, **kwargs) except IntegrityError: - logging.debug(f"Race-condition caught? ({retries}/{max_retries} retries)") + logger.debug(f"Race-condition caught? ({retries}/{max_retries} retries)") if retries >= max_retries: - logging.error(f"Unable to execute {func}, max retries exceeded") + logger.error(f"Unable to execute {func}, max retries exceeded") raise retries += 1 backoff(retries, max_retries) diff --git a/test/test_mocks.py b/test/test_mocks.py index ce82bf1..f402213 100644 --- a/test/test_mocks.py +++ b/test/test_mocks.py @@ -7,6 +7,7 @@ import pytest from psqlgraph import Edge, Node +from psqlgraph.exc import PSQLGraphError from psqlgraph.mocks import GraphFactory, NodeFactory STRING_MATCH = "[a-zA-Z0-9]{32}" @@ -94,7 +95,10 @@ def test_init_graph_factory(gdcmodels, gdcdictionary): _ = GraphFactory(gdcmodels, gdcdictionary) -def test_graph_factory_with_nodes_and_edges(gdcmodels, gdcdictionary): +def test_graph_factory__strict_with_invalid_edge( + gdcmodels: FakeModels, gdcdictionary: models.FakeDictionary +) -> None: + """Confirm invalid edges raises exception when strict is set to True.""" gf = GraphFactory(gdcmodels, gdcdictionary) foobar_uuids = [str(uuid.uuid4())] @@ -120,6 +124,40 @@ def test_graph_factory_with_nodes_and_edges(gdcmodels, gdcdictionary): {"src": foo_uuids[1], "dst": foobar_uuids[0]}, # f1 -> fb0 ] + with pytest.raises(PSQLGraphError): + gf.create_from_nodes_and_edges( + nodes=nodes, edges=edges, unique_key="node_id", strict=True + ) + + +def test_graph_factory_with_nodes_and_edges( + gdcmodels: FakeModels, gdcdictionary: models.FakeDictionary +) -> None: + """Test GraphFactory can successfully create nodes and edges.""" + gf = GraphFactory(gdcmodels, gdcdictionary) + + foobar_uuids = [str(uuid.uuid4())] + foo_uuids = [str(uuid.uuid4()), str(uuid.uuid4())] + test_uuids = [str(uuid.uuid4()), str(uuid.uuid4()), str(uuid.uuid4())] + + nodes = [ + {"label": "test", "node_id": test_uuids[0]}, + {"label": "test", "node_id": test_uuids[1]}, + {"label": "test", "node_id": test_uuids[2]}, + {"label": "foo", "node_id": foo_uuids[0]}, + {"label": "foo", "node_id": foo_uuids[1]}, + {"label": "foo_bar", "node_id": foobar_uuids[0]}, + ] + + edges = [ + {"src": test_uuids[0], "dst": test_uuids[1]}, # t0 -> t1 + {"src": test_uuids[0], "dst": foo_uuids[0]}, # t0 -> f0 + {"src": test_uuids[1], "dst": foo_uuids[1]}, # t1 -> f1 + {"src": test_uuids[2], "dst": foo_uuids[1]}, # t2 -> f1 + {"src": foo_uuids[0], "dst": foobar_uuids[0]}, # f0 -> fb0 + {"src": foo_uuids[1], "dst": foobar_uuids[0]}, # f1 -> fb0 + ] + created_nodes = gf.create_from_nodes_and_edges(nodes=nodes, edges=edges, unique_key="node_id") expected_adjacency = defaultdict(set) diff --git a/test/test_psqlgraph.py b/test/test_psqlgraph.py index c5591bb..b13295d 100644 --- a/test/test_psqlgraph.py +++ b/test/test_psqlgraph.py @@ -1,4 +1,3 @@ -import logging import random import unittest import uuid @@ -17,8 +16,6 @@ from psqlgraph import PolyNode as PsqlNode from psqlgraph import VoidedEdge, sanitize -logging.basicConfig(level=logging.DEBUG) - def timestamp(): return str(datetime.now()) diff --git a/test/test_psqlgraph2.py b/test/test_psqlgraph2.py index 2ed33c9..3efc4b2 100644 --- a/test/test_psqlgraph2.py +++ b/test/test_psqlgraph2.py @@ -11,7 +11,6 @@ from psqlgraph import PsqlGraphDriver from psqlgraph.exc import SessionClosedError, ValidationError -logging.basicConfig(level=logging.DEBUG) log = logging.getLogger(__name__)