Skip to content

Commit

Permalink
Merge branch 'develop' into fix_1525_hanging_ssh
Browse files Browse the repository at this point in the history
  • Loading branch information
ramirezfranciscof authored Jun 4, 2021
2 parents c8cbb6e + 124fece commit 406a3bc
Show file tree
Hide file tree
Showing 4 changed files with 259 additions and 13 deletions.
60 changes: 48 additions & 12 deletions aiida/engine/processes/workchains/workchain.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import collections.abc
import functools
import logging
from typing import Any, List, Optional, Sequence, Union, TYPE_CHECKING
from typing import Any, List, Optional, Sequence, Union, Tuple, TYPE_CHECKING

from plumpy.persistence import auto_persist
from plumpy.process_states import Wait, Continue
Expand Down Expand Up @@ -121,25 +121,59 @@ def on_run(self):
super().on_run()
self.node.set_stepper_state_info(str(self._stepper))

def _resolve_nested_context(self, key: str) -> Tuple[AttributeDict, str]:
"""
Returns a reference to a sub-dictionary of the context and the last key,
after resolving a potentially segmented key where required sub-dictionaries are created as needed.
:param key: A key into the context, where words before a dot are interpreted as a key for a sub-dictionary
"""
ctx = self.ctx
ctx_path = key.split('.')

for index, path in enumerate(ctx_path[:-1]):
try:
ctx = ctx[path]
except KeyError: # see below why this is the only exception we have to catch here
ctx[path] = AttributeDict() # create the sub-dict and update the context
ctx = ctx[path]
continue

# Notes:
# * the first ctx (self.ctx) is guaranteed to be an AttributeDict, hence the post-"dereference" checking
# * the values can be many different things: on insertion they are either AtrributeDict, List or Awaitables
# (subclasses of AttributeDict) but after resolution of an Awaitable this will be the value itself
# * assumption: a resolved value is never a plain AttributeDict, on the other hand if a resolved Awaitable
# would be an AttributeDict we can append things to it since the order of tasks is maintained.
if type(ctx) != AttributeDict: # pylint: disable=C0123
raise ValueError(
f'Can not update the context for key `{key}`:'
f' found instance of `{type(ctx)}` at `{".".join(ctx_path[:index+1])}`, expected AttributeDict'
)

return ctx, ctx_path[-1]

def insert_awaitable(self, awaitable: Awaitable) -> None:
"""Insert an awaitable that should be terminated before before continuing to the next step.
:param awaitable: the thing to await
:type awaitable: :class:`aiida.engine.processes.workchains.awaitable.Awaitable`
"""
self._awaitables.append(awaitable)
ctx, key = self._resolve_nested_context(awaitable.key)

# Already assign the awaitable itself to the location in the context container where it is supposed to end up
# once it is resolved. This is especially important for the `APPEND` action, since it needs to maintain the
# order, but the awaitables will not necessarily be resolved in the order in which they are added. By using the
# awaitable as a placeholder, in the `resolve_awaitable`, it can be found and replaced by the resolved value.
if awaitable.action == AwaitableAction.ASSIGN:
self.ctx[awaitable.key] = awaitable
ctx[key] = awaitable
elif awaitable.action == AwaitableAction.APPEND:
self.ctx.setdefault(awaitable.key, []).append(awaitable)
ctx.setdefault(key, []).append(awaitable)
else:
assert f'Unknown awaitable action: {awaitable.action}'
raise AssertionError(f'Unsupported awaitable action: {awaitable.action}')

self._awaitables.append(
awaitable
) # add only if everything went ok, otherwise we end up in an inconsistent state
self._update_process_status()

def resolve_awaitable(self, awaitable: Awaitable, value: Any) -> None:
Expand All @@ -149,23 +183,25 @@ def resolve_awaitable(self, awaitable: Awaitable, value: Any) -> None:
:param awaitable: the awaitable to resolve
"""
self._awaitables.remove(awaitable)

ctx, key = self._resolve_nested_context(awaitable.key)

if awaitable.action == AwaitableAction.ASSIGN:
self.ctx[awaitable.key] = value
ctx[key] = value
elif awaitable.action == AwaitableAction.APPEND:
# Find the same awaitable inserted in the context
container = self.ctx[awaitable.key]
container = ctx[key]
for index, placeholder in enumerate(container):
if placeholder.pk == awaitable.pk and isinstance(placeholder, Awaitable):
if isinstance(placeholder, Awaitable) and placeholder.pk == awaitable.pk:
container[index] = value
break
else:
assert f'Awaitable `{awaitable.pk} was not found in `ctx.{awaitable.pk}`'
raise AssertionError(f'Awaitable `{awaitable.pk} was not found in `ctx.{awaitable.key}`')
else:
assert f'Unknown awaitable action: {awaitable.action}'
raise AssertionError(f'Unsupported awaitable action: {awaitable.action}')

awaitable.resolved = True
self._awaitables.remove(awaitable) # remove only if everything went ok, otherwise we may lose track

if not self.has_terminated():
# the process may be terminated, for example, if the process was killed or excepted
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# -*- coding: utf-8 -*-
from aiida.engine import WorkChain


class SomeWorkChain(WorkChain):

@classmethod
def define(cls, spec):
super().define(spec)
spec.outline(
cls.submit_workchains,
cls.inspect_workchains,
)

def submit_workchains(self):
for i in range(3):
future = self.submit(SomeWorkChain)
key = f'workchain.sub{i}'
self.to_context(**{key: future})

def inspect_workchains(self):
for i in range(3):
assert self.ctx.workchain[f'sub{i}'].is_finished_ok
11 changes: 11 additions & 0 deletions docs/source/topics/workflows/usage.rst
Original file line number Diff line number Diff line change
Expand Up @@ -417,6 +417,17 @@ The ``self.ctx.workchains`` now contains a list with the nodes of the completed
Note that the use of ``append_`` is not just limited to the ``to_context`` method.
You can also use it in exactly the same way with ``ToContext`` to append a process to a list in the context in multiple outline steps.

Nested context keys
^^^^^^^^^^^^^^^^^^^

To simplify the organization of the context, the keys may contain dots ``.``, transparently creating namespaces in the process.
As an example compare the following to the parallel submission example above:

.. include:: include/snippets/workchains/run_workchain_submit_append.py
:code: python

This allows to create intuitively grouped and easily accessible structures of child calculations or workchains.

.. _topics:workflows:usage:workchains:reporting:

Reporting
Expand Down
178 changes: 177 additions & 1 deletion tests/engine/test_work_chain.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from aiida.common import exceptions
from aiida.common.links import LinkType
from aiida.common.utils import Capturing
from aiida.engine import ExitCode, Process, ToContext, WorkChain, if_, while_, return_, launch, calcfunction
from aiida.engine import ExitCode, Process, ToContext, WorkChain, if_, while_, return_, launch, calcfunction, append_
from aiida.engine.persistence import ObjectLoader
from aiida.manage.manager import get_manager
from aiida.orm import load_node, Bool, Float, Int, Str
Expand Down Expand Up @@ -780,6 +780,182 @@ def result(self):

run_and_check_success(Workchain)

def test_nested_to_context(self):
val = Int(5).store()

test_case = self

class SimpleWc(WorkChain):

@classmethod
def define(cls, spec):
super().define(spec)
spec.outline(cls.result)
spec.outputs.dynamic = True

def result(self):
self.out('result', val)

class Workchain(WorkChain):

@classmethod
def define(cls, spec):
super().define(spec)
spec.outline(cls.begin, cls.result)

def begin(self):
self.to_context(**{'sub1.sub2.result_a': self.submit(SimpleWc)})
return ToContext(**{'sub1.sub2.result_b': self.submit(SimpleWc)})

def result(self):
test_case.assertEqual(self.ctx.sub1.sub2.result_a.outputs.result, val)
test_case.assertEqual(self.ctx.sub1.sub2.result_b.outputs.result, val)

run_and_check_success(Workchain)

def test_nested_to_context_with_append(self):
val1 = Int(5).store()
val2 = Int(6).store()

test_case = self

class SimpleWc1(WorkChain):

@classmethod
def define(cls, spec):
super().define(spec)
spec.outline(cls.result)
spec.outputs.dynamic = True

def result(self):
self.out('result', val1)

class SimpleWc2(WorkChain):

@classmethod
def define(cls, spec):
super().define(spec)
spec.outline(cls.result)
spec.outputs.dynamic = True

def result(self):
self.out('result', val2)

class Workchain(WorkChain):

@classmethod
def define(cls, spec):
super().define(spec)
spec.outline(cls.begin, cls.result)

def begin(self):
self.to_context(**{'sub1.workchains': append_(self.submit(SimpleWc1))})
return ToContext(**{'sub1.workchains': append_(self.submit(SimpleWc2))})

def result(self):
test_case.assertEqual(self.ctx.sub1.workchains[0].outputs.result, val1)
test_case.assertEqual(self.ctx.sub1.workchains[1].outputs.result, val2)

run_and_check_success(Workchain)

def test_nested_to_context_no_overlap(self):
val = Int(5).store()

class SimpleWc(WorkChain):

@classmethod
def define(cls, spec):
super().define(spec)
spec.outline(cls.result)
spec.outputs.dynamic = True

def result(self):
self.out('result', val)

class Workchain(WorkChain):

@classmethod
def define(cls, spec):
super().define(spec)
spec.outline(cls.begin, cls.result)

def begin(self):
self.to_context(**{'result_a': self.submit(SimpleWc)})
return ToContext(**{'result_a.sub1': self.submit(SimpleWc)})

def result(self):
raise RuntimeError('Never reached: the second to_context above should fail')

process = Workchain()
with pytest.raises(ValueError):
launch.run(process)

def test_nested_to_context_no_overlap_with_append(self):
val = Int(5).store()

class SimpleWc(WorkChain):

@classmethod
def define(cls, spec):
super().define(spec)
spec.outline(cls.result)
spec.outputs.dynamic = True

def result(self):
self.out('result', val)

class Workchain(WorkChain):

@classmethod
def define(cls, spec):
super().define(spec)
spec.outline(cls.begin, cls.result)

def begin(self):
self.to_context(workchains=append_(self.submit(SimpleWc))) # make the workchains point to a list
return ToContext(**{'workchains.sub1.sub2': self.submit(SimpleWc)}) # now try to treat it as a sub-dict

def result(self):
raise RuntimeError('Never reached: the second to_context above should fail')

process = Workchain()
with pytest.raises(ValueError):
launch.run(process)

def test_nested_to_context_no_overlap_with_append2(self):
val = Int(5).store()

class SimpleWc(WorkChain):

@classmethod
def define(cls, spec):
super().define(spec)
spec.outline(cls.result)
spec.outputs.dynamic = True

def result(self):
self.out('result', val)

class Workchain(WorkChain):

@classmethod
def define(cls, spec):
super().define(spec)
spec.outline(cls.begin, cls.result)

def begin(self):
self.to_context(workchains=append_(self.submit(SimpleWc))) # make the workchains point to a list
return ToContext(
**{'workchains.sub1': self.submit(SimpleWc)}
) # now try to treat the final path element it as a sub-dict

def result(self):
raise RuntimeError('Never reached: the second to_context above should fail')

process = Workchain()
with pytest.raises(ValueError):
launch.run(process)

def test_namespace_nondb_mapping(self):
"""
Regression test for a bug in _flatten_inputs
Expand Down

0 comments on commit 406a3bc

Please sign in to comment.