Skip to content

Commit

Permalink
fix(sdk): Allow keyword-only arguments in pipeline function signature (
Browse files Browse the repository at this point in the history
…#4544)

* add test for keyword-only arguments in pipeline func

* fix: kwargs-only argument for pipeline func

* test: kwargs generate same yaml as args

* remove whole metadata

* assert -> self.assertEqual

* programmatic example --> fixed example

* same name for both

Co-authored-by: Alexey Volkov <[email protected]>
  • Loading branch information
Udiknedormin and Ark-kun authored Jan 30, 2021
1 parent cad02dc commit ce985bc
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 3 deletions.
11 changes: 8 additions & 3 deletions sdk/python/kfp/compiler/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -842,17 +842,22 @@ def _create_workflow(self,
raise ValueError('Either specify pipeline params in the pipeline function, or in "params_list", but not both.')

args_list = []
kwargs_dict = dict()
signature = inspect.signature(pipeline_func)
for arg_name in signature.parameters:
for arg_name, arg in signature.parameters.items():
arg_type = None
for input in pipeline_meta.inputs or []:
if arg_name == input.name:
arg_type = input.type
break
args_list.append(dsl.PipelineParam(sanitize_k8s_name(arg_name, True), param_type=arg_type))
param = dsl.PipelineParam(sanitize_k8s_name(arg_name, True), param_type=arg_type)
if arg.kind == inspect.Parameter.KEYWORD_ONLY:
kwargs_dict[arg_name] = param
else:
args_list.append(param)

with dsl.Pipeline(pipeline_name) as dsl_pipeline:
pipeline_func(*args_list)
pipeline_func(*args_list, **kwargs_dict)

pipeline_conf = pipeline_conf or dsl_pipeline.conf # Configuration passed to the compiler is overriding. Unfortunately, it's not trivial to detect whether the dsl_pipeline.conf was ever modified.

Expand Down
37 changes: 37 additions & 0 deletions sdk/python/tests/compiler/compiler_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List

import kfp
import kfp.compiler as compiler
Expand Down Expand Up @@ -1108,3 +1109,39 @@ def test__resolve_task_pipeline_param(self):

def test_uri_artifact_passing(self):
self._test_py_compile_yaml('uri_artifacts')

def test_keyword_only_argument_for_pipeline_func(self):
def some_pipeline(casual_argument: str, *, keyword_only_argument: str):
pass
kfp.compiler.Compiler()._create_workflow(some_pipeline)

def test_keyword_only_argument_for_pipeline_func_identity(self):
test_data_dir = os.path.join(os.path.dirname(__file__), 'testdata')
sys.path.append(test_data_dir)

# `@pipeline` is needed to make name the same for both functions

@pipeline(name="pipeline_func")
def pipeline_func_arg(foo_arg: str, bar_arg: str):
dsl.ContainerOp(
name='foo',
image='foo',
command=['bar'],
arguments=[foo_arg, ' and ', bar_arg]
)

@pipeline(name="pipeline_func")
def pipeline_func_kwarg(foo_arg: str, *, bar_arg: str):
return pipeline_func_arg(foo_arg, bar_arg)

pipeline_yaml_arg = kfp.compiler.Compiler()._create_workflow(pipeline_func_arg)
pipeline_yaml_kwarg = kfp.compiler.Compiler()._create_workflow(pipeline_func_kwarg)

# the yamls may differ in metadata
def remove_metadata(yaml) -> None:
del yaml['metadata']
remove_metadata(pipeline_yaml_arg)
remove_metadata(pipeline_yaml_kwarg)

# compare
self.assertEqual(pipeline_yaml_arg, pipeline_yaml_kwarg)

0 comments on commit ce985bc

Please sign in to comment.