Skip to content

Commit

Permalink
fixed up inheritance of Inputs and Outputs
Browse files Browse the repository at this point in the history
  • Loading branch information
tclose committed May 16, 2023
1 parent bd3fefc commit 9247daf
Show file tree
Hide file tree
Showing 2 changed files with 100 additions and 21 deletions.
44 changes: 33 additions & 11 deletions pydra/mark/shell.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ def shell_task(
dct = {"__annotations__": annotations}

if isinstance(klass_or_name, str):
# Dynamically created classes using shell_task as a function
name = klass_or_name

if executable is not None:
Expand All @@ -77,6 +78,19 @@ def ensure_base_included(base_class: type, bases_list: list[type]):
if not any(issubclass(b, base_class) for b in bases_list):
bases_list.append(base_class)

# Get inputs and outputs bases from base class if not explicitly provided
for base in bases:
if not inputs_bases:
try:
inputs_bases = [base.Inputs]
except AttributeError:
pass
if not outputs_bases:
try:
outputs_bases = [base.Outputs]
except AttributeError:
pass

# Ensure bases are lists and can be modified
ensure_base_included(pydra.engine.task.ShellCommandTask, bases)
ensure_base_included(pydra.engine.specs.ShellSpec, inputs_bases)
Expand Down Expand Up @@ -108,6 +122,7 @@ def convert_to_attrs(fields: dict[str, dict[str, ty.Any]], attrs_func):
)

else:
# Statically defined classes using shell_task as decorator
if (
executable,
input_fields,
Expand Down Expand Up @@ -147,8 +162,12 @@ def convert_to_attrs(fields: dict[str, dict[str, ty.Any]], attrs_func):
except AttributeError:
Outputs = type("Outputs", (pydra.engine.specs.ShellOutSpec,), {})

Inputs = attrs.define(kw_only=True, slots=False)(Inputs)
Outputs = attrs.define(kw_only=True, slots=False)(Outputs)
# Pass Inputs and Outputs in attrs.define if they are present in klass (i.e.
# not in a base class)
if "Inputs" in klass.__dict__:
Inputs = attrs.define(kw_only=True, slots=False)(Inputs)
if "Outputs" in klass.__dict__:
Outputs = attrs.define(kw_only=True, slots=False)(Outputs)

if not issubclass(Inputs, pydra.engine.specs.ShellSpec):
Inputs = attrs.define(kw_only=True, slots=False)(
Expand All @@ -159,12 +178,12 @@ def convert_to_attrs(fields: dict[str, dict[str, ty.Any]], attrs_func):

if not issubclass(Outputs, pydra.engine.specs.ShellOutSpec):
outputs_bases = (Outputs, pydra.engine.specs.ShellOutSpec)
wrap_output = True
add_base_class = True
else:
outputs_bases = (Outputs,)
wrap_output = False
add_base_class = False

if wrap_output or template_fields:
if add_base_class or template_fields:
Outputs = attrs.define(kw_only=True, slots=False)(
type("Outputs", outputs_bases, template_fields)
)
Expand All @@ -173,12 +192,7 @@ def convert_to_attrs(fields: dict[str, dict[str, ty.Any]], attrs_func):
dct["Outputs"] = Outputs

task_klass = type(name, tuple(bases), dct)
task_klass.input_spec = pydra.engine.specs.SpecInfo(
name=f"{name}Inputs", fields=[], bases=(task_klass.Inputs,)
)
task_klass.output_spec = pydra.engine.specs.SpecInfo(
name=f"{name}Outputs", fields=[], bases=(task_klass.Outputs,)
)

if not hasattr(task_klass, "executable"):
raise RuntimeError(
"Classes generated by `shell_task` should contain an `executable` "
Expand All @@ -194,6 +208,14 @@ def convert_to_attrs(fields: dict[str, dict[str, ty.Any]], attrs_func):
"Classes generated by `shell_task` should contain an `Outputs` class "
"attribute specifying the outputs to the shell tool"
)

task_klass.input_spec = pydra.engine.specs.SpecInfo(
name=f"{name}Inputs", fields=[], bases=(task_klass.Inputs,)
)
task_klass.output_spec = pydra.engine.specs.SpecInfo(
name=f"{name}Outputs", fields=[], bases=(task_klass.Outputs,)
)

return task_klass


Expand Down
77 changes: 67 additions & 10 deletions pydra/mark/tests/test_shell.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ class Outputs:
return Ls


def test_shell_task_fields(Ls):
def test_shell_fields(Ls):
assert [a.name for a in attrs.fields(Ls.Inputs)] == [
"executable",
"args",
Expand All @@ -152,7 +152,7 @@ def test_shell_task_fields(Ls):
]


def test_shell_task_pickle_roundtrip(Ls, tmpdir):
def test_shell_pickle_roundtrip(Ls, tmpdir):
pkl_file = tmpdir / "ls.pkl"
with open(pkl_file, "wb") as f:
cp.dump(Ls, f)
Expand All @@ -163,7 +163,7 @@ def test_shell_task_pickle_roundtrip(Ls, tmpdir):
assert RereadLs is Ls


def test_shell_task_run(Ls, tmpdir):
def test_shell_run(Ls, tmpdir):
Path.touch(tmpdir / "a")
Path.touch(tmpdir / "b")
Path.touch(tmpdir / "c")
Expand Down Expand Up @@ -196,7 +196,7 @@ class Inputs:
help_string="an input file", argstr="", position=0
)
y: str = shell_arg(
help_string="an input file",
help_string="path of output file",
output_file_template="{x}_out",
argstr="",
)
Expand All @@ -214,7 +214,7 @@ class Inputs:
},
"y": {
"type": str,
"help_string": "an output file",
"help_string": "path of output file",
"argstr": "",
"output_file_template": "{x}_out",
},
Expand All @@ -231,7 +231,7 @@ def get_file_size(y: Path):
return result.st_size


def test_shell_task_bases_dynamic(A, tmpdir):
def test_shell_bases_dynamic(A, tmpdir):
B = shell_task(
"B",
output_fields={
Expand All @@ -256,7 +256,7 @@ def test_shell_task_bases_dynamic(A, tmpdir):
assert result.output.y == str(ypath)


def test_shell_task_bases_static(A, tmpdir):
def test_shell_bases_static(A, tmpdir):
@shell_task
class B(A):
class Outputs:
Expand All @@ -276,12 +276,24 @@ class Outputs:
assert result.output.y == str(ypath)


def test_shell_task_dynamic_inputs_bases(tmpdir):
def test_shell_inputs_outputs_bases_dynamic(tmpdir):
A = shell_task(
"A",
"ls",
input_fields={
"directory": {"type": os.PathLike, "help_string": "input directory"}
"directory": {
"type": os.PathLike,
"help_string": "input directory",
"argstr": "",
"position": -1,
}
},
output_fields={
"entries": {
"type": list,
"help_string": "list of entries returned by ls command",
"callable": list_entries,
}
},
)
B = shell_task(
Expand All @@ -290,13 +302,58 @@ def test_shell_task_dynamic_inputs_bases(tmpdir):
input_fields={
"hidden": {
"type": bool,
"argstr": "-a",
"help_string": "show hidden files",
"default": False,
}
},
bases=[A],
inputs_bases=[A.Inputs],
)

b = B(directory=tmpdir)
Path.touch(tmpdir / ".hidden")

b = B(directory=tmpdir, hidden=True)

assert b.inputs.directory == tmpdir
assert b.inputs.hidden
assert b.cmdline == f"ls -a {tmpdir}"

result = b()
assert result.output.entries == [".", "..", ".hidden"]


def test_shell_inputs_outputs_bases_static(tmpdir):
@shell_task
class A:
executable = "ls"

class Inputs:
directory: os.PathLike = shell_arg(
help_string="input directory", argstr="", position=-1
)

class Outputs:
entries: list = shell_out(
help_string="list of entries returned by ls command",
callable=list_entries,
)

@shell_task
class B(A):
class Inputs(A.Inputs):
hidden: bool = shell_arg(
help_string="show hidden files",
argstr="-a",
default=False,
)

Path.touch(tmpdir / ".hidden")

b = B(directory=tmpdir, hidden=True)

assert b.inputs.directory == tmpdir
assert b.inputs.hidden

result = b()
assert result.output.entries == [".", "..", ".hidden"]

0 comments on commit 9247daf

Please sign in to comment.