diff --git a/aiida/engine/processes/workchains/workchain.py b/aiida/engine/processes/workchains/workchain.py index 4978e3594f..82e1b50837 100644 --- a/aiida/engine/processes/workchains/workchain.py +++ b/aiida/engine/processes/workchains/workchain.py @@ -11,7 +11,7 @@ import collections.abc import functools import logging -from typing import Any, List, Optional, Sequence, Union, TYPE_CHECKING +from typing import Any, List, Optional, Sequence, Union, Tuple, TYPE_CHECKING from plumpy.persistence import auto_persist from plumpy.process_states import Wait, Continue @@ -121,25 +121,59 @@ def on_run(self): super().on_run() self.node.set_stepper_state_info(str(self._stepper)) + def _resolve_nested_context(self, key: str) -> Tuple[AttributeDict, str]: + """ + Returns a reference to a sub-dictionary of the context and the last key, + after resolving a potentially segmented key where required sub-dictionaries are created as needed. + + :param key: A key into the context, where words before a dot are interpreted as a key for a sub-dictionary + """ + ctx = self.ctx + ctx_path = key.split('.') + + for index, path in enumerate(ctx_path[:-1]): + try: + ctx = ctx[path] + except KeyError: # see below why this is the only exception we have to catch here + ctx[path] = AttributeDict() # create the sub-dict and update the context + ctx = ctx[path] + continue + + # Notes: + # * the first ctx (self.ctx) is guaranteed to be an AttributeDict, hence the post-"dereference" checking + # * the values can be many different things: on insertion they are either AtrributeDict, List or Awaitables + # (subclasses of AttributeDict) but after resolution of an Awaitable this will be the value itself + # * assumption: a resolved value is never a plain AttributeDict, on the other hand if a resolved Awaitable + # would be an AttributeDict we can append things to it since the order of tasks is maintained. + if type(ctx) != AttributeDict: # pylint: disable=C0123 + raise ValueError( + f'Can not update the context for key `{key}`:' + f' found instance of `{type(ctx)}` at `{".".join(ctx_path[:index+1])}`, expected AttributeDict' + ) + + return ctx, ctx_path[-1] + def insert_awaitable(self, awaitable: Awaitable) -> None: """Insert an awaitable that should be terminated before before continuing to the next step. :param awaitable: the thing to await - :type awaitable: :class:`aiida.engine.processes.workchains.awaitable.Awaitable` """ - self._awaitables.append(awaitable) + ctx, key = self._resolve_nested_context(awaitable.key) # Already assign the awaitable itself to the location in the context container where it is supposed to end up # once it is resolved. This is especially important for the `APPEND` action, since it needs to maintain the # order, but the awaitables will not necessarily be resolved in the order in which they are added. By using the # awaitable as a placeholder, in the `resolve_awaitable`, it can be found and replaced by the resolved value. if awaitable.action == AwaitableAction.ASSIGN: - self.ctx[awaitable.key] = awaitable + ctx[key] = awaitable elif awaitable.action == AwaitableAction.APPEND: - self.ctx.setdefault(awaitable.key, []).append(awaitable) + ctx.setdefault(key, []).append(awaitable) else: - assert f'Unknown awaitable action: {awaitable.action}' + raise AssertionError(f'Unsupported awaitable action: {awaitable.action}') + self._awaitables.append( + awaitable + ) # add only if everything went ok, otherwise we end up in an inconsistent state self._update_process_status() def resolve_awaitable(self, awaitable: Awaitable, value: Any) -> None: @@ -149,23 +183,25 @@ def resolve_awaitable(self, awaitable: Awaitable, value: Any) -> None: :param awaitable: the awaitable to resolve """ - self._awaitables.remove(awaitable) + + ctx, key = self._resolve_nested_context(awaitable.key) if awaitable.action == AwaitableAction.ASSIGN: - self.ctx[awaitable.key] = value + ctx[key] = value elif awaitable.action == AwaitableAction.APPEND: # Find the same awaitable inserted in the context - container = self.ctx[awaitable.key] + container = ctx[key] for index, placeholder in enumerate(container): - if placeholder.pk == awaitable.pk and isinstance(placeholder, Awaitable): + if isinstance(placeholder, Awaitable) and placeholder.pk == awaitable.pk: container[index] = value break else: - assert f'Awaitable `{awaitable.pk} was not found in `ctx.{awaitable.pk}`' + raise AssertionError(f'Awaitable `{awaitable.pk} was not found in `ctx.{awaitable.key}`') else: - assert f'Unknown awaitable action: {awaitable.action}' + raise AssertionError(f'Unsupported awaitable action: {awaitable.action}') awaitable.resolved = True + self._awaitables.remove(awaitable) # remove only if everything went ok, otherwise we may lose track if not self.has_terminated(): # the process may be terminated, for example, if the process was killed or excepted diff --git a/docs/source/topics/workflows/include/snippets/workchains/run_workchain_submit_parallel_nested.py b/docs/source/topics/workflows/include/snippets/workchains/run_workchain_submit_parallel_nested.py new file mode 100644 index 0000000000..e71af06912 --- /dev/null +++ b/docs/source/topics/workflows/include/snippets/workchains/run_workchain_submit_parallel_nested.py @@ -0,0 +1,23 @@ +# -*- coding: utf-8 -*- +from aiida.engine import WorkChain + + +class SomeWorkChain(WorkChain): + + @classmethod + def define(cls, spec): + super().define(spec) + spec.outline( + cls.submit_workchains, + cls.inspect_workchains, + ) + + def submit_workchains(self): + for i in range(3): + future = self.submit(SomeWorkChain) + key = f'workchain.sub{i}' + self.to_context(**{key: future}) + + def inspect_workchains(self): + for i in range(3): + assert self.ctx.workchain[f'sub{i}'].is_finished_ok diff --git a/docs/source/topics/workflows/usage.rst b/docs/source/topics/workflows/usage.rst index 8f7160fe90..abcd62d959 100644 --- a/docs/source/topics/workflows/usage.rst +++ b/docs/source/topics/workflows/usage.rst @@ -417,6 +417,17 @@ The ``self.ctx.workchains`` now contains a list with the nodes of the completed Note that the use of ``append_`` is not just limited to the ``to_context`` method. You can also use it in exactly the same way with ``ToContext`` to append a process to a list in the context in multiple outline steps. +Nested context keys +^^^^^^^^^^^^^^^^^^^ + +To simplify the organization of the context, the keys may contain dots ``.``, transparently creating namespaces in the process. +As an example compare the following to the parallel submission example above: + +.. include:: include/snippets/workchains/run_workchain_submit_append.py + :code: python + +This allows to create intuitively grouped and easily accessible structures of child calculations or workchains. + .. _topics:workflows:usage:workchains:reporting: Reporting diff --git a/tests/engine/test_work_chain.py b/tests/engine/test_work_chain.py index b08e9cf3b8..79946ad679 100644 --- a/tests/engine/test_work_chain.py +++ b/tests/engine/test_work_chain.py @@ -20,7 +20,7 @@ from aiida.common import exceptions from aiida.common.links import LinkType from aiida.common.utils import Capturing -from aiida.engine import ExitCode, Process, ToContext, WorkChain, if_, while_, return_, launch, calcfunction +from aiida.engine import ExitCode, Process, ToContext, WorkChain, if_, while_, return_, launch, calcfunction, append_ from aiida.engine.persistence import ObjectLoader from aiida.manage.manager import get_manager from aiida.orm import load_node, Bool, Float, Int, Str @@ -780,6 +780,182 @@ def result(self): run_and_check_success(Workchain) + def test_nested_to_context(self): + val = Int(5).store() + + test_case = self + + class SimpleWc(WorkChain): + + @classmethod + def define(cls, spec): + super().define(spec) + spec.outline(cls.result) + spec.outputs.dynamic = True + + def result(self): + self.out('result', val) + + class Workchain(WorkChain): + + @classmethod + def define(cls, spec): + super().define(spec) + spec.outline(cls.begin, cls.result) + + def begin(self): + self.to_context(**{'sub1.sub2.result_a': self.submit(SimpleWc)}) + return ToContext(**{'sub1.sub2.result_b': self.submit(SimpleWc)}) + + def result(self): + test_case.assertEqual(self.ctx.sub1.sub2.result_a.outputs.result, val) + test_case.assertEqual(self.ctx.sub1.sub2.result_b.outputs.result, val) + + run_and_check_success(Workchain) + + def test_nested_to_context_with_append(self): + val1 = Int(5).store() + val2 = Int(6).store() + + test_case = self + + class SimpleWc1(WorkChain): + + @classmethod + def define(cls, spec): + super().define(spec) + spec.outline(cls.result) + spec.outputs.dynamic = True + + def result(self): + self.out('result', val1) + + class SimpleWc2(WorkChain): + + @classmethod + def define(cls, spec): + super().define(spec) + spec.outline(cls.result) + spec.outputs.dynamic = True + + def result(self): + self.out('result', val2) + + class Workchain(WorkChain): + + @classmethod + def define(cls, spec): + super().define(spec) + spec.outline(cls.begin, cls.result) + + def begin(self): + self.to_context(**{'sub1.workchains': append_(self.submit(SimpleWc1))}) + return ToContext(**{'sub1.workchains': append_(self.submit(SimpleWc2))}) + + def result(self): + test_case.assertEqual(self.ctx.sub1.workchains[0].outputs.result, val1) + test_case.assertEqual(self.ctx.sub1.workchains[1].outputs.result, val2) + + run_and_check_success(Workchain) + + def test_nested_to_context_no_overlap(self): + val = Int(5).store() + + class SimpleWc(WorkChain): + + @classmethod + def define(cls, spec): + super().define(spec) + spec.outline(cls.result) + spec.outputs.dynamic = True + + def result(self): + self.out('result', val) + + class Workchain(WorkChain): + + @classmethod + def define(cls, spec): + super().define(spec) + spec.outline(cls.begin, cls.result) + + def begin(self): + self.to_context(**{'result_a': self.submit(SimpleWc)}) + return ToContext(**{'result_a.sub1': self.submit(SimpleWc)}) + + def result(self): + raise RuntimeError('Never reached: the second to_context above should fail') + + process = Workchain() + with pytest.raises(ValueError): + launch.run(process) + + def test_nested_to_context_no_overlap_with_append(self): + val = Int(5).store() + + class SimpleWc(WorkChain): + + @classmethod + def define(cls, spec): + super().define(spec) + spec.outline(cls.result) + spec.outputs.dynamic = True + + def result(self): + self.out('result', val) + + class Workchain(WorkChain): + + @classmethod + def define(cls, spec): + super().define(spec) + spec.outline(cls.begin, cls.result) + + def begin(self): + self.to_context(workchains=append_(self.submit(SimpleWc))) # make the workchains point to a list + return ToContext(**{'workchains.sub1.sub2': self.submit(SimpleWc)}) # now try to treat it as a sub-dict + + def result(self): + raise RuntimeError('Never reached: the second to_context above should fail') + + process = Workchain() + with pytest.raises(ValueError): + launch.run(process) + + def test_nested_to_context_no_overlap_with_append2(self): + val = Int(5).store() + + class SimpleWc(WorkChain): + + @classmethod + def define(cls, spec): + super().define(spec) + spec.outline(cls.result) + spec.outputs.dynamic = True + + def result(self): + self.out('result', val) + + class Workchain(WorkChain): + + @classmethod + def define(cls, spec): + super().define(spec) + spec.outline(cls.begin, cls.result) + + def begin(self): + self.to_context(workchains=append_(self.submit(SimpleWc))) # make the workchains point to a list + return ToContext( + **{'workchains.sub1': self.submit(SimpleWc)} + ) # now try to treat the final path element it as a sub-dict + + def result(self): + raise RuntimeError('Never reached: the second to_context above should fail') + + process = Workchain() + with pytest.raises(ValueError): + launch.run(process) + def test_namespace_nondb_mapping(self): """ Regression test for a bug in _flatten_inputs