Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

function node can return the source of function rather than only the source of file #4554

Merged
merged 1 commit into from
Nov 2, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ repos:
aiida/orm/nodes/process/.*py|
aiida/orm/nodes/repository.py|
aiida/orm/utils/links.py|
aiida/orm/utils/mixins.py|
aiida/plugins/entry_point.py|
aiida/plugins/factories.py|
aiida/plugins/utils.py|
Expand Down
110 changes: 83 additions & 27 deletions aiida/orm/utils/mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@
from __future__ import annotations

import inspect
from typing import List, Optional

from aiida.common import exceptions
from aiida.common.lang import classproperty, override
from aiida.common.lang import classproperty, override, type_check
from aiida.common.warnings import warn_deprecation


class FunctionCalculationMixin:
Expand All @@ -30,9 +30,10 @@ class FunctionCalculationMixin:
FUNCTION_NAME_KEY = 'function_name'
FUNCTION_NAMESPACE_KEY = 'function_namespace'
FUNCTION_STARTING_LINE_KEY = 'function_starting_line_number'
FUNCTION_NUMBER_OF_LINES_KEY = 'function_number_of_lines'
FUNCTION_SOURCE_FILE_PATH = 'source_file'

def store_source_info(self, func):
def store_source_info(self, func) -> None:
"""
Retrieve source information about the wrapped function `func` through the inspect module,
and store it in the attributes and repository of the node. The function name, namespace
Expand All @@ -44,10 +45,12 @@ def store_source_info(self, func):
self._set_function_name(func.__name__)

try:
_, starting_line_number = inspect.getsourcelines(func)
self._set_function_starting_line_number(starting_line_number)
source_list, starting_line_number = inspect.getsourcelines(func)
except (IOError, OSError):
pass
else:
self._set_function_starting_line_number(starting_line_number)
self._set_function_number_of_lines(len(source_list))

try:
self._set_function_namespace(func.__globals__['__name__'])
Expand All @@ -56,55 +59,78 @@ def store_source_info(self, func):

try:
source_file_path = inspect.getsourcefile(func)
with open(source_file_path, 'rb') as handle:
self.base.repository.put_object_from_filelike(handle, self.FUNCTION_SOURCE_FILE_PATH)
if source_file_path:
with open(source_file_path, 'rb') as handle:
self.base.repository.put_object_from_filelike( # type: ignore[attr-defined]
handle, self.FUNCTION_SOURCE_FILE_PATH
)
except (IOError, OSError):
pass

@property
def function_name(self):
def function_name(self) -> str | None:
"""Return the function name of the wrapped function.

:returns: the function name or None
"""
return self.base.attributes.get(self.FUNCTION_NAME_KEY, None)
return self.base.attributes.get(self.FUNCTION_NAME_KEY, None) # type: ignore[attr-defined]

def _set_function_name(self, function_name):
def _set_function_name(self, function_name: str):
"""Set the function name of the wrapped function.

:param function_name: the function name
"""
self.base.attributes.set(self.FUNCTION_NAME_KEY, function_name)
self.base.attributes.set(self.FUNCTION_NAME_KEY, function_name) # type: ignore[attr-defined]

@property
def function_namespace(self):
def function_namespace(self) -> str | None:
"""Return the function namespace of the wrapped function.

:returns: the function namespace or None
"""
return self.base.attributes.get(self.FUNCTION_NAMESPACE_KEY, None)
return self.base.attributes.get(self.FUNCTION_NAMESPACE_KEY, None) # type: ignore[attr-defined]

def _set_function_namespace(self, function_namespace):
def _set_function_namespace(self, function_namespace: str) -> None:
"""Set the function namespace of the wrapped function.

:param function_namespace: the function namespace
"""
self.base.attributes.set(self.FUNCTION_NAMESPACE_KEY, function_namespace)
self.base.attributes.set(self.FUNCTION_NAMESPACE_KEY, function_namespace) # type: ignore[attr-defined]

@property
def function_starting_line_number(self):
def function_starting_line_number(self) -> int | None:
"""Return the starting line number of the wrapped function in its source file.

:returns: the starting line number or None
"""
return self.base.attributes.get(self.FUNCTION_STARTING_LINE_KEY, None)
return self.base.attributes.get(self.FUNCTION_STARTING_LINE_KEY, None) # type: ignore[attr-defined]

def _set_function_starting_line_number(self, function_starting_line_number):
def _set_function_starting_line_number(self, function_starting_line_number: int) -> None:
"""Set the starting line number of the wrapped function in its source file.

:param function_starting_line_number: the starting line number
"""
self.base.attributes.set(self.FUNCTION_STARTING_LINE_KEY, function_starting_line_number)
self.base.attributes.set( # type: ignore[attr-defined]
self.FUNCTION_STARTING_LINE_KEY, function_starting_line_number
)

@property
def function_number_of_lines(self) -> int | None:
"""Return the number of lines of the wrapped function in its source file.

:returns: the number of lines or None
"""
return self.base.attributes.get(self.FUNCTION_NUMBER_OF_LINES_KEY, None) # type: ignore[attr-defined]

def _set_function_number_of_lines(self, function_number_of_lines: int) -> None:
"""Set the number of lines of the wrapped function in its source file.

:param function_number_of_lines: the number of lines
"""
unkcpz marked this conversation as resolved.
Show resolved Hide resolved
sphuber marked this conversation as resolved.
Show resolved Hide resolved
type_check(function_number_of_lines, int)
self.base.attributes.set( # type: ignore[attr-defined]
self.FUNCTION_NUMBER_OF_LINES_KEY, function_number_of_lines
)

def get_function_source_code(self) -> str | None:
"""Return the source code of the function stored in the repository.
Expand All @@ -113,13 +139,43 @@ def get_function_source_code(self) -> str | None:
function was defined in an interactive shell in which case ``store_source_info`` will have failed to retrieve
the source code using ``inspect.getsourcefile``.

:returns: The source code of the function or ``None`` if it could not be determined when storing the node.
"""
warn_deprecation('This method will be removed, use `get_source_code_file` instead.', version=3)

return self.get_source_code_file()

def get_source_code_file(self) -> str | None:
"""Return the source code of the file in which the process function was defined.

If the source code file does not exist, this will return ``None`` instead. This can happen for example when the
function was defined in an interactive shell in which case ``store_source_info`` will have failed to retrieve
the source code using ``inspect.getsourcefile``.

:returns: The source code of the function or ``None`` if it could not be determined when storing the node.
"""
try:
return self.base.repository.get_object_content(self.FUNCTION_SOURCE_FILE_PATH)
return self.base.repository.get_object_content(self.FUNCTION_SOURCE_FILE_PATH) # type: ignore[attr-defined]
except FileNotFoundError:
return None

def get_source_code_function(self) -> str | None:
"""Return the source code of the function including the decorator.

:returns: The source code of the function or ``None`` if not available.
"""
source_code = self.get_source_code_file()

if source_code is None or self.function_number_of_lines is None or self.function_starting_line_number is None:
return None

content_list = source_code.splitlines()
start_line = self.function_starting_line_number
end_line = start_line + self.function_number_of_lines

# Start at ``start_line - 1`` to include the decorator
return '\n'.join(content_list[start_line - 1:end_line])


class Sealable:
"""Mixin to mark a Node as `sealable`."""
Expand All @@ -128,21 +184,21 @@ class Sealable:
SEALED_KEY = 'sealed'

@classproperty
def _updatable_attributes(cls): # pylint: disable=no-self-argument
def _updatable_attributes(cls) -> tuple[str]: # pylint: disable=no-self-argument
return (cls.SEALED_KEY,)

@property
def is_sealed(self):
def is_sealed(self) -> bool:
"""Returns whether the node is sealed, i.e. whether the sealed attribute has been set to True."""
return self.base.attributes.get(self.SEALED_KEY, False)
return self.base.attributes.get(self.SEALED_KEY, False) # type: ignore[attr-defined]

def seal(self):
def seal(self) -> None:
"""Seal the node by setting the sealed attribute to True."""
if not self.is_sealed:
self.base.attributes.set(self.SEALED_KEY, True)
self.base.attributes.set(self.SEALED_KEY, True) # type: ignore[attr-defined]

@override
def _check_mutability_attributes(self, keys: Optional[List[str]] = None) -> None: # pylint: disable=unused-argument
def _check_mutability_attributes(self, keys: list[str] | None = None) -> None: # pylint: disable=unused-argument
"""Check if the entity is mutable and raise an exception if not.

This is called from `NodeAttributes` methods that modify the attributes.
Expand All @@ -152,7 +208,7 @@ def _check_mutability_attributes(self, keys: Optional[List[str]] = None) -> None
if self.is_sealed:
raise exceptions.ModificationNotAllowed('attributes of a sealed node are immutable')

if self.is_stored:
if self.is_stored: # type: ignore[attr-defined]
# here we are more lenient than the base class, since we allow the modification of some attributes
if keys is None:
raise exceptions.ModificationNotAllowed('Cannot bulk modify attributes of a stored+unsealed node')
Expand Down
14 changes: 11 additions & 3 deletions tests/cmdline/commands/test_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def test_run_workfunction(self):
"""
from aiida.orm import WorkFunctionNode, load_node

script_content = textwrap.dedent(
source_file = textwrap.dedent(
"""\
#!/usr/bin/env python
from aiida.engine import workfunction
Expand All @@ -50,12 +50,19 @@ def wf():
print(node.pk)
"""
)
source_function = textwrap.dedent(
"""\
@workfunction
def wf():
pass
"""
)

# If `verdi run` is not setup correctly, the script above when run with `verdi run` will fail, because when
# the engine will try to create the node for the workfunction and create a copy of its sourcefile, namely the
# script itself, it will use `inspect.getsourcefile` which will return None
with tempfile.NamedTemporaryFile(mode='w+') as fhandle:
fhandle.write(script_content)
fhandle.write(source_file)
fhandle.flush()

options = [fhandle.name]
Expand All @@ -68,7 +75,8 @@ def wf():
# Verify that the node has the correct function name and content
assert isinstance(node, WorkFunctionNode)
assert node.function_name == 'wf'
assert node.get_function_source_code() == script_content
assert node.get_source_code_file() == source_file
assert node.get_source_code_function() == source_function


class TestAutoGroups:
Expand Down
5 changes: 5 additions & 0 deletions tests/engine/test_process_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,8 @@ def test_exit_status():
@pytest.mark.usefixtures('aiida_profile_clean')
def test_source_code_attributes():
"""Verify function properties are properly introspected and stored in the nodes attributes and repository."""
import inspect

function_name = 'test_process_function'

@calcfunction
Expand All @@ -179,6 +181,9 @@ def test_process_function(data):
assert node.function_name == function_name
assert isinstance(node.function_starting_line_number, int)

# Check the source code of the function is stored
assert node.get_source_code_function() == inspect.getsource(test_process_function)

# Check that first line number is correct. Note that the first line should correspond
# to the `@workfunction` directive, but since the list is zero-indexed we actually get the
# following line, which should correspond to the function name i.e. `def test_process_function(data)`
Expand Down