Skip to content

Commit

Permalink
SDK - Added support for raw artifact values to ContainerOp
Browse files Browse the repository at this point in the history
  • Loading branch information
Ark-kun committed May 14, 2019
1 parent a2fd16e commit 7bd74cf
Show file tree
Hide file tree
Showing 6 changed files with 185 additions and 4 deletions.
33 changes: 29 additions & 4 deletions sdk/python/kfp/compiler/_op_to_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,11 +118,33 @@ def _parameters_to_json(params: List[dsl.PipelineParam]):
return params


# TODO: artifacts?
def _inputs_to_json(inputs_params: List[dsl.PipelineParam], _artifacts=None):
def _inputs_to_json(
inputs_params: List[dsl.PipelineParam],
input_artifact_paths: Dict[str, str] = None,
input_artifact_arguments: Dict[str, str] = None,
):
"""Converts a list of PipelineParam into an argo `inputs` JSON obj."""
parameters = _parameters_to_json(inputs_params)
return {'parameters': parameters} if parameters else None

# Building the input artifacts section
# Only constant arguments are supported for now
# Constant arguments will be compiled as Argo's input artifact default values, not as real arguments until the artifact passing is implemented
artifacts = []
for name, path in (input_artifact_paths or {}).items():
artifact = {'name': name, 'path': path}
if input_artifact_arguments:
argument = input_artifact_arguments.get(name, None)
if argument:
artifact['raw'] = {'data': str(argument)}
artifacts.append(artifact)
artifacts.sort(key=lambda x: x['name']) #Stabilizing the input artifact ordering

inputs_dict = {}
if parameters:
inputs_dict['parameters'] = parameters
if artifacts:
inputs_dict['artifacts'] = artifacts
return inputs_dict


def _outputs_to_json(op: BaseOp,
Expand Down Expand Up @@ -199,7 +221,10 @@ def _op_to_template(op: BaseOp):
}

# inputs
inputs = _inputs_to_json(processed_op.inputs)
if isinstance(op, dsl.ContainerOp):
inputs = _inputs_to_json(processed_op.inputs, processed_op.input_artifact_paths, processed_op.input_artifact_arguments)
elif isinstance(op, dsl.ResourceOp):
inputs = _inputs_to_json(processed_op.inputs)
if inputs:
template['inputs'] = inputs

Expand Down
12 changes: 12 additions & 0 deletions sdk/python/kfp/dsl/_container_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -865,6 +865,8 @@ def __init__(self,
arguments: StringOrStringList = None,
sidecars: List[Sidecar] = None,
container_kwargs: Dict = None,
input_artifact_paths: Dict[str, str] = None,
input_artifact_arguments: Dict[str, str] = None,
file_outputs: Dict[str, str] = None,
output_artifact_paths : Dict[str, str]=None,
is_exit_handler=False,
Expand All @@ -885,6 +887,10 @@ def __init__(self,
together with the `main` container.
container_kwargs: the dict of additional keyword arguments to pass to the
op's `Container` definition.
input_artifact_paths: Maps artifact input names to local file paths.
At pipeline run time, the value of the input artifact argument is saved to this local file.
input_artifact_arguments: Maps artifact input names to the artifact values.
At pipeline run time, the value of the input artifact argument is saved to a local file specified in the input_artifact_paths map.
file_outputs: Maps output labels to local file paths. At pipeline run time,
the value of a PipelineParam is saved to its corresponding local file. It's
one way for outside world to receive outputs of the container.
Expand Down Expand Up @@ -939,9 +945,15 @@ def _decorated(*args, **kwargs):
setattr(self, attr_to_proxy, _proxy(attr_to_proxy))

# attributes specific to `ContainerOp`
self.input_artifact_paths = input_artifact_paths or {}
self.input_artifact_arguments = input_artifact_arguments or {}
self.file_outputs = file_outputs
self.output_artifact_paths = output_artifact_paths or {}

for artifact_name, artifact_argument in self.input_artifact_arguments.items():
if not isinstance(artifact_argument, str):
raise TypeError('Argument "{}" was passed to the artifact input "{}", but only constant strings are supported at this moment.'.format(str(artifact_argument), artifact_name))

self._metadata = None

self.outputs = {}
Expand Down
4 changes: 4 additions & 0 deletions sdk/python/tests/compiler/compiler_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -460,3 +460,7 @@ def test_tolerations(self):
value='run'))

self._test_op_to_template_yaml(op1, file_base_name='tolerations')

def test_py_input_artifact_raw_value(self):
"""Test pipeline input_artifact_raw_value."""
self._test_py_compile_yaml('input_artifact_raw_value')
60 changes: 60 additions & 0 deletions sdk/python/tests/compiler/testdata/input_artifact_raw_value.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
#!/usr/bin/env python3
# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import sys
from pathlib import Path

sys.path.insert(0, __file__ + '/../../../../')

import kfp.dsl as dsl

def component_with_input_artifact(text):
'''A component that passes text as input artifact'''

text_input_path = '/inputs/text/data'
return dsl.ContainerOp(
name='component_with_input_artifact',
input_artifact_paths={'text': text_input_path},
input_artifact_arguments={'text': text},
image='alpine',
command=['cat', text_input_path],
)

def component_with_hardcoded_input_artifact_value():
'''A component that passes hard-coded text as input artifact'''
return component_with_input_artifact('hard-coded artifact value')


def component_with_input_artifact_value_from_file(file_path):
'''A component that passes contents of a file as input artifact'''
return component_with_input_artifact(Path(file_path).read_text())


file_path = str(Path(__file__).parent.joinpath('input_artifact_raw_value.txt'))


@dsl.pipeline(
name='Pipeline with artifact input raw argument value.',
description='Pipeline shows how to define artifact inputs and pass raw artifacts to them.'
)
def retry_sample_pipeline():
component_with_input_artifact('Constant artifact value')
component_with_hardcoded_input_artifact_value()
component_with_input_artifact_value_from_file(file_path)

if __name__ == '__main__':
import kfp.compiler as compiler
compiler.Compiler().compile(retry_sample_pipeline, __file__ + '.tar.gz')
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Text from a file with hard-coded artifact value
79 changes: 79 additions & 0 deletions sdk/python/tests/compiler/testdata/input_artifact_raw_value.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
apiVersion: argoproj.io/v1alpha1
kind: Workflow
metadata:
generateName: pipeline-with-artifact-input-raw-argument-value-
spec:
arguments:
parameters: []
entrypoint: pipeline-with-artifact-input-raw-argument-value
serviceAccountName: pipeline-runner
templates:
- container:
command:
- cat
- /inputs/text/data
image: alpine
inputs:
artifacts:
- name: text
path: /inputs/text/data
raw:
data: Constant artifact value
name: component-with-input-artifact
outputs:
artifacts:
- name: mlpipeline-ui-metadata
path: /mlpipeline-ui-metadata.json
optional: true
- name: mlpipeline-metrics
path: /mlpipeline-metrics.json
optional: true
- container:
command:
- cat
- /inputs/text/data
image: alpine
inputs:
artifacts:
- name: text
path: /inputs/text/data
raw:
data: hard-coded artifact value
name: component-with-input-artifact-2
outputs:
artifacts:
- name: mlpipeline-ui-metadata
path: /mlpipeline-ui-metadata.json
optional: true
- name: mlpipeline-metrics
path: /mlpipeline-metrics.json
optional: true
- container:
command:
- cat
- /inputs/text/data
image: alpine
inputs:
artifacts:
- name: text
path: /inputs/text/data
raw:
data: Text from a file with hard-coded artifact value
name: component-with-input-artifact-3
outputs:
artifacts:
- name: mlpipeline-ui-metadata
path: /mlpipeline-ui-metadata.json
optional: true
- name: mlpipeline-metrics
path: /mlpipeline-metrics.json
optional: true
- dag:
tasks:
- name: component-with-input-artifact
template: component-with-input-artifact
- name: component-with-input-artifact-2
template: component-with-input-artifact-2
- name: component-with-input-artifact-3
template: component-with-input-artifact-3
name: pipeline-with-artifact-input-raw-argument-value

0 comments on commit 7bd74cf

Please sign in to comment.