Skip to content

Commit

Permalink
♻️ REFACTOR: EntityExtrasMixin -> EntityExtras (#5445)
Browse files Browse the repository at this point in the history
This commit is part of the `Node` namespace restructure. It moves all
methods related to extras to a new `EntityExtras` class. This is made
available for the `Node` and `Group` entities through the `base.extras`
attribute.
  • Loading branch information
chrisjsewell authored Apr 8, 2022
1 parent 18e3626 commit d34c3ad
Show file tree
Hide file tree
Showing 28 changed files with 518 additions and 414 deletions.
8 changes: 4 additions & 4 deletions .github/system_tests/test_daemon.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,15 +200,15 @@ def validate_cached(cached_calcs):
print_report(calc.pk)
valid = False

if '_aiida_cached_from' not in calc.extras or calc.get_hash() != calc.get_extra('_aiida_hash'):
if '_aiida_cached_from' not in calc.base.extras or calc.get_hash() != calc.base.extras.get('_aiida_hash'):
print(f'Cached calculation<{calc.pk}> has invalid hash')
print_report(calc.pk)
valid = False

if isinstance(calc, CalcJobNode):
original_calc = load_node(calc.get_extra('_aiida_cached_from'))
files_original = original_calc.list_object_names()
files_cached = calc.list_object_names()
original_calc = load_node(calc.base.extras.get('_aiida_cached_from'))
files_original = original_calc.base.repository.list_object_names()
files_cached = calc.base.repository.list_object_names()

if not files_cached:
print(f'Cached calculation <{calc.pk}> does not have any raw inputs files')
Expand Down
1 change: 1 addition & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ repos:
aiida/orm/implementation/users.py|
aiida/orm/implementation/querybuilder.py|
aiida/orm/entities.py|
aiida/orm/extras.py|
aiida/orm/authinfos.py|
aiida/orm/comments.py|
aiida/orm/computers.py|
Expand Down
2 changes: 1 addition & 1 deletion aiida/cmdline/commands/cmd_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@ def echo_node_dict(nodes, keys, fmt, identifier, raw, use_attrs=True):
node_dict = node.base.attributes.all
dict_name = 'attributes'
else:
node_dict = node.extras
node_dict = node.base.extras.all
dict_name = 'extras'

if keys is not None:
Expand Down
4 changes: 2 additions & 2 deletions aiida/engine/processes/workchains/restart.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,9 +185,9 @@ def run_process(self) -> ToContext:
# Add a new empty list to the `BaseRestartWorkChain._considered_handlers_extra` extra. This will contain the
# name and return value of all class methods, decorated with `process_handler`, that are called during
# the `inspect_process` outline step.
considered_handlers = self.node.get_extra(self._considered_handlers_extra, [])
considered_handlers = self.node.base.extras.get(self._considered_handlers_extra, [])
considered_handlers.append([])
self.node.set_extra(self._considered_handlers_extra, considered_handlers)
self.node.base.extras.set(self._considered_handlers_extra, considered_handlers)

self.report(f'launching {self.ctx.process_name}<{node.pk}> iteration #{self.ctx.iteration}')

Expand Down
4 changes: 2 additions & 2 deletions aiida/engine/processes/workchains/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def wrapper(wrapped, instance, args, kwargs):

# Append the name and return value of the current process handler to the `considered_handlers` extra.
try:
considered_handlers = instance.node.get_extra(instance._considered_handlers_extra, []) # pylint: disable=protected-access
considered_handlers = instance.node.base.extras.get(instance._considered_handlers_extra, []) # pylint: disable=protected-access
current_process = considered_handlers[-1]
except IndexError:
# The extra was never initialized, so we skip this functionality
Expand All @@ -132,7 +132,7 @@ def wrapper(wrapped, instance, args, kwargs):
if isinstance(serialized, ProcessHandlerReport):
serialized = {'do_break': serialized.do_break, 'exit_status': serialized.exit_code.status}
current_process.append((wrapped.__name__, serialized))
instance.node.set_extra(instance._considered_handlers_extra, considered_handlers) # pylint: disable=protected-access
instance.node.base.extras.set(instance._considered_handlers_extra, considered_handlers) # pylint: disable=protected-access

return result

Expand Down
3 changes: 2 additions & 1 deletion aiida/orm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from .comments import *
from .computers import *
from .entities import *
from .extras import *
from .groups import *
from .logs import *
from .nodes import *
Expand Down Expand Up @@ -51,7 +52,7 @@
'Data',
'Dict',
'Entity',
'EntityExtrasMixin',
'EntityExtras',
'EntityTypes',
'EnumData',
'Float',
Expand Down
155 changes: 2 additions & 153 deletions aiida/orm/entities.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,9 @@
###########################################################################
"""Module for all common top level AiiDA entity classes and methods"""
import abc
import copy
from enum import Enum
from functools import lru_cache
from typing import TYPE_CHECKING, Any, Dict, Generic, List, Optional, Protocol, Type, TypeVar, cast
from typing import TYPE_CHECKING, Any, Generic, List, Optional, Type, TypeVar, cast

from plumpy.base.utils import call_with_super_check, super_check

Expand All @@ -23,14 +22,12 @@
from aiida.orm.implementation import BackendEntity, StorageBackend
from aiida.orm.querybuilder import FilterType, OrderByType, QueryBuilder

__all__ = ('Entity', 'Collection', 'EntityExtrasMixin', 'EntityTypes')
__all__ = ('Entity', 'Collection', 'EntityTypes')

CollectionType = TypeVar('CollectionType', bound='Collection')
EntityType = TypeVar('EntityType', bound='Entity')
BackendEntityType = TypeVar('BackendEntityType', bound='BackendEntity')

_NO_DEFAULT: Any = tuple()


class EntityTypes(Enum):
"""Enum for referring to ORM entities in a backend-agnostic manner."""
Expand Down Expand Up @@ -245,151 +242,3 @@ def backend(self) -> 'StorageBackend':
def backend_entity(self) -> BackendEntityType:
"""Get the implementing class for this object"""
return self._backend_entity


class EntityProtocol(Protocol):
"""Protocol for attributes required by Entity mixins."""

@property
def backend_entity(self) -> 'BackendEntity':
...

@property
def is_stored(self) -> bool:
...


class EntityExtrasMixin:
"""Mixin class that adds all methods for the extras column to an entity."""

@property
def extras(self: EntityProtocol) -> Dict[str, Any]:
"""Return the complete extras dictionary.
.. warning:: While the entity is unstored, this will return references of the extras on the database model,
meaning that changes on the returned values (if they are mutable themselves, e.g. a list or dictionary) will
automatically be reflected on the database model as well. As soon as the entity is stored, the returned
extras will be a deep copy and mutations of the database extras will have to go through the appropriate set
methods. Therefore, once stored, retrieving a deep copy can be a heavy operation. If you only need the keys
or some values, use the iterators `extras_keys` and `extras_items`, or the getters `get_extra` and
`get_extra_many` instead.
:return: the extras as a dictionary
"""
extras = self.backend_entity.extras

if self.is_stored:
extras = copy.deepcopy(extras)

return extras

def get_extra(self: EntityProtocol, key: str, default: Any = _NO_DEFAULT) -> Any:
"""Return the value of an extra.
.. warning:: While the entity is unstored, this will return a reference of the extra on the database model,
meaning that changes on the returned value (if they are mutable themselves, e.g. a list or dictionary) will
automatically be reflected on the database model as well. As soon as the entity is stored, the returned
extra will be a deep copy and mutations of the database extras will have to go through the appropriate set
methods.
:param key: name of the extra
:param default: return this value instead of raising if the attribute does not exist
:return: the value of the extra
:raises AttributeError: if the extra does not exist and no default is specified
"""
try:
extra = self.backend_entity.get_extra(key)
except AttributeError:
if default is _NO_DEFAULT:
raise
extra = default

if self.is_stored:
extra = copy.deepcopy(extra)

return extra

def get_extra_many(self: EntityProtocol, keys: List[str]) -> List[Any]:
"""Return the values of multiple extras.
.. warning:: While the entity is unstored, this will return references of the extras on the database model,
meaning that changes on the returned values (if they are mutable themselves, e.g. a list or dictionary) will
automatically be reflected on the database model as well. As soon as the entity is stored, the returned
extras will be a deep copy and mutations of the database extras will have to go through the appropriate set
methods. Therefore, once stored, retrieving a deep copy can be a heavy operation. If you only need the keys
or some values, use the iterators `extras_keys` and `extras_items`, or the getters `get_extra` and
`get_extra_many` instead.
:param keys: a list of extra names
:return: a list of extra values
:raises AttributeError: if at least one extra does not exist
"""
extras = self.backend_entity.get_extra_many(keys)

if self.is_stored:
extras = copy.deepcopy(extras)

return extras

def set_extra(self: EntityProtocol, key: str, value: Any) -> None:
"""Set an extra to the given value.
:param key: name of the extra
:param value: value of the extra
:raise aiida.common.ValidationError: if the key is invalid, i.e. contains periods
"""
self.backend_entity.set_extra(key, value)

def set_extra_many(self: EntityProtocol, extras: Dict[str, Any]) -> None:
"""Set multiple extras.
.. note:: This will override any existing extras that are present in the new dictionary.
:param extras: a dictionary with the extras to set
:raise aiida.common.ValidationError: if any of the keys are invalid, i.e. contain periods
"""
self.backend_entity.set_extra_many(extras)

def reset_extras(self: EntityProtocol, extras: Dict[str, Any]) -> None:
"""Reset the extras.
.. note:: This will completely clear any existing extras and replace them with the new dictionary.
:param extras: a dictionary with the extras to set
:raise aiida.common.ValidationError: if any of the keys are invalid, i.e. contain periods
"""
self.backend_entity.reset_extras(extras)

def delete_extra(self: EntityProtocol, key: str) -> None:
"""Delete an extra.
:param key: name of the extra
:raises AttributeError: if the extra does not exist
"""
self.backend_entity.delete_extra(key)

def delete_extra_many(self: EntityProtocol, keys: List[str]) -> None:
"""Delete multiple extras.
:param keys: names of the extras to delete
:raises AttributeError: if at least one of the extra does not exist
"""
self.backend_entity.delete_extra_many(keys)

def clear_extras(self: EntityProtocol) -> None:
"""Delete all extras."""
self.backend_entity.clear_extras()

def extras_items(self: EntityProtocol):
"""Return an iterator over the extras.
:return: an iterator with extra key value pairs
"""
return self.backend_entity.extras_items()

def extras_keys(self: EntityProtocol):
"""Return an iterator over the extra keys.
:return: an iterator with extra keys
"""
return self.backend_entity.extras_keys()
Loading

0 comments on commit d34c3ad

Please sign in to comment.