Skip to content

Commit

Permalink
added tests for output_file_template/output_field_name and various un…
Browse files Browse the repository at this point in the history
…covered cases in shell_task decorator
  • Loading branch information
tclose committed May 16, 2023
1 parent 9247daf commit 1b3f72e
Show file tree
Hide file tree
Showing 2 changed files with 114 additions and 21 deletions.
27 changes: 6 additions & 21 deletions pydra/mark/shell.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,8 +151,8 @@ def convert_to_attrs(fields: dict[str, dict[str, ty.Any]], attrs_func):
)
try:
Inputs = klass.Inputs
except KeyError:
raise AttributeError(
except AttributeError:
raise RuntimeError(
"Classes decorated by `shell_task` should contain an `Inputs` class "
"attribute specifying the inputs to the shell tool"
)
Expand Down Expand Up @@ -198,16 +198,6 @@ def convert_to_attrs(fields: dict[str, dict[str, ty.Any]], attrs_func):
"Classes generated by `shell_task` should contain an `executable` "
"attribute specifying the shell tool to run"
)
if not hasattr(task_klass, "Inputs"):
raise RuntimeError(
"Classes generated by `shell_task` should contain an `Inputs` class "
"attribute specifying the inputs to the shell tool"
)
if not hasattr(task_klass, "Outputs"):
raise RuntimeError(
"Classes generated by `shell_task` should contain an `Outputs` class "
"attribute specifying the outputs to the shell tool"
)

task_klass.input_spec = pydra.engine.specs.SpecInfo(
name=f"{name}Inputs", fields=[], bases=(task_klass.Inputs,)
Expand Down Expand Up @@ -381,25 +371,20 @@ def shell_out(
)


def _gen_output_template_fields(Inputs: type, Outputs: type) -> tuple[dict, dict]:
def _gen_output_template_fields(Inputs: type, Outputs: type) -> dict:
"""Auto-generates output fields for inputs that specify an 'output_file_template'
Parameters
----------
Inputs : type
Input specification class
Inputs specification class
Outputs : type
Output specification class
Outputs specification class
Returns
-------
template_fields: dict[str, attrs._CountingAttribute]
template_fields: dict[str, attrs._make_CountingAttribute]
the template fields to add to the output spec
Raises
------
RuntimeError
_description_
"""
annotations = {}
template_fields = {"__annotations__": annotations}
Expand Down
108 changes: 108 additions & 0 deletions pydra/mark/tests/test_shell.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,53 @@ class Inputs:
return A


def test_shell_output_file_template(A):
assert "y" in [a.name for a in attrs.fields(A.Outputs)]


def test_shell_output_field_name_static():
@shell_task
class A:
executable = "cp"

class Inputs:
x: os.PathLike = shell_arg(
help_string="an input file", argstr="", position=0
)
y: str = shell_arg(
help_string="path of output file",
output_file_template="{x}_out",
output_field_name="y_out",
argstr="",
)

assert "y_out" in [a.name for a in attrs.fields(A.Outputs)]


def test_shell_output_field_name_dynamic():
A = shell_task(
"A",
executable="cp",
input_fields={
"x": {
"type": os.PathLike,
"help_string": "an input file",
"argstr": "",
"position": 0,
},
"y": {
"type": str,
"help_string": "path of output file",
"argstr": "",
"output_field_name": "y_out",
"output_file_template": "{x}_out",
},
},
)

assert "y_out" in [a.name for a in attrs.fields(A.Outputs)]


def get_file_size(y: Path):
result = os.stat(y)
return result.st_size
Expand Down Expand Up @@ -357,3 +404,64 @@ class Inputs(A.Inputs):

result = b()
assert result.output.entries == [".", "..", ".hidden"]


def test_shell_missing_executable_static():
with pytest.raises(RuntimeError, match="should contain an `executable`"):

@shell_task
class A:
class Inputs:
directory: os.PathLike = shell_arg(
help_string="input directory", argstr="", position=-1
)

class Outputs:
entries: list = shell_out(
help_string="list of entries returned by ls command",
callable=list_entries,
)


def test_shell_missing_executable_dynamic():
with pytest.raises(RuntimeError, match="should contain an `executable`"):
A = shell_task(
"A",
executable=None,
input_fields={
"directory": {
"type": os.PathLike,
"help_string": "input directory",
"argstr": "",
"position": -1,
}
},
output_fields={
"entries": {
"type": list,
"help_string": "list of entries returned by ls command",
"callable": list_entries,
}
},
)


def test_shell_missing_inputs_static():
with pytest.raises(RuntimeError, match="should contain an `Inputs`"):

@shell_task
class A:
executable = "ls"

class Outputs:
entries: list = shell_out(
help_string="list of entries returned by ls command",
callable=list_entries,
)


def test_shell_decorator_misuse(A):
with pytest.raises(
RuntimeError, match=("`shell_task` should not be provided any other arguments")
):
shell_task(A, executable="cp")

0 comments on commit 1b3f72e

Please sign in to comment.