Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(sdk): Allow keyword-only arguments in pipeline function signature #4544

Merged
11 changes: 8 additions & 3 deletions sdk/python/kfp/compiler/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -802,17 +802,22 @@ def _create_workflow(self,
raise ValueError('Either specify pipeline params in the pipeline function, or in "params_list", but not both.')

args_list = []
kwargs_dict = dict()
signature = inspect.signature(pipeline_func)
for arg_name in signature.parameters:
for arg_name, arg in signature.parameters.items():
arg_type = None
for input in pipeline_meta.inputs or []:
if arg_name == input.name:
arg_type = input.type
break
args_list.append(dsl.PipelineParam(sanitize_k8s_name(arg_name, True), param_type=arg_type))
param = dsl.PipelineParam(sanitize_k8s_name(arg_name, True), param_type=arg_type)
if arg.kind == inspect.Parameter.KEYWORD_ONLY:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we can put all arguments in a dictionary and then just use signature.bind which will sort out what goes where?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think it would work with position-only arguments:

def foo(a: str, /, b: str, *, c: str):
    print(a, b, c)

import inspect
sig = inspect.signature(foo)

kargs = {"a": "A", "b": "B", "c": "C"}
bound_arguments = sig.bind(**kwargs) # Error: "a" is positional-only

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Position-only arguments won't work with kwargs in signature.bind:

def foo(a: str, /, b: str, *, c: str): ...
    
import inspect
sig = inspect.signature(foo)

kwargs = {"a": "A", "b": "B", "c": "C"}
bound_arguments = sig.bind(**kwargs) # Error: "a" is positional-only

args = ["A"]
kwargs = {"b": "B", "c": "C"}
bound_arguments = sig.bind(**kwargs) # ok, no error

So it seems to me that the easiest approach is to only bound keyword-only arguments to kwargs, and the rest arguments to args.

kwargs_dict[arg_name] = param
else:
args_list.append(param)

with dsl.Pipeline(pipeline_name) as dsl_pipeline:
pipeline_func(*args_list)
pipeline_func(*args_list, **kwargs_dict)

pipeline_conf = pipeline_conf or dsl_pipeline.conf # Configuration passed to the compiler is overriding. Unfortunately, it's not trivial to detect whether the dsl_pipeline.conf was ever modified.

Expand Down
43 changes: 43 additions & 0 deletions sdk/python/tests/compiler/compiler_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List

import kfp
import kfp.compiler as compiler
Expand Down Expand Up @@ -1043,3 +1044,45 @@ def some_pipeline():
parameter_arguments_json = template['metadata']['annotations']['pipelines.kubeflow.org/arguments.parameters']
parameter_arguments = json.loads(parameter_arguments_json)
self.assertEqual(set(parameter_arguments.keys()), {'Input 1'})

def test_keyword_only_argument_for_pipeline_func(self):
def some_pipeline(casual_argument: str, *, keyword_only_argument: str):
pass
kfp.compiler.Compiler()._create_workflow(some_pipeline)

def test_keyword_only_argument_for_pipeline_func_identity(self):
test_data_dir = os.path.join(os.path.dirname(__file__), 'testdata')
sys.path.append(test_data_dir)
import basic

pipeline_func_arg = basic.save_most_frequent_word
Udiknedormin marked this conversation as resolved.
Show resolved Hide resolved

# clone name and description
@dsl.pipeline(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps @pipeline is not needed here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It should have the same name and description as save_most_frequent_word, so it seems to me it does.

name = pipeline_func_arg._component_human_name,
description = pipeline_func_arg._component_description
)
def pipeline_func_kwarg(*args, **kwargs):
return basic.save_most_frequent_word(*args, **kwargs)

# clone signature, but changing all arguments to keyword-only
import inspect
sig = inspect.signature(pipeline_func_arg)
new_parameters = [
param.replace(kind = inspect.Parameter.KEYWORD_ONLY)
for param in sig.parameters.values()
]
new_sig = sig.replace(parameters = new_parameters)
pipeline_func_kwarg.__signature__ = new_sig

pipeline_yaml_arg = kfp.compiler.Compiler()._create_workflow(pipeline_func_arg)
pipeline_yaml_kwarg = kfp.compiler.Compiler()._create_workflow(pipeline_func_kwarg)

# the yamls may differ in creation time, remove it
def remove_creation_time(yaml) -> None:
del yaml['metadata']['annotations']['pipelines.kubeflow.org/pipeline_compilation_time']
Udiknedormin marked this conversation as resolved.
Show resolved Hide resolved
remove_creation_time(pipeline_yaml_arg)
remove_creation_time(pipeline_yaml_kwarg)

# compare
assert pipeline_yaml_arg == pipeline_yaml_kwarg
Udiknedormin marked this conversation as resolved.
Show resolved Hide resolved