Skip to content

Commit

Permalink
Ensure correct types for QueryBuilder().dict() with multiple projec…
Browse files Browse the repository at this point in the history
…tions

The returned results by the backend implementation of `QueryBuilder` are
passed through `get_aiida_entity_res` to convert backend instances to
front end class instances. It calls `aiida.orm.convert.get_orm_entity`
which is a singledispatch to convert all known backend types to its
corresponding front-end ORM analogue. The registered implementation for
`Mapping` contained the bug. It used a simple comprehension and did not
catch any `TypeError` that might be thrown from values that could not be
converted. This would bubble up to `get_aiida_entity_res` which would
then simply return the original value.

If the mapping contains a mixture of backend entities and normal types,
the entire converting would be undone. This was surfaced when calling
`dict` on a query builder instance with `project=['*', 'id']`. The
returned value for each match is a dictionary with one value an integer
corresponding to the `id` and the other value a backend node instance.
The integer would raise the `TypeError` in the `Mapping` converter and
since it wasn't caught, the backend node was also not converted.
  • Loading branch information
sphuber committed Dec 20, 2019
1 parent 0834731 commit 9280fd5
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 2 deletions.
15 changes: 15 additions & 0 deletions aiida/backends/tests/orm/test_querybuilder.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,6 +363,21 @@ def test_simple_query_2(self):
self.assertTrue(id(query1) != id(query2))
self.assertTrue(id(query2) == id(query3))

def test_dict_multiple_projections(self):
"""Test that the `.dict()` accumulator with multiple projections returns the correct types."""
node = orm.Data().store()
builder = orm.QueryBuilder().append(orm.Data, project=['*', 'id'])
results = builder.dict()

self.assertIsInstance(results, list)
self.assertTrue(all(isinstance(value, dict) for value in results))

dictionary = list(results[0].values())[0] # `results` should have the form [{'Data_1': {'*': Node, 'id': 1}}]

self.assertIsInstance(dictionary['*'], orm.Data)
self.assertEqual(dictionary['*'].pk, node.pk)
self.assertEqual(dictionary['id'], node.pk)

def test_operators_eq_lt_gt(self):
nodes = [orm.Data() for _ in range(8)]

Expand Down
33 changes: 32 additions & 1 deletion aiida/orm/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,38 @@ def get_orm_entity(backend_entity):

@get_orm_entity.register(Mapping)
def _(backend_entity):
return {key: get_orm_entity(value) for key, value in backend_entity.items()}
"""Convert all values of the given mapping to ORM entities if they are backend ORM instances."""
converted = {}

# Note that we cannot use a simple comprehension because raised `TypeError` should be caught here otherwise only
# parts of the mapping will be converted.
for key, value in backend_entity.items():
try:
converted[key] = get_orm_entity(value)
except TypeError:
converted[key] = value

return converted


@get_orm_entity.register(list)
@get_orm_entity.register(tuple)
def _(backend_entity):
"""Convert all values of the given list or tuple to ORM entities if they are backend ORM instances.
Note that we do not register on `collections.abc.Sequence` because that will also match strings.
"""
converted = []

# Note that we cannot use a simple comprehension because raised `TypeError` should be caught here otherwise only
# parts of the mapping will be converted.
for value in backend_entity:
try:
converted.append(get_orm_entity(value))
except TypeError:
converted.append(value)

return converted


@get_orm_entity.register(BackendGroup)
Expand Down
1 change: 0 additions & 1 deletion aiida/orm/querybuilder.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
An instance of one of the implementation classes becomes a member of the :func:`QueryBuilder` instance
when instantiated by the user.
"""

# Checking for correct input with the inspect module
from inspect import isclass as inspect_isclass
import copy
Expand Down

0 comments on commit 9280fd5

Please sign in to comment.