diff --git a/flytekit/core/task.py b/flytekit/core/task.py index 402862be74..2588248488 100644 --- a/flytekit/core/task.py +++ b/flytekit/core/task.py @@ -12,7 +12,7 @@ from flytekit.core import launch_plan as _annotated_launchplan from flytekit.core import workflow as _annotated_workflow from flytekit.core.base_task import PythonTask, TaskMetadata, TaskResolverMixin -from flytekit.core.interface import transform_function_to_interface +from flytekit.core.interface import Interface, output_name_generator, transform_function_to_interface from flytekit.core.pod_template import PodTemplate from flytekit.core.python_function_task import PythonFunctionTask from flytekit.core.reference_entity import ReferenceEntity, TaskReference @@ -416,3 +416,44 @@ def wrapper(fn) -> ReferenceTask: return ReferenceTask(project, domain, name, version, interface.inputs, interface.outputs) return wrapper + + +class Echo(PythonTask): + _TASK_TYPE = "echo" + + def __init__(self, name: str, inputs: Optional[Dict[str, Type]] = None, **kwargs): + """ + A task that simply echoes the inputs back to the user. + The task's inputs and outputs interface are the same. + + FlytePropeller uses echo plugin to handle this task, and it won't create a pod for this task. + It will simply pass the inputs to the outputs. + https://github.com/flyteorg/flyte/blob/master/flyteplugins/go/tasks/plugins/testing/echo.go + + Note: Make sure to enable the echo plugin in the propeller config to use this task. + ``` + task-plugins: + enabled-plugins: + - echo + ``` + + :param name: The name of the task. + :param inputs: Name and type of inputs specified as a dictionary. + e.g. {"a": int, "b": str}. + :param kwargs: All other args required by the parent type - PythonTask. + + """ + outputs = dict(zip(output_name_generator(len(inputs)), inputs.values())) if inputs else None + super().__init__( + task_type=self._TASK_TYPE, + name=name, + interface=Interface(inputs=inputs, outputs=outputs), + **kwargs, + ) + + def execute(self, **kwargs) -> Any: + values = list(kwargs.values()) + if len(values) == 1: + return values[0] + else: + return tuple(values) diff --git a/tests/flytekit/unit/core/test_conditions.py b/tests/flytekit/unit/core/test_conditions.py index b3bf0c5eab..53a924d697 100644 --- a/tests/flytekit/unit/core/test_conditions.py +++ b/tests/flytekit/unit/core/test_conditions.py @@ -9,6 +9,7 @@ from flytekit import task, workflow from flytekit.configuration import Image, ImageConfig, SerializationSettings from flytekit.core.condition import conditional +from flytekit.core.task import Echo from flytekit.models.core.workflow import Node from flytekit.tools.translator import get_serializable @@ -495,3 +496,41 @@ def multiplier_2(my_input: float) -> float: res = multiplier_2(my_input=10.0) assert res == 20 + + +def test_echo_in_condition(): + echo1 = Echo(name="echo", inputs={"a": typing.Optional[float]}) + + @task() + def t1(radius: float) -> typing.Optional[float]: + return 2 * 3.14 * radius + + @workflow + def wf1(radius: float) -> typing.Optional[float]: + return ( + conditional("shape_properties_with_multiple_branches") + .if_((radius >= 0.1) & (radius < 1.0)) + .then(t1(radius=radius)) + .else_() + .then(echo1(a=radius)) + ) + + assert wf1(radius=1.8) == 1.8 + + echo2 = Echo(name="echo", inputs={"a": float, "b": float}) + + @task() + def t2(radius: float) -> typing.Tuple[float, float]: + return 2 * 3.14 * radius, 2 * 3.14 * radius + + @workflow + def wf2(radius1: float, radius2: float) -> typing.Tuple[float, float]: + return ( + conditional("shape_properties_with_multiple_branches") + .if_((radius1 >= 0.1) & (radius1 < 1.0)) + .then(t2(radius=radius2)) + .else_() + .then(echo2(a=radius1, b=radius2)) + ) + + assert wf2(radius1=1.8, radius2=1.8) == (1.8, 1.8)