-
Notifications
You must be signed in to change notification settings - Fork 1.6k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Sample] Add new TFX::OSS sample (#2319)
* init. * Patched Ajay's sample * Clean up the sample and add preload config. * Fix default value * Remove old file * Add compiled tfx sample * Add compiled pipeline and move tfx sample to contrib to prevent dependency issue. * Add readme and remove redundant params * Add inline comments. * Add description * Add sample test * fix test name. * fix test dir * fix data path. * Fix pipeline_root
- Loading branch information
1 parent
ca17faa
commit 361fbee
Showing
8 changed files
with
232 additions
and
7 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
# Parameterized TFX pipeline sample | ||
|
||
[Tensorflow Extended (TFX)](https://github.com/tensorflow/tfx) is a Google-production-scale machine | ||
learning platform based on TensorFlow. It provides a configuration framework to express ML pipelines | ||
consisting of TFX components. Kubeflow Pipelines can be used as the orchestrator supporting the | ||
execution of a TFX pipeline. | ||
|
||
This sample demonstrates how to author a ML pipeline in TFX and run it on a KFP deployment. | ||
Please refer to inline comments for the purpose of each step. | ||
|
||
In order to successfully compile this sample, you'll need to have a TFX installation at HEAD. | ||
First, you can clone their repo and | ||
then point the version tag in `tfx/version.py` to TFX's latest nightly build image of version | ||
`0.15.0dev` (e.g., `0.15.0dev20191007`, list of image available can be found [here](https://hub.docker.com/r/tensorflow/tfx/tags)). | ||
Finally, run `python setup.py install` from `tfx/tfx`. After that, running | ||
`chicago_taxi_pipeline_simple.py` compiles the TFX pipeline into KFP pipeline package. | ||
This pipeline requires google storage permission to run. | ||
|
||
## Caveats | ||
|
||
This sample uses pipeline parameters in a TFX pipeline, which is not yet fully supported. | ||
See [here](https://github.com/tensorflow/tfx/issues/362) for more details. In this sample, however, | ||
the path to module file and path to data are parameterized. This is achieved by specifying those | ||
objects `dsl.PipelineParam` and appending them to the `KubeflowDagRunner._params`. Then, | ||
KubeflowDagRunner can correctly identify those pipeline parameters and interpret them as Argo | ||
placeholder correctly when compilation. However, this parameterization approach is a hack and | ||
we do not have plan for long-term support. Instead we're working with TFX team to support | ||
pipeline parameterization using their [RuntimeParameter](https://github.com/tensorflow/tfx/blob/46bb4f975c36ea1defde4b3c33553e088b3dc5b8/tfx/orchestration/data_types.py#L108). | ||
|
||
### Known issues | ||
* This approach only works for string-typed quantities. For example, you cannot parameterize | ||
`num_steps` of `Trainer` in this way. | ||
* Name of parameters should be unique. | ||
* By default pipeline root is always parameterized. |
147 changes: 147 additions & 0 deletions
147
samples/contrib/parameterized_tfx_oss/parameterized_tfx_oss.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,147 @@ | ||
#!/usr/bin/env python3 | ||
# Copyright 2019 Google LLC | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import argparse | ||
import os | ||
import tensorflow as tf | ||
|
||
from typing import Text | ||
|
||
import kfp | ||
from kfp import dsl | ||
from tfx.components.evaluator.component import Evaluator | ||
from tfx.components.example_gen.csv_example_gen.component import CsvExampleGen | ||
from tfx.components.example_validator.component import ExampleValidator | ||
from tfx.components.model_validator.component import ModelValidator | ||
from tfx.components.pusher.component import Pusher | ||
from tfx.components.schema_gen.component import SchemaGen | ||
from tfx.components.statistics_gen.component import StatisticsGen | ||
from tfx.components.trainer.component import Trainer | ||
from tfx.components.transform.component import Transform | ||
from tfx.orchestration import metadata | ||
from tfx.orchestration import pipeline | ||
from tfx.orchestration.kubeflow import kubeflow_dag_runner | ||
from tfx.proto import evaluator_pb2 | ||
from tfx.utils.dsl_utils import csv_input | ||
from tfx.proto import pusher_pb2 | ||
from tfx.proto import trainer_pb2 | ||
from tfx.extensions.google_cloud_ai_platform.trainer import executor as ai_platform_trainer_executor | ||
from ml_metadata.proto import metadata_store_pb2 | ||
from tfx.orchestration.kubeflow.proto import kubeflow_pb2 | ||
|
||
# Define pipeline params used for pipeline execution. | ||
# Path to the module file, should be a GCS path. | ||
_taxi_module_file_param = dsl.PipelineParam( | ||
name='module-file', | ||
value='gs://ml-pipeline-playground/tfx_taxi_simple/modules/taxi_utils.py') | ||
|
||
# Path to the CSV data file, under which their should be a data.csv file. | ||
_data_root_param = dsl.PipelineParam( | ||
name='data-root', | ||
value='gs://ml-pipeline-playground/tfx_taxi_simple/data') | ||
|
||
# Path of pipeline root, should be a GCS path. | ||
_pipeline_root_param = dsl.PipelineParam( | ||
name='pipeline-root', | ||
value=os.path.join('gs://your-bucket', 'tfx_taxi_simple')) | ||
|
||
def _create_test_pipeline(pipeline_root: Text, csv_input_location: Text, | ||
taxi_module_file: Text, enable_cache: bool): | ||
"""Creates a simple Kubeflow-based Chicago Taxi TFX pipeline. | ||
Args: | ||
pipeline_name: The name of the pipeline. | ||
pipeline_root: The root of the pipeline output. | ||
csv_input_location: The location of the input data directory. | ||
taxi_module_file: The location of the module file for Transform/Trainer. | ||
enable_cache: Whether to enable cache or not. | ||
Returns: | ||
A logical TFX pipeline.Pipeline object. | ||
""" | ||
examples = csv_input(csv_input_location) | ||
|
||
example_gen = CsvExampleGen(input_base=examples) | ||
statistics_gen = StatisticsGen(input_data=example_gen.outputs.examples) | ||
infer_schema = SchemaGen( | ||
stats=statistics_gen.outputs.output, infer_feature_shape=False) | ||
validate_stats = ExampleValidator( | ||
stats=statistics_gen.outputs.output, schema=infer_schema.outputs.output) | ||
transform = Transform( | ||
input_data=example_gen.outputs.examples, | ||
schema=infer_schema.outputs.output, | ||
module_file=taxi_module_file) | ||
trainer = Trainer( | ||
module_file=taxi_module_file, | ||
transformed_examples=transform.outputs.transformed_examples, | ||
schema=infer_schema.outputs.output, | ||
transform_output=transform.outputs.transform_output, | ||
train_args=trainer_pb2.TrainArgs(num_steps=10), | ||
eval_args=trainer_pb2.EvalArgs(num_steps=5)) | ||
model_analyzer = Evaluator( | ||
examples=example_gen.outputs.examples, | ||
model_exports=trainer.outputs.output, | ||
feature_slicing_spec=evaluator_pb2.FeatureSlicingSpec(specs=[ | ||
evaluator_pb2.SingleSlicingSpec( | ||
column_for_slicing=['trip_start_hour']) | ||
])) | ||
model_validator = ModelValidator( | ||
examples=example_gen.outputs.examples, model=trainer.outputs.output) | ||
pusher = Pusher( | ||
model_export=trainer.outputs.output, | ||
model_blessing=model_validator.outputs.blessing, | ||
push_destination=pusher_pb2.PushDestination( | ||
filesystem=pusher_pb2.PushDestination.Filesystem( | ||
base_directory=os.path.join(pipeline_root, 'model_serving')))) | ||
|
||
return pipeline.Pipeline( | ||
pipeline_name='parameterized_tfx_oss', | ||
pipeline_root=pipeline_root, | ||
components=[ | ||
example_gen, statistics_gen, infer_schema, validate_stats, transform, | ||
trainer, model_analyzer, model_validator, pusher | ||
], | ||
enable_cache=enable_cache, | ||
) | ||
|
||
|
||
def _get_kubeflow_metadata_config() -> kubeflow_pb2.KubeflowMetadataConfig: | ||
config = kubeflow_pb2.KubeflowMetadataConfig() | ||
config.mysql_db_service_host.environment_variable = 'MYSQL_SERVICE_HOST' | ||
config.mysql_db_service_port.environment_variable = 'MYSQL_SERVICE_PORT' | ||
config.mysql_db_name.value = 'metadb' | ||
config.mysql_db_user.value = 'root' | ||
config.mysql_db_password.value = '' | ||
return config | ||
|
||
|
||
if __name__ == '__main__': | ||
|
||
enable_cache = True | ||
|
||
pipeline = _create_test_pipeline( | ||
str(_pipeline_root_param), | ||
str(_data_root_param), | ||
str(_taxi_module_file_param), | ||
enable_cache=enable_cache) | ||
|
||
config = kubeflow_dag_runner.KubeflowDagRunnerConfig( | ||
kubeflow_metadata_config=_get_kubeflow_metadata_config()) | ||
|
||
kfp_runner = kubeflow_dag_runner.KubeflowDagRunner(config=config) | ||
# Make sure kfp_runner recognizes those parameters. | ||
kfp_runner._params.extend([_data_root_param, _taxi_module_file_param]) | ||
|
||
kfp_runner.run(pipeline) |
Binary file not shown.
18 changes: 18 additions & 0 deletions
18
test/sample-test/configs/parameterized_tfx_oss.config.yaml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
# Copyright 2019 Google LLC | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
test_name: parameterized_tfx_oss | ||
arguments: | ||
output: | ||
run_pipeline: True |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters