From a69ed93c5ed03056c14e72e053a897f8814feefb Mon Sep 17 00:00:00 2001 From: Chi-Sheng Liu Date: Thu, 20 Jun 2024 15:57:55 +0800 Subject: [PATCH] feat: Support positional arguments Resolves: flyteorg/flyte#5320 Signed-off-by: Chi-Sheng Liu --- flytekit/core/base_sql_task.py | 9 +++---- flytekit/core/container_task.py | 2 +- flytekit/core/gate.py | 15 ++++-------- flytekit/core/interface.py | 28 +++++++++++----------- flytekit/core/launch_plan.py | 8 ++++++- flytekit/core/promise.py | 40 +++++++++++++++---------------- flytekit/core/reference.py | 6 ++--- flytekit/core/reference_entity.py | 6 ++--- flytekit/core/task.py | 6 ++--- flytekit/core/workflow.py | 8 ++++++- 10 files changed, 67 insertions(+), 61 deletions(-) diff --git a/flytekit/core/base_sql_task.py b/flytekit/core/base_sql_task.py index 30b73223a9f..534b5288822 100644 --- a/flytekit/core/base_sql_task.py +++ b/flytekit/core/base_sql_task.py @@ -1,5 +1,6 @@ import re -from typing import Any, Dict, Optional, Tuple, Type, TypeVar +from collections import OrderedDict +from typing import Any, Optional, Tuple, Type, TypeVar from flytekit.core.base_task import PythonTask, TaskMetadata from flytekit.core.interface import Interface @@ -24,9 +25,9 @@ def __init__( query_template: str, task_config: Optional[T] = None, task_type="sql_task", - inputs: Optional[Dict[str, Tuple[Type, Any]]] = None, + inputs: Optional[OrderedDict[str, Tuple[Type, Any]]] = None, metadata: Optional[TaskMetadata] = None, - outputs: Optional[Dict[str, Type]] = None, + outputs: Optional[OrderedDict[str, Type]] = None, **kwargs, ): """ @@ -36,7 +37,7 @@ def __init__( super().__init__( task_type=task_type, name=name, - interface=Interface(inputs=inputs or {}, outputs=outputs or {}), + interface=Interface(inputs=inputs or OrderedDict(), outputs=outputs or OrderedDict()), metadata=metadata, task_config=task_config, **kwargs, diff --git a/flytekit/core/container_task.py b/flytekit/core/container_task.py index 66fe522c070..adb5003ea12 100644 --- a/flytekit/core/container_task.py +++ b/flytekit/core/container_task.py @@ -49,7 +49,7 @@ def __init__( inputs: Optional[OrderedDict[str, Type]] = None, metadata: Optional[TaskMetadata] = None, arguments: Optional[List[str]] = None, - outputs: Optional[Dict[str, Type]] = None, + outputs: Optional[OrderedDict[str, Type]] = None, requests: Optional[Resources] = None, limits: Optional[Resources] = None, input_data_dir: Optional[str] = None, diff --git a/flytekit/core/gate.py b/flytekit/core/gate.py index 76851657434..c147b8134a4 100644 --- a/flytekit/core/gate.py +++ b/flytekit/core/gate.py @@ -2,6 +2,7 @@ import datetime import typing +from collections import OrderedDict from typing import Tuple, Union import click @@ -49,11 +50,7 @@ def __init__( self._python_interface = flyte_interface.Interface() elif input_type: # Waiting for user input, so the output of the node is whatever input the user provides. - self._python_interface = flyte_interface.Interface( - outputs={ - "o0": self.input_type, - } - ) + self._python_interface = flyte_interface.Interface(outputs=OrderedDict([("o0", self.input_type)])) else: # We don't know how to find the python interface here, approve() sets it below, See the code. self._python_interface = None # type: ignore @@ -205,12 +202,8 @@ def approve(upstream_item: Union[Tuple[Promise], Promise, VoidPromise], name: st # In either case, we need a python interface g._python_interface = flyte_interface.Interface( - inputs={ - io_var_name: io_type, - }, - outputs={ - io_var_name: io_type, - }, + inputs=OrderedDict([(io_var_name, io_type)]), + outputs=OrderedDict([(io_var_name, io_type)]), ) kwargs = {io_var_name: upstream_item} diff --git a/flytekit/core/interface.py b/flytekit/core/interface.py index 13b6af2d4bd..d90078c2f8b 100644 --- a/flytekit/core/interface.py +++ b/flytekit/core/interface.py @@ -52,8 +52,8 @@ class Interface(object): def __init__( self, - inputs: Union[Optional[Dict[str, Type]], Optional[Dict[str, Tuple[Type, Any]]]] = None, - outputs: Union[Optional[Dict[str, Type]], Optional[Dict[str, Optional[Type]]]] = None, + inputs: Union[OrderedDict[str, Type], OrderedDict[str, Tuple[Type, Any]], None] = None, + outputs: Union[Dict[str, Type], Dict[str, Optional[Type]], None] = None, output_tuple_name: Optional[str] = None, docstring: Optional[Docstring] = None, ): @@ -67,14 +67,14 @@ def __init__( primarily used when handling one-element NamedTuples. :param docstring: Docstring of the annotated @task or @workflow from which the interface derives from. """ - self._inputs: Union[Dict[str, Tuple[Type, Any]], Dict[str, Type]] = {} # type: ignore + self._inputs: Union[OrderedDict[str, Tuple[Type, Any]], OrderedDict[str, Type]] = OrderedDict() # type: ignore if inputs: for k, v in inputs.items(): if type(v) is tuple and len(cast(Tuple, v)) > 1: self._inputs[k] = v # type: ignore else: self._inputs[k] = (v, None) # type: ignore - self._outputs = outputs if outputs else {} # type: ignore + self._outputs = outputs if outputs else OrderedDict() # type: ignore self._output_tuple_name = output_tuple_name if outputs: @@ -123,8 +123,8 @@ def output_tuple_name(self) -> Optional[str]: return self._output_tuple_name @property - def inputs(self) -> Dict[str, type]: - r = {} + def inputs(self) -> OrderedDict[str, type]: + r = OrderedDict() for k, v in self._inputs.items(): r[k] = v[0] return r @@ -144,7 +144,7 @@ def default_inputs_as_kwargs(self) -> Dict[str, Any]: return {k: v[1] for k, v in self._inputs.items() if v[1] is not None} @property - def outputs(self) -> typing.Dict[str, type]: + def outputs(self) -> OrderedDict[str, type]: return self._outputs # type: ignore @property @@ -313,8 +313,8 @@ def verify_outputs_artifact_bindings( def transform_types_to_list_of_type( - m: Dict[str, type], bound_inputs: typing.Set[str], list_as_optional: bool = False -) -> Dict[str, type]: + m: OrderedDict[str, type], bound_inputs: typing.Set[str], list_as_optional: bool = False +) -> OrderedDict[str, type]: """ Converts unbound inputs into the equivalent (optional) collections. This is useful for array jobs / map style code. It will create a collection of types even if any one these types is not a collection type. @@ -375,7 +375,7 @@ def transform_function_to_interface(fn: typing.Callable, docstring: Optional[Doc outputs = extract_return_annotation(return_annotation) for k, v in outputs.items(): outputs[k] = v # type: ignore - inputs: Dict[str, Tuple[Type, Any]] = OrderedDict() + inputs: OrderedDict[str, Tuple[Type, Any]] = OrderedDict() for k, v in signature.parameters.items(): # type: ignore annotation = type_hints.get(k, None) default = v.default if v.default is not inspect.Parameter.empty else None @@ -446,7 +446,7 @@ def output_name_generator(length: int) -> Generator[str, None, None]: yield default_output_name(x) -def extract_return_annotation(return_annotation: Union[Type, Tuple, None]) -> Dict[str, Type]: +def extract_return_annotation(return_annotation: Union[Type, Tuple, None]) -> OrderedDict[str, Type]: """ The purpose of this function is to sort out whether a function is returning one thing, or multiple things, and to name the outputs accordingly, either by using our default name function, or from a typing.NamedTuple. @@ -481,7 +481,7 @@ def t(a: int, b: str) -> Dict[str, int]: ... # Handle Option 6 # We can think about whether we should add a default output name with type None in the future. if return_annotation in (None, type(None), inspect.Signature.empty): - return {} + return OrderedDict() # This statement results in true for typing.Namedtuple, single and void return types, so this # handles Options 1, 2. Even though NamedTuple for us is multi-valued, it's a single value for Python @@ -491,7 +491,7 @@ def t(a: int, b: str) -> Dict[str, int]: ... bases = return_annotation.__bases__ # type: ignore if len(bases) == 1 and bases[0] == tuple and hasattr(return_annotation, "_fields"): logger.debug(f"Task returns named tuple {return_annotation}") - return dict(get_type_hints(cast(Type, return_annotation), include_extras=True)) + return OrderedDict(get_type_hints(cast(Type, return_annotation), include_extras=True)) if hasattr(return_annotation, "__origin__") and return_annotation.__origin__ is tuple: # type: ignore # Handle option 3 @@ -511,7 +511,7 @@ def t(a: int, b: str) -> Dict[str, int]: ... else: # Handle all other single return types logger.debug(f"Task returns unnamed native tuple {return_annotation}") - return {default_output_name(): cast(Type, return_annotation)} + return OrderedDict([(default_output_name(), cast(Type, return_annotation))]) def remap_shared_output_descriptions(output_descriptions: Dict[str, str], outputs: Dict[str, Type]) -> Dict[str, str]: diff --git a/flytekit/core/launch_plan.py b/flytekit/core/launch_plan.py index 90181848370..4d091b4a0e6 100644 --- a/flytekit/core/launch_plan.py +++ b/flytekit/core/launch_plan.py @@ -484,7 +484,13 @@ class ReferenceLaunchPlan(ReferenceEntity, LaunchPlan): """ def __init__( - self, project: str, domain: str, name: str, version: str, inputs: Dict[str, Type], outputs: Dict[str, Type] + self, + project: str, + domain: str, + name: str, + version: str, + inputs: typing.OrderedDict[str, Type], + outputs: typing.OrderedDict[str, Type], ): super().__init__(LaunchPlanReference(project, domain, name, version), inputs, outputs) diff --git a/flytekit/core/promise.py b/flytekit/core/promise.py index 557d621dd40..c4f71eb2d6a 100644 --- a/flytekit/core/promise.py +++ b/flytekit/core/promise.py @@ -1202,19 +1202,22 @@ def flyte_entity_call_handler( #. Start a local execution - This means that we're not already in a local workflow execution, which means that we should expect inputs to be native Python values and that we should return Python native values. """ - # Sanity checks - # Only keyword args allowed - if len(args) > 0: - raise _user_exceptions.FlyteAssertion( - f"When calling tasks, only keyword args are supported. " - f"Aborting execution as detected {len(args)} positional args {args}" - ) # Make sure arguments are part of interface for k, v in kwargs.items(): - if k not in cast(SupportsNodeCreation, entity).python_interface.inputs: - raise AssertionError( - f"Received unexpected keyword argument '{k}' in function '{cast(SupportsNodeCreation, entity).name}'" - ) + if k not in entity.python_interface.inputs: + raise AssertionError(f"Received unexpected keyword argument '{k}' in function '{entity.name}'") + + # Check if we have more arguments than expected + if len(args) > len(entity.python_interface.inputs): + raise AssertionError( + f"Received more arguments than expected in function '{entity.name}'. Expected {len(entity.python_interface.inputs)} but got {len(args)}" + ) + + # Convert args to kwargs + for arg, input_name in zip(args, entity.python_interface.inputs.keys()): + if input_name in kwargs: + raise AssertionError(f"Got multiple values for argument '{input_name}' in function '{entity.name}'") + kwargs[input_name] = arg ctx = FlyteContextManager.current_context() if ctx.execution_state and ( @@ -1234,15 +1237,12 @@ def flyte_entity_call_handler( child_ctx.execution_state and child_ctx.execution_state.branch_eval_mode == BranchEvalMode.BRANCH_SKIPPED ): - if ( - len(cast(SupportsNodeCreation, entity).python_interface.inputs) > 0 - or len(cast(SupportsNodeCreation, entity).python_interface.outputs) > 0 - ): - output_names = list(cast(SupportsNodeCreation, entity).python_interface.outputs.keys()) + if len(entity.python_interface.inputs) > 0 or len(entity.python_interface.outputs) > 0: + output_names = list(entity.python_interface.outputs.keys()) if len(output_names) == 0: return VoidPromise(entity.name) vals = [Promise(var, None) for var in output_names] - return create_task_output(vals, cast(SupportsNodeCreation, entity).python_interface) + return create_task_output(vals, entity.python_interface) else: return None return cast(LocallyExecutable, entity).local_execute(ctx, **kwargs) @@ -1255,7 +1255,7 @@ def flyte_entity_call_handler( cast(ExecutionParameters, child_ctx.user_space_params)._decks = [] result = cast(LocallyExecutable, entity).local_execute(child_ctx, **kwargs) - expected_outputs = len(cast(SupportsNodeCreation, entity).python_interface.outputs) + expected_outputs = len(entity.python_interface.outputs) if expected_outputs == 0: if result is None or isinstance(result, VoidPromise): return None @@ -1268,10 +1268,10 @@ def flyte_entity_call_handler( if (1 < expected_outputs == len(cast(Tuple[Promise], result))) or ( result is not None and expected_outputs == 1 ): - return create_native_named_tuple(ctx, result, cast(SupportsNodeCreation, entity).python_interface) + return create_native_named_tuple(ctx, result, entity.python_interface) raise AssertionError( f"Expected outputs and actual outputs do not match." f"Result {result}. " - f"Python interface: {cast(SupportsNodeCreation, entity).python_interface}" + f"Python interface: {entity.python_interface}" ) diff --git a/flytekit/core/reference.py b/flytekit/core/reference.py index 6a88549c43a..1c4e04cb214 100644 --- a/flytekit/core/reference.py +++ b/flytekit/core/reference.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Dict, Type +from typing import OrderedDict, Type from flytekit.core.launch_plan import ReferenceLaunchPlan from flytekit.core.task import ReferenceTask @@ -15,8 +15,8 @@ def get_reference_entity( domain: str, name: str, version: str, - inputs: Dict[str, Type], - outputs: Dict[str, Type], + inputs: OrderedDict[str, Type], + outputs: OrderedDict[str, Type], ): """ See the documentation for :py:class:`flytekit.reference_task` and :py:class:`flytekit.reference_workflow` as well. diff --git a/flytekit/core/reference_entity.py b/flytekit/core/reference_entity.py index b54c4d67f6b..113157de574 100644 --- a/flytekit/core/reference_entity.py +++ b/flytekit/core/reference_entity.py @@ -1,6 +1,6 @@ from abc import ABC, abstractmethod from dataclasses import dataclass -from typing import Any, Dict, Optional, Tuple, Type, Union +from typing import Any, Optional, OrderedDict, Tuple, Type, Union from flytekit.core.context_manager import BranchEvalMode, ExecutionState, FlyteContext from flytekit.core.interface import Interface, transform_interface_to_typed_interface @@ -71,8 +71,8 @@ class ReferenceEntity(object): def __init__( self, reference: Union[WorkflowReference, TaskReference, LaunchPlanReference], - inputs: Dict[str, Type], - outputs: Dict[str, Type], + inputs: OrderedDict[str, Type], + outputs: OrderedDict[str, Type], ): if ( not isinstance(reference, WorkflowReference) diff --git a/flytekit/core/task.py b/flytekit/core/task.py index d30947509d3..0e0c035e54d 100644 --- a/flytekit/core/task.py +++ b/flytekit/core/task.py @@ -2,7 +2,7 @@ import datetime from functools import update_wrapper -from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Type, TypeVar, Union, overload +from typing import Any, Callable, Dict, Iterable, List, Optional, OrderedDict, Tuple, Type, TypeVar, Union, overload from flytekit.core import launch_plan as _annotated_launchplan from flytekit.core import workflow as _annotated_workflow @@ -366,8 +366,8 @@ def __init__( domain: str, name: str, version: str, - inputs: Dict[str, type], - outputs: Dict[str, Type], + inputs: OrderedDict[str, type], + outputs: OrderedDict[str, Type], ): super().__init__(TaskReference(project, domain, name, version), inputs, outputs) diff --git a/flytekit/core/workflow.py b/flytekit/core/workflow.py index 58f81579831..c407ee81f2f 100644 --- a/flytekit/core/workflow.py +++ b/flytekit/core/workflow.py @@ -886,7 +886,13 @@ class ReferenceWorkflow(ReferenceEntity, PythonFunctionWorkflow): # type: ignor """ def __init__( - self, project: str, domain: str, name: str, version: str, inputs: Dict[str, Type], outputs: Dict[str, Type] + self, + project: str, + domain: str, + name: str, + version: str, + inputs: typing.OrderedDict[str, Type], + outputs: typing.OrderedDict[str, Type], ): super().__init__(WorkflowReference(project, domain, name, version), inputs, outputs)