Skip to content

Commit

Permalink
Process: Have inputs property always return AttributesFrozenDict (
Browse files Browse the repository at this point in the history
#6010)

The `Process.inputs` property as implemented in `plumpy` has as a return
type `AttributesFrozenDict | None`. This leads to a lot unnecessary
complexities in the code having to deal with the potential `None`, where
really this should never really occur. A lot of user code will never
even check for `Process.inputs` returning `None`, such as in `WorkChain`
implementations, and as a result type checkers will fail forcing a user
to either unnecessarily complicate their code by explicitly checking for
`None`, but will typically end up silencing the error.

The `inputs` property is overridden here to return an empty
`AttributesFrozenDict` in case the inputs are `None`, which allows to
simplify the return type and get rid of any type errors in downstream
code.
  • Loading branch information
sphuber authored May 15, 2023
1 parent 7ad9168 commit 60756fe
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 11 deletions.
12 changes: 5 additions & 7 deletions aiida/engine/processes/calcjobs/calcjob.py
Original file line number Diff line number Diff line change
Expand Up @@ -539,11 +539,11 @@ def run(self) -> Union[plumpy.process_states.Stop, int, plumpy.process_states.Wa
`Wait` command if the calcjob is to be uploaded
"""
if self.inputs.metadata.dry_run: # type: ignore[union-attr]
if self.inputs.metadata.dry_run:
self._perform_dry_run()
return plumpy.process_states.Stop(None, True)

if 'remote_folder' in self.inputs: # type: ignore[operator]
if 'remote_folder' in self.inputs:
exit_code = self._perform_import()
return exit_code

Expand Down Expand Up @@ -596,7 +596,7 @@ def _setup_inputs(self) -> None:
# will have an associated computer, but in that case the ``computer`` property should return ``None`` and
# nothing would change anyway.
if not self.node.computer:
self.node.computer = self.inputs.code.computer # type: ignore[union-attr]
self.node.computer = self.inputs.code.computer

def _perform_dry_run(self):
"""Perform a dry run.
Expand Down Expand Up @@ -640,9 +640,7 @@ def _perform_import(self):
with SandboxFolder(filepath_sandbox) as folder:
with SandboxFolder(filepath_sandbox) as retrieved_temporary_folder:
self.presubmit(folder)
self.node.set_remote_workdir(
self.inputs.remote_folder.get_remote_path() # type: ignore[union-attr]
)
self.node.set_remote_workdir(self.inputs.remote_folder.get_remote_path())
retrieve_calculation(self.node, transport, retrieved_temporary_folder.abspath)
self.node.set_state(CalcJobState.PARSING)
self.node.base.attributes.set(orm.CalcJobNode.IMMIGRATED_KEY, True)
Expand Down Expand Up @@ -821,7 +819,7 @@ def presubmit(self, folder: Folder) -> CalcInfo:

inputs = self.node.base.links.get_incoming(link_type=LinkType.INPUT_CALC)

if not self.inputs.metadata.dry_run and not self.node.is_stored: # type: ignore[union-attr]
if not self.inputs.metadata.dry_run and not self.node.is_stored:
raise InvalidOperation('calculation node is not stored.')

computer = self.node.computer
Expand Down
12 changes: 11 additions & 1 deletion aiida/engine/processes/process.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
import plumpy.persistence
from plumpy.process_states import Finished, ProcessState
import plumpy.processes
from plumpy.utils import AttributesFrozendict

from aiida import orm
from aiida.common import exceptions
Expand Down Expand Up @@ -233,6 +234,15 @@ def uuid(self) -> str: # type: ignore[override]
"""
return self.node.uuid

@property
def inputs(self) -> AttributesFrozendict:
"""Return the inputs attribute dictionary or an empty one.
This overrides the property of the base class because that can also return ``None``. This override ensures
calling functions that they will always get an instance of ``AttributesFrozenDict``.
"""
return super().inputs or AttributesFrozendict()

@property
def metadata(self) -> AttributeDict:
"""Return the metadata that were specified when this process instance was launched.
Expand Down Expand Up @@ -953,7 +963,7 @@ def exposed_inputs(
else:
inputs = self.inputs
for part in sub_namespace.split('.'):
inputs = inputs[part] # type: ignore[index]
inputs = inputs[part]
try:
port_namespace = self.spec().inputs.get_port(sub_namespace) # type: ignore[assignment]
except KeyError:
Expand Down
6 changes: 3 additions & 3 deletions aiida/engine/processes/workchains/restart.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ def should_run_process(self) -> bool:
This is the case as long as the last process has not finished successfully and the maximum number of restarts
has not yet been exceeded.
"""
max_iterations = self.inputs.max_iterations.value # type: ignore[union-attr]
max_iterations = self.inputs.max_iterations.value
return not self.ctx.is_finished and self.ctx.iteration < max_iterations

def run_process(self) -> ToContext:
Expand Down Expand Up @@ -311,7 +311,7 @@ def results(self) -> Optional['ExitCode']:
# We check the `is_finished` attribute of the work chain and not the successfulness of the last process
# because the error handlers in the last iteration can have qualified a "failed" process as satisfactory
# for the outcome of the work chain and so have marked it as `is_finished=True`.
max_iterations = self.inputs.max_iterations.value # type: ignore[union-attr]
max_iterations = self.inputs.max_iterations.value
if not self.ctx.is_finished and self.ctx.iteration >= max_iterations:
self.report(
f'reached the maximum number of iterations {max_iterations}: '
Expand Down Expand Up @@ -392,7 +392,7 @@ def on_terminated(self):
"""Clean the working directories of all child calculation jobs if `clean_workdir=True` in the inputs."""
super().on_terminated()

if self.inputs.clean_workdir.value is False: # type: ignore[union-attr]
if self.inputs.clean_workdir.value is False:
self.report('remote folders will not be cleaned')
return

Expand Down
1 change: 1 addition & 0 deletions docs/source/nitpick-exceptions
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,7 @@ py:class paramiko.proxy.ProxyCommand
# These can be removed once they are properly included in the `__all__` in `plumpy`
py:class plumpy.ports.PortNamespace
py:class plumpy.utils.AttributesDict
py:class plumpy.utils.AttributesFrozendict
py:class plumpy.process_states.State
py:class plumpy.workchains._If
py:class plumpy.workchains._While
Expand Down

0 comments on commit 60756fe

Please sign in to comment.