Skip to content

Commit

Permalink
DEV-2875: resolve link resolution by label or name (#242)
Browse files Browse the repository at this point in the history
Update logging usage to stop using root logger.
Fixes issue with resolving links by name and allow data generation to fail for invalid links
  • Loading branch information
kulgan authored Jul 29, 2024
1 parent df81670 commit dec7e6a
Show file tree
Hide file tree
Showing 7 changed files with 159 additions and 59 deletions.
4 changes: 2 additions & 2 deletions .secrets.baseline
Original file line number Diff line number Diff line change
Expand Up @@ -144,10 +144,10 @@
"filename": "test/test_mocks.py",
"hashed_secret": "99ffd88615aaba1c26dbae068b6b6360bd1358a4",
"is_verified": false,
"line_number": 171,
"line_number": 209,
"is_secret": false
}
]
},
"generated_at": "2024-06-26T17:01:47Z"
"generated_at": "2024-07-26T21:01:14Z"
}
152 changes: 109 additions & 43 deletions src/psqlgraph/hydrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@

import psqlgraph
from psqlgraph import Node
from psqlgraph.exc import PSQLGraphError

logger = logging.getLogger(__name__)


class Randomizer:
Expand Down Expand Up @@ -168,7 +171,7 @@ def __init__(self, properties):
self.properties.get(name, {})
)
except ValueError as ve:
logging.debug(
logger.debug(
"Property: '{}' is most likely a relationship. Error: {}" "".format(name, ve)
)

Expand Down Expand Up @@ -255,7 +258,7 @@ def create(

_, value = self.property_factories[label].create(prop, override_val)
except (KeyError, ValueError):
logging.debug(f"No factory for property: '{prop}'")
logger.debug(f"No factory for property: '{prop}'")
continue

node_json["properties"][prop] = value
Expand Down Expand Up @@ -304,6 +307,7 @@ def create_from_nodes_and_edges(
edges: List[Dict[str, str]],
unique_key: str = "submitter_id",
all_props: bool = False,
strict: bool = False,
) -> List[Node]:
"""Create a graph from nodes and edges.
Expand All @@ -318,6 +322,7 @@ def create_from_nodes_and_edges(
[{'src': 'read_group_1', 'dst': 'aliquot_1'}]
unique_key: a name of the property that will be used to connect nodes
all_props: generate all node properties or not
strict: raises error if invalid links are provided
Returns:
List of psqlgraph nodes
Expand All @@ -344,10 +349,10 @@ def create_from_nodes_and_edges(
node2 = nodes_map.get(sub_id2)

if not node1 or not node2:
logging.debug(f"Could not find nodes for edge: '{sub_id1}'<->'{sub_id2}'")
logger.debug(f"Could not find nodes for edge: '{sub_id1}'<->'{sub_id2}'")
continue

self.make_association(node1, node2, edge_label)
self.make_association(node1, node2, edge_label, strict)

return list(nodes_map.values())

Expand Down Expand Up @@ -509,57 +514,118 @@ def is_parent_relation(self, label, relation):
return relation in links

def make_association(
self, src_node: Node, dst_node: Node, edge_label: Optional[str] = None
self,
src_node: Node,
dst_node: Node,
edge_label: Optional[str] = None,
strict: bool = False,
) -> None:
"""Create an Edge between 2 nodes
"""Create an Edge between two nodes
Given 2 instances of a Node, find appropriate association between the
2 nodes and create a relation between them
two nodes and create a relation between them
There are some special cases like auxiliary_files and
structural_variant_calling_workflow, there are 2 different edges between them
structural_variant_calling_workflow, there are two different edges between them
in opposite directions. In this case, we use the label to differentiate them.
Bonus: Label could be added to all edges and will make the lookup faster.
Args:
src_node: first node of the edge
dst_node: second node of the edge
edge_label: label of the edge
src_node: source node of the edge
dst_node: destination node of the edge
edge_label: identifier defined in the dictionary `links` section. It can be
the name of the edge, or the label
strict: raise error is edge is invalid
Raises:
PSQLGraphError if either no association is found or multiple associations are found
for the source and destination nodes. Specifying a label reduces the chances of finding
multiple associations.
"""
association_name = None
if edge_label:
edge_class = self.models.Edge.get_subclass(edge_label)
if not edge_class:
logging.warning(f"Edge with label {edge_label} not found")
elif (
edge_class.__src_class__ == src_node.__class__.__name__
and edge_class.__dst_class__ == dst_node.__class__.__name__
):
getattr(src_node, edge_class.__src_dst_assoc__).append(dst_node)
return
elif (
edge_class.__src_class__ == dst_node.__class__.__name__
and edge_class.__dst_class__ == src_node.__class__.__name__
):
getattr(dst_node, edge_class.__src_dst_assoc__).append(src_node)
return
else:
logging.warning(
"Edge with label {} is not allowed between nodes {} and {}".format(
edge_label, src_node.label, dst_node.label
)
)
# attempt to get association using link name - e.g., performed_on
association_name = self.get_association_by_edge_label(src_node, dst_node, edge_label)

link_found = False
for assoc_name, assoc_meta in src_node._pg_edges.items():
if isinstance(dst_node, assoc_meta["type"]):
getattr(src_node, assoc_name).append(dst_node)
link_found = True
break

if not link_found:
logging.debug(
"Could not find a direct relation between '{}'<->'{}'".format(
src_node.label, dst_node.label
if association_name:
getattr(src_node, association_name).append(dst_node)
return

# attempt to get association by going through all links defined on the source node if needed.
association_names = self.get_association_by_edge_name(src_node, dst_node, edge_label)

if len(association_names) == 1:
getattr(src_node, association_names[0]).append(dst_node)

elif len(association_names) > 1:
if strict:
raise PSQLGraphError(
f"Multiple associations '{association_names}' found between "
f"'{src_node.label}' and '{dst_node.label}'"
)
logger.warning(
"Multiple association '%s' found between '%s' and '%s'",
association_names,
src_node.label,
dst_node.label,
)
else:
if strict:
raise PSQLGraphError(
f"Could not find a direct relation (edge name or label = '{edge_label}') "
f"between '{src_node.label}' and '{dst_node.label}' "
)
logger.warning(
"Could not find a direct relation (edge name or label = '%s') between '%s' and '%s'",
edge_label,
src_node.label,
dst_node.label,
)

# try the reverse
self.make_association(dst_node, src_node, edge_label, strict=True)

def get_association_by_edge_name(
self, src_node: psqlgraph.Node, dst_node: psqlgraph.Node, edge_name: Optional[str] = None
) -> List[str]:
"""Get the association name used to link the src and dst nodes
Args:
src_node: the source node
dst_node: the destination node
edge_name: the edge name as defined in the dictionary.
If None, a unique association between the two nodes will be used.
An exception is raised if no unique association is found.
Returns:
The name of the edge from the source node.
For example, if the source node is case and the destination
node is aliquot, the name of the edge from the `case` node is `aliquots`
"""
association_names = []
for assoc_name, assoc_meta in src_node._pg_edges.items(): # noqa
if edge_name and assoc_name != edge_name:
continue

if isinstance(dst_node, assoc_meta["type"]):
association_names.append(assoc_name)
return association_names

def get_association_by_edge_label(
self, src_node: psqlgraph.Node, dst_node: psqlgraph.Node, edge_label: str
) -> Optional[str]:
"""Get association name for the unique combination of src and dst node
Args:
src_node: source node
dst_node: destination node
edge_label: dictionary defined edge label - see links section in the dictionary
Returns:
association name
"""
logger.debug("Resolving association using edge labels")
edge_class = self.models.Edge.get_unique_subclass(
src_node.label, edge_label, dst_node.label
)
if edge_class:
return edge_class.__src_dst_assoc__
return None
11 changes: 5 additions & 6 deletions src/psqlgraph/psql.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def __init__(self, host, user, password, database, **kwargs):
user=user, password=password, host=host, database=database
)
if kwargs["isolation_level"] not in self.acceptable_isolation_levels:
logging.warning(
logger.warning(
(
"Using an isolation level '{}' that is not in the list of "
"acceptable isolation levels {} is not safe and should be "
Expand Down Expand Up @@ -229,7 +229,7 @@ def session_scope(
local.commit()

except Exception as msg:
logging.error(f"Rolling back session {msg}")
logger.error(f"Rolling back session {msg}")
local.rollback()
raise

Expand All @@ -256,10 +256,9 @@ def _configure_driver_mappers(self):
try:
configure_mappers()
except Exception as e:
logging.error(
("{}: Unable to configure mappers. " "Have you imported your models?").format(
str(e)
)
logger.error(
"{}: Unable to configure mappers. "
"Have you imported your models?".format(str(e))
)

def __expand_query(self, query=None):
Expand Down
7 changes: 4 additions & 3 deletions src/psqlgraph/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,15 @@

# PsqlNode modules
DEFAULT_RETRIES = 0
logger = logging.getLogger(__name__)


def validate(f, value, types, enum=None):
"""Validation decorator types for hybrid_properties"""
if enum:
if value not in enum and value is not None:
raise ValidationError(
("Value '{}' not in allowed value list for {} for property {}.").format(
"Value '{}' not in allowed value list for {} for property {}.".format(
value, enum, f.__name__
)
)
Expand Down Expand Up @@ -113,9 +114,9 @@ def wrapper(*args, **kwargs):
try:
return func(*args, **kwargs)
except IntegrityError:
logging.debug(f"Race-condition caught? ({retries}/{max_retries} retries)")
logger.debug(f"Race-condition caught? ({retries}/{max_retries} retries)")
if retries >= max_retries:
logging.error(f"Unable to execute {func}, max retries exceeded")
logger.error(f"Unable to execute {func}, max retries exceeded")
raise
retries += 1
backoff(retries, max_retries)
Expand Down
40 changes: 39 additions & 1 deletion test/test_mocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import pytest

from psqlgraph import Edge, Node
from psqlgraph.exc import PSQLGraphError
from psqlgraph.mocks import GraphFactory, NodeFactory

STRING_MATCH = "[a-zA-Z0-9]{32}"
Expand Down Expand Up @@ -94,7 +95,10 @@ def test_init_graph_factory(gdcmodels, gdcdictionary):
_ = GraphFactory(gdcmodels, gdcdictionary)


def test_graph_factory_with_nodes_and_edges(gdcmodels, gdcdictionary):
def test_graph_factory__strict_with_invalid_edge(
gdcmodels: FakeModels, gdcdictionary: models.FakeDictionary
) -> None:
"""Confirm invalid edges raises exception when strict is set to True."""
gf = GraphFactory(gdcmodels, gdcdictionary)

foobar_uuids = [str(uuid.uuid4())]
Expand All @@ -120,6 +124,40 @@ def test_graph_factory_with_nodes_and_edges(gdcmodels, gdcdictionary):
{"src": foo_uuids[1], "dst": foobar_uuids[0]}, # f1 -> fb0
]

with pytest.raises(PSQLGraphError):
gf.create_from_nodes_and_edges(
nodes=nodes, edges=edges, unique_key="node_id", strict=True
)


def test_graph_factory_with_nodes_and_edges(
gdcmodels: FakeModels, gdcdictionary: models.FakeDictionary
) -> None:
"""Test GraphFactory can successfully create nodes and edges."""
gf = GraphFactory(gdcmodels, gdcdictionary)

foobar_uuids = [str(uuid.uuid4())]
foo_uuids = [str(uuid.uuid4()), str(uuid.uuid4())]
test_uuids = [str(uuid.uuid4()), str(uuid.uuid4()), str(uuid.uuid4())]

nodes = [
{"label": "test", "node_id": test_uuids[0]},
{"label": "test", "node_id": test_uuids[1]},
{"label": "test", "node_id": test_uuids[2]},
{"label": "foo", "node_id": foo_uuids[0]},
{"label": "foo", "node_id": foo_uuids[1]},
{"label": "foo_bar", "node_id": foobar_uuids[0]},
]

edges = [
{"src": test_uuids[0], "dst": test_uuids[1]}, # t0 -> t1
{"src": test_uuids[0], "dst": foo_uuids[0]}, # t0 -> f0
{"src": test_uuids[1], "dst": foo_uuids[1]}, # t1 -> f1
{"src": test_uuids[2], "dst": foo_uuids[1]}, # t2 -> f1
{"src": foo_uuids[0], "dst": foobar_uuids[0]}, # f0 -> fb0
{"src": foo_uuids[1], "dst": foobar_uuids[0]}, # f1 -> fb0
]

created_nodes = gf.create_from_nodes_and_edges(nodes=nodes, edges=edges, unique_key="node_id")

expected_adjacency = defaultdict(set)
Expand Down
3 changes: 0 additions & 3 deletions test/test_psqlgraph.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import logging
import random
import unittest
import uuid
Expand All @@ -17,8 +16,6 @@
from psqlgraph import PolyNode as PsqlNode
from psqlgraph import VoidedEdge, sanitize

logging.basicConfig(level=logging.DEBUG)


def timestamp():
return str(datetime.now())
Expand Down
1 change: 0 additions & 1 deletion test/test_psqlgraph2.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from psqlgraph import PsqlGraphDriver
from psqlgraph.exc import SessionClosedError, ValidationError

logging.basicConfig(level=logging.DEBUG)
log = logging.getLogger(__name__)


Expand Down

0 comments on commit dec7e6a

Please sign in to comment.