Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SageMaker on Flyte: TrainingJob for training with built-in algorithms and basic HPOJob support [Alpha] #120

Merged
merged 88 commits into from
Jul 31, 2020
Merged
Show file tree
Hide file tree
Changes from 75 commits
Commits
Show all changes
88 commits
Select commit Hold shift + click to select a range
f80f9e7
adding trainingjob model and sagemaker task
bnsblue Jun 3, 2020
db31d1c
adding models for sagemaker proto messages
bnsblue Jun 4, 2020
55420b8
add new line at eof
bnsblue Jun 4, 2020
0edef9c
adding common trainingjob task
bnsblue Jun 9, 2020
0e98520
redo flytekit changes to comply with new interface and proto definition
bnsblue Jun 22, 2020
5fa5d40
Fix a logic bug in training job model. Adding SdkSimpleTrainingJobTas…
bnsblue Jun 22, 2020
d588e3e
Add a comment
bnsblue Jun 22, 2020
54c5d2e
Add SdkSimpleHPOJobTask
bnsblue Jun 22, 2020
27a9ba4
Remove the embedding of the underlying trainingjob's output from the …
bnsblue Jun 22, 2020
e40e033
fix a typo
bnsblue Jun 22, 2020
56ba534
add new line at eof
bnsblue Jun 22, 2020
a38657d
adding custom training job sdk type
bnsblue Jun 22, 2020
a081fa0
add code for tranlating an enum in hpo_job model; fix hpo_job_task sd…
bnsblue Jun 23, 2020
e38674c
missing a colon
bnsblue Jun 23, 2020
18172ea
add the missing input stopping_condition for training job tasks
bnsblue Jun 24, 2020
f08be91
bump flyteidl version
bnsblue Jun 24, 2020
1df9b63
bump to a beta version
bnsblue Jun 25, 2020
635f19c
merge with master and bump version accordingly
bnsblue Jun 25, 2020
7e8ff64
fixing unit tests
bnsblue Jun 25, 2020
643708a
fixing unit tests
bnsblue Jun 25, 2020
fcb3dee
replacing interface types
bnsblue Jun 25, 2020
669ff39
change
Jun 25, 2020
186e977
fixed training job unit test
bnsblue Jun 25, 2020
f65d97f
fix hpo job task interface and hide task type from users
bnsblue Jun 25, 2020
babefc1
Merge branch 'add-sagemaker-trainingjob-hpojob' of github.com:lyft/fl…
bnsblue Jun 25, 2020
f0d37cd
fix hpo job task interface
bnsblue Jun 25, 2020
20b9809
fix hpo models
bnsblue Jun 25, 2020
008cf38
fix serialization of the underlying trainingjob of a hpo job
bnsblue Jun 25, 2020
2f274d6
Expose training job as a parameter
Jun 25, 2020
ac15490
Working!
Jun 25, 2020
a7d4cdd
replacing hyphens with underscores
bnsblue Jun 26, 2020
43a42a0
updated
Jun 26, 2020
5204e9d
bug fix
Jun 26, 2020
aff9055
Sagemaker nb
Jun 26, 2020
4f3a3c1
Sagemaker HPO
Jun 29, 2020
edfd2a6
remove .demo directory
Jun 29, 2020
b0bbca2
Merge branch 'master' into add-sagemaker-trainingjob-hpojob
Jun 30, 2020
ba40a9b
Merge branch 'master' into add-sagemaker-trainingjob-hpojob
Jun 30, 2020
345e057
register and launch standalone trainingjob task
bnsblue Jul 9, 2020
9d5f243
Merge
EngHabu Jul 9, 2020
5ba0483
Complete the examples in sagemaker-hpo notebook and add text descript…
bnsblue Jul 14, 2020
acab1a9
update notebook
bnsblue Jul 14, 2020
70a4954
all hands demo notebook added
bnsblue Jul 15, 2020
7f28f26
update a notebook
bnsblue Jul 15, 2020
aaa2056
update the demo notebook
bnsblue Jul 16, 2020
044f7b5
adding unit test for SdkSimpleHPOJobTask
bnsblue Jul 21, 2020
e45ccf7
failing the unit test
bnsblue Jul 22, 2020
1aa3a7c
failing the unit test
bnsblue Jul 22, 2020
4e8f923
wip for custom training job
bnsblue Jul 23, 2020
e82e242
Revert "wip for custom training job"
bnsblue Jul 23, 2020
c725a1d
Revert "failing the unit test"
bnsblue Jul 23, 2020
dcc33cd
Revert "failing the unit test"
bnsblue Jul 23, 2020
9982aa7
fixing unit tests
bnsblue Jul 23, 2020
70472ac
bump minor version
bnsblue Jul 23, 2020
a081557
preventing installing numpy==1.19.0 which introduces a breaking chang…
bnsblue Jul 23, 2020
6f5003f
fix semver
bnsblue Jul 23, 2020
adcc3b5
Merge
EngHabu Jul 24, 2020
a463f3f
make changes corresponding to flyteidl changes (renaming hpo to hyper…
bnsblue Jul 24, 2020
58b1d5e
bump beta version
bnsblue Jul 24, 2020
c32045b
Merge branch 'add-sagemaker-trainingjob-hpojob' of github.com:lyft/fl…
bnsblue Jul 24, 2020
6d14fe6
Delete config.yaml
EngHabu Jul 24, 2020
6f29de6
sagemaker-hpo notebook update
bnsblue Jul 24, 2020
f659a47
Merge branch 'add-sagemaker-trainingjob-hpojob' of github.com:lyft/fl…
bnsblue Jul 24, 2020
06c254e
make changes to reflect changes in flyteidl
bnsblue Jul 27, 2020
da9fc56
make task name consistent
bnsblue Jul 28, 2020
cd70862
add missing properties for hyperparameter models
bnsblue Jul 28, 2020
998aac0
add missing type hints and remove unused imports
bnsblue Jul 28, 2020
585d5ab
remove unused sdk sagemaker dir
bnsblue Jul 28, 2020
f77d7f1
remove unused test file
bnsblue Jul 28, 2020
5b912bb
revert numpy semver
bnsblue Jul 28, 2020
02800df
removing notebooks
bnsblue Jul 28, 2020
16867fb
remove type hints for self because CI is using python 3.6.3 while __f…
bnsblue Jul 28, 2020
7c6461e
complete docstrings for hpo job task
bnsblue Jul 28, 2020
5429d3d
merging with master and resolve conflict
bnsblue Jul 28, 2020
e6345ae
fix unit test
bnsblue Jul 28, 2020
d4008d8
adding input_file_type (wip)
bnsblue Jul 29, 2020
af64b52
add input file type support
bnsblue Jul 29, 2020
363beff
add docs
bnsblue Jul 29, 2020
21f3a5d
reflecting the renamed type and field
bnsblue Jul 30, 2020
8b0c1f4
reflecting remove of libsvm content type
bnsblue Jul 30, 2020
d09f37f
reflecting remove of libsvm content type
bnsblue Jul 30, 2020
76b2f0f
Give metric_definitions a None as the default value because built-in …
bnsblue Jul 31, 2020
0eee0b3
nix a print statement
bnsblue Jul 31, 2020
84a6972
nix custom training job for the current release
bnsblue Jul 31, 2020
4cd76d3
rename SdkSimpleTrainingJobTask to SdkBuiltinAlgorithmTrainingJobTask
bnsblue Jul 31, 2020
050f50c
merge with master and bump minor version
bnsblue Jul 31, 2020
62e6af6
revert setup.py dependency
bnsblue Jul 31, 2020
7dbac57
add back existing notebooks
bnsblue Jul 31, 2020
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,4 @@ build/
dist
*.iml
.eggs
.demo
2 changes: 1 addition & 1 deletion flytekit/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@

import flytekit.plugins

__version__ = "0.10.11"
__version__ = '0.11.0b1'
2 changes: 2 additions & 0 deletions flytekit/common/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ class SdkTaskType(object):
PYTORCH_TASK = "pytorch"
# Raw container task is just a name, it defaults to using the regular container task (like python etc), but sets the data_config in the container
RAW_CONTAINER_TASK = "raw-container"
SAGEMAKER_TRAINING_JOB_TASK = "sagemaker_training_job_task"
SAGEMAKER_HYPERPARAMETER_TUNING_JOB_TASK = "sagemaker_hyperparameter_tuning_job_task"

GLOBAL_INPUT_NODE_ID = ''

Expand Down
93 changes: 93 additions & 0 deletions flytekit/common/tasks/sagemaker/hpo_job_task.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
from __future__ import absolute_import
import datetime as _datetime

from google.protobuf.json_format import MessageToDict

from flyteidl.plugins.sagemaker import hyperparameter_tuning_job_pb2 as _pb2_hpo_job
from flytekit import __version__
from flytekit.common.constants import SdkTaskType
from flytekit.common.tasks import task as _sdk_task
from flytekit.common import interface as _interface
from flytekit.common.tasks.sagemaker.training_job_task import SdkSimpleTrainingJobTask
from flytekit.models import task as _task_models
from flytekit.models import interface as _interface_model
from flytekit.models.sagemaker import hpo_job as _hpo_job_model
from flytekit.models import literals as _literal_models
from flytekit.models import types as _types_models
from flytekit.models.core import types as _core_types
from flytekit.sdk import types as _sdk_types


class SdkSimpleHyperparameterTuningJobTask(_sdk_task.SdkTask):

def __init__(
self,
max_number_of_training_jobs: int,
max_parallel_training_jobs: int,
training_job: SdkSimpleTrainingJobTask,
retries: int = 0,
cacheable: bool = False,
cache_version: str = "",
):
"""

:param max_number_of_training_jobs: The maximum number of training jobs that can be launched by this
hyperparameter tuning job
:param max_parallel_training_jobs: The maximum number of training jobs that can launched by this hyperparameter
tuning job in parallel
:param training_job: The reference to the training job definition
:param retries: Number of retries to attempt
:param cacheable: The flag to set if the user wants the output of the task execution to be cached
:param cache_version: String describing the caching version for task discovery purposes
"""
# Use the training job model as a measure of type checking
hpo_job = _hpo_job_model.HyperparameterTuningJob(
max_number_of_training_jobs=max_number_of_training_jobs,
max_parallel_training_jobs=max_parallel_training_jobs,
training_job=training_job.training_job_model,
).to_flyte_idl()

# Setting flyte-level timeout to 0, and let SageMaker respect the StoppingCondition of
# the underlying training job
# TODO: Discuss whether this is a viable interface or contract
timeout = _datetime.timedelta(seconds=0)

inputs = {
"hyperparameter_tuning_job_config": _interface_model.Variable(
_sdk_types.Types.Proto(
_pb2_hpo_job.HyperparameterTuningJobConfig).to_flyte_literal_type(), ""
),
}
inputs.update(training_job.interface.inputs)

super(SdkSimpleHyperparameterTuningJobTask, self).__init__(
type=SdkTaskType.SAGEMAKER_HYPERPARAMETER_TUNING_JOB_TASK,
metadata=_task_models.TaskMetadata(
runtime=_task_models.RuntimeMetadata(
type=_task_models.RuntimeMetadata.RuntimeType.FLYTE_SDK,
version=__version__,
flavor='sagemaker'
),
discoverable=cacheable,
timeout=timeout,
retries=_literal_models.RetryStrategy(retries=retries),
interruptible=False,
discovery_version=cache_version,
deprecated_error_message="",
),
interface=_interface.TypedInterface(
inputs=inputs,
outputs={
"model": _interface_model.Variable(
type=_types_models.LiteralType(
blob=_core_types.BlobType(
format="",
dimensionality=_core_types.BlobType.BlobDimensionality.SINGLE
)
),
description=""
)
}
),
custom=MessageToDict(hpo_job),
)
199 changes: 199 additions & 0 deletions flytekit/common/tasks/sagemaker/training_job_task.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,199 @@
from __future__ import absolute_import

from typing import Dict, Callable
import datetime as _datetime

from flytekit import __version__
from flytekit.common.tasks import task as _sdk_task, sdk_runnable as _sdk_runnable
from flytekit.models import task as _task_models
from flytekit.models import interface as _interface_model
from flytekit.common import interface as _interface
from flytekit.models.sagemaker import training_job as _training_job_models
from google.protobuf.json_format import MessageToDict
from flytekit.models import types as _idl_types
from flytekit.models.core import types as _core_types
from flytekit.models import literals as _literal_models
from flytekit.common.constants import SdkTaskType


class SdkSimpleTrainingJobTask(_sdk_task.SdkTask):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What does Simple mean?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this meant to be built in?

Copy link
Contributor Author

@bnsblue bnsblue Jul 29, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes -- meaning using built-in algorithm mode where users don't write his/her own decorated function.

def __init__(
self,
training_job_resource_config: _training_job_models.TrainingJobResourceConfig,
algorithm_specification: _training_job_models.AlgorithmSpecification,
retries: int = 0,
cacheable: bool = False,
cache_version: str = "",
):
"""

:param training_job_resource_config: The options to configure the training job
:param algorithm_specification: The options to configure the target algorithm of the training
:param retries: Number of retries to attempt
:param cacheable: The flag to set if the user wants the output of the task execution to be cached
:param cache_version: String describing the caching version for task discovery purposes
"""
# Use the training job model as a measure of type checking
self._training_job_model = _training_job_models.TrainingJob(
algorithm_specification=algorithm_specification,
training_job_resource_config=training_job_resource_config,
)

# Setting flyte-level timeout to 0, and let SageMaker takes the StoppingCondition and terminate the training
# job gracefully
timeout = _datetime.timedelta(seconds=0)

super(SdkSimpleTrainingJobTask, self).__init__(
type=SdkTaskType.SAGEMAKER_TRAINING_JOB_TASK,
metadata=_task_models.TaskMetadata(
runtime=_task_models.RuntimeMetadata(
type=_task_models.RuntimeMetadata.RuntimeType.FLYTE_SDK,
version=__version__,
flavor='sagemaker'
),
discoverable=cacheable,
timeout=timeout,
retries=_literal_models.RetryStrategy(retries=retries),
interruptible=False,
discovery_version=cache_version,
deprecated_error_message="",
),
interface=_interface.TypedInterface(
inputs={
"static_hyperparameters": _interface_model.Variable(
type=_idl_types.LiteralType(simple=_idl_types.SimpleType.STRUCT),
description="",
),
"train": _interface_model.Variable(
type=_idl_types.LiteralType(
blob=_core_types.BlobType(
format="csv",
bnsblue marked this conversation as resolved.
Show resolved Hide resolved
dimensionality=_core_types.BlobType.BlobDimensionality.MULTIPART
),
),
description="",
),
"validation": _interface_model.Variable(
type=_idl_types.LiteralType(
blob=_core_types.BlobType(
format="csv",
dimensionality=_core_types.BlobType.BlobDimensionality.MULTIPART
),
),
description="",
),
},
outputs={
"model": _interface_model.Variable(
type=_idl_types.LiteralType(
blob=_core_types.BlobType(
format="",
dimensionality=_core_types.BlobType.BlobDimensionality.SINGLE
)
),
description=""
)
}
),
custom=MessageToDict(self._training_job_model.to_flyte_idl()),
)

@property
def training_job_model(self) -> _training_job_models.TrainingJob:
return self._training_job_model


class SdkCustomTrainingJobTask(_sdk_runnable.SdkRunnableTask):
def __init__(
self,
task_function: Callable,
training_job_resource_config: _training_job_models.TrainingJobResourceConfig,
algorithm_specification: _training_job_models.AlgorithmSpecification,
cache_version: str,
retries: int = 0,
# interruptible: bool = False,
deprecated: bool = False,
cacheable: bool = False,
# environment: Dict[str, str] = None,
):
"""

:param task_function:

:param training_job_resource_config: The options to configure the training job
:param algorithm_specification: The options to configure the target algorithm of the training
:param cache_version: String describing the caching version for task discovery purposes
:param retries: Number of retries to attempt
:param deprecated: This string can be used to mark the task as deprecated. Consumers of the task will
receive deprecation warnings.
:param cacheable: The flag to set if the user wants the output of the task execution to be cached
"""
# Use the training job model as a measure of type checking
training_job = _training_job_models.TrainingJob(
algorithm_specification=algorithm_specification,
training_job_resource_config=training_job_resource_config,
).to_flyte_idl()

# Setting flyte-level timeout to 0, and let SageMaker takes the StoppingCondition and terminate the training
# job gracefully
timeout = _datetime.timedelta(seconds=0)

super(SdkCustomTrainingJobTask, self).__init__(
task_function=task_function,
task_type=SdkTaskType.SAGEMAKER_TRAINING_JOB_TASK,
discovery_version=cache_version,
retries=retries,
interruptible=False,
deprecated=deprecated,
storage_request="",
cpu_request="",
gpu_request="",
memory_request="",
storage_limit="",
cpu_limit="",
gpu_limit="",
memory_limit="",
discoverable=cacheable,
timeout=timeout,
environment={},
custom=MessageToDict(training_job),
),
self.add_inputs(
{
"static_hyperparameters": _interface_model.Variable(
type=_idl_types.LiteralType(simple=_idl_types.SimpleType.STRUCT),
description="",
),
"train": _interface_model.Variable(
type=_idl_types.LiteralType(
blob=_core_types.BlobType(
format="csv",
bnsblue marked this conversation as resolved.
Show resolved Hide resolved
dimensionality=_core_types.BlobType.BlobDimensionality.MULTIPART
),
),
description="",
),
"validation": _interface_model.Variable(
type=_idl_types.LiteralType(
blob=_core_types.BlobType(
format="csv",
dimensionality=_core_types.BlobType.BlobDimensionality.MULTIPART
),
),
description="",
),
},
)
self.add_outputs(
{
"model": _interface_model.Variable(
type=_idl_types.LiteralType(
blob=_core_types.BlobType(
format="",
dimensionality=_core_types.BlobType.BlobDimensionality.SINGLE
)
),
description=""
)
}
)
Loading