-
Notifications
You must be signed in to change notification settings - Fork 296
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
SageMaker on Flyte: TrainingJob for training with built-in algorithms and basic HPOJob support [Alpha] #120
Merged
Merged
Changes from 75 commits
Commits
Show all changes
88 commits
Select commit
Hold shift + click to select a range
f80f9e7
adding trainingjob model and sagemaker task
bnsblue db31d1c
adding models for sagemaker proto messages
bnsblue 55420b8
add new line at eof
bnsblue 0edef9c
adding common trainingjob task
bnsblue 0e98520
redo flytekit changes to comply with new interface and proto definition
bnsblue 5fa5d40
Fix a logic bug in training job model. Adding SdkSimpleTrainingJobTas…
bnsblue d588e3e
Add a comment
bnsblue 54c5d2e
Add SdkSimpleHPOJobTask
bnsblue 27a9ba4
Remove the embedding of the underlying trainingjob's output from the …
bnsblue e40e033
fix a typo
bnsblue 56ba534
add new line at eof
bnsblue a38657d
adding custom training job sdk type
bnsblue a081fa0
add code for tranlating an enum in hpo_job model; fix hpo_job_task sd…
bnsblue e38674c
missing a colon
bnsblue 18172ea
add the missing input stopping_condition for training job tasks
bnsblue f08be91
bump flyteidl version
bnsblue 1df9b63
bump to a beta version
bnsblue 635f19c
merge with master and bump version accordingly
bnsblue 7e8ff64
fixing unit tests
bnsblue 643708a
fixing unit tests
bnsblue fcb3dee
replacing interface types
bnsblue 669ff39
change
186e977
fixed training job unit test
bnsblue f65d97f
fix hpo job task interface and hide task type from users
bnsblue babefc1
Merge branch 'add-sagemaker-trainingjob-hpojob' of github.com:lyft/fl…
bnsblue f0d37cd
fix hpo job task interface
bnsblue 20b9809
fix hpo models
bnsblue 008cf38
fix serialization of the underlying trainingjob of a hpo job
bnsblue 2f274d6
Expose training job as a parameter
ac15490
Working!
a7d4cdd
replacing hyphens with underscores
bnsblue 43a42a0
updated
5204e9d
bug fix
aff9055
Sagemaker nb
4f3a3c1
Sagemaker HPO
edfd2a6
remove .demo directory
b0bbca2
Merge branch 'master' into add-sagemaker-trainingjob-hpojob
ba40a9b
Merge branch 'master' into add-sagemaker-trainingjob-hpojob
345e057
register and launch standalone trainingjob task
bnsblue 9d5f243
Merge
EngHabu 5ba0483
Complete the examples in sagemaker-hpo notebook and add text descript…
bnsblue acab1a9
update notebook
bnsblue 70a4954
all hands demo notebook added
bnsblue 7f28f26
update a notebook
bnsblue aaa2056
update the demo notebook
bnsblue 044f7b5
adding unit test for SdkSimpleHPOJobTask
bnsblue e45ccf7
failing the unit test
bnsblue 1aa3a7c
failing the unit test
bnsblue 4e8f923
wip for custom training job
bnsblue e82e242
Revert "wip for custom training job"
bnsblue c725a1d
Revert "failing the unit test"
bnsblue dcc33cd
Revert "failing the unit test"
bnsblue 9982aa7
fixing unit tests
bnsblue 70472ac
bump minor version
bnsblue a081557
preventing installing numpy==1.19.0 which introduces a breaking chang…
bnsblue 6f5003f
fix semver
bnsblue adcc3b5
Merge
EngHabu a463f3f
make changes corresponding to flyteidl changes (renaming hpo to hyper…
bnsblue 58b1d5e
bump beta version
bnsblue c32045b
Merge branch 'add-sagemaker-trainingjob-hpojob' of github.com:lyft/fl…
bnsblue 6d14fe6
Delete config.yaml
EngHabu 6f29de6
sagemaker-hpo notebook update
bnsblue f659a47
Merge branch 'add-sagemaker-trainingjob-hpojob' of github.com:lyft/fl…
bnsblue 06c254e
make changes to reflect changes in flyteidl
bnsblue da9fc56
make task name consistent
bnsblue cd70862
add missing properties for hyperparameter models
bnsblue 998aac0
add missing type hints and remove unused imports
bnsblue 585d5ab
remove unused sdk sagemaker dir
bnsblue f77d7f1
remove unused test file
bnsblue 5b912bb
revert numpy semver
bnsblue 02800df
removing notebooks
bnsblue 16867fb
remove type hints for self because CI is using python 3.6.3 while __f…
bnsblue 7c6461e
complete docstrings for hpo job task
bnsblue 5429d3d
merging with master and resolve conflict
bnsblue e6345ae
fix unit test
bnsblue d4008d8
adding input_file_type (wip)
bnsblue af64b52
add input file type support
bnsblue 363beff
add docs
bnsblue 21f3a5d
reflecting the renamed type and field
bnsblue 8b0c1f4
reflecting remove of libsvm content type
bnsblue d09f37f
reflecting remove of libsvm content type
bnsblue 76b2f0f
Give metric_definitions a None as the default value because built-in …
bnsblue 0eee0b3
nix a print statement
bnsblue 84a6972
nix custom training job for the current release
bnsblue 4cd76d3
rename SdkSimpleTrainingJobTask to SdkBuiltinAlgorithmTrainingJobTask
bnsblue 050f50c
merge with master and bump minor version
bnsblue 62e6af6
revert setup.py dependency
bnsblue 7dbac57
add back existing notebooks
bnsblue File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -18,3 +18,4 @@ build/ | |
dist | ||
*.iml | ||
.eggs | ||
.demo |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,4 +2,4 @@ | |
|
||
import flytekit.plugins | ||
|
||
__version__ = "0.10.11" | ||
__version__ = '0.11.0b1' |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,93 @@ | ||
from __future__ import absolute_import | ||
import datetime as _datetime | ||
|
||
from google.protobuf.json_format import MessageToDict | ||
|
||
from flyteidl.plugins.sagemaker import hyperparameter_tuning_job_pb2 as _pb2_hpo_job | ||
from flytekit import __version__ | ||
from flytekit.common.constants import SdkTaskType | ||
from flytekit.common.tasks import task as _sdk_task | ||
from flytekit.common import interface as _interface | ||
from flytekit.common.tasks.sagemaker.training_job_task import SdkSimpleTrainingJobTask | ||
from flytekit.models import task as _task_models | ||
from flytekit.models import interface as _interface_model | ||
from flytekit.models.sagemaker import hpo_job as _hpo_job_model | ||
from flytekit.models import literals as _literal_models | ||
from flytekit.models import types as _types_models | ||
from flytekit.models.core import types as _core_types | ||
from flytekit.sdk import types as _sdk_types | ||
|
||
|
||
class SdkSimpleHyperparameterTuningJobTask(_sdk_task.SdkTask): | ||
|
||
def __init__( | ||
self, | ||
max_number_of_training_jobs: int, | ||
max_parallel_training_jobs: int, | ||
training_job: SdkSimpleTrainingJobTask, | ||
retries: int = 0, | ||
cacheable: bool = False, | ||
cache_version: str = "", | ||
): | ||
""" | ||
|
||
:param max_number_of_training_jobs: The maximum number of training jobs that can be launched by this | ||
hyperparameter tuning job | ||
:param max_parallel_training_jobs: The maximum number of training jobs that can launched by this hyperparameter | ||
tuning job in parallel | ||
:param training_job: The reference to the training job definition | ||
:param retries: Number of retries to attempt | ||
:param cacheable: The flag to set if the user wants the output of the task execution to be cached | ||
:param cache_version: String describing the caching version for task discovery purposes | ||
""" | ||
# Use the training job model as a measure of type checking | ||
hpo_job = _hpo_job_model.HyperparameterTuningJob( | ||
max_number_of_training_jobs=max_number_of_training_jobs, | ||
max_parallel_training_jobs=max_parallel_training_jobs, | ||
training_job=training_job.training_job_model, | ||
).to_flyte_idl() | ||
|
||
# Setting flyte-level timeout to 0, and let SageMaker respect the StoppingCondition of | ||
# the underlying training job | ||
# TODO: Discuss whether this is a viable interface or contract | ||
timeout = _datetime.timedelta(seconds=0) | ||
|
||
inputs = { | ||
"hyperparameter_tuning_job_config": _interface_model.Variable( | ||
_sdk_types.Types.Proto( | ||
_pb2_hpo_job.HyperparameterTuningJobConfig).to_flyte_literal_type(), "" | ||
), | ||
} | ||
inputs.update(training_job.interface.inputs) | ||
|
||
super(SdkSimpleHyperparameterTuningJobTask, self).__init__( | ||
type=SdkTaskType.SAGEMAKER_HYPERPARAMETER_TUNING_JOB_TASK, | ||
metadata=_task_models.TaskMetadata( | ||
runtime=_task_models.RuntimeMetadata( | ||
type=_task_models.RuntimeMetadata.RuntimeType.FLYTE_SDK, | ||
version=__version__, | ||
flavor='sagemaker' | ||
), | ||
discoverable=cacheable, | ||
timeout=timeout, | ||
retries=_literal_models.RetryStrategy(retries=retries), | ||
interruptible=False, | ||
discovery_version=cache_version, | ||
deprecated_error_message="", | ||
), | ||
interface=_interface.TypedInterface( | ||
inputs=inputs, | ||
outputs={ | ||
"model": _interface_model.Variable( | ||
type=_types_models.LiteralType( | ||
blob=_core_types.BlobType( | ||
format="", | ||
dimensionality=_core_types.BlobType.BlobDimensionality.SINGLE | ||
) | ||
), | ||
description="" | ||
) | ||
} | ||
), | ||
custom=MessageToDict(hpo_job), | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,199 @@ | ||
from __future__ import absolute_import | ||
|
||
from typing import Dict, Callable | ||
import datetime as _datetime | ||
|
||
from flytekit import __version__ | ||
from flytekit.common.tasks import task as _sdk_task, sdk_runnable as _sdk_runnable | ||
from flytekit.models import task as _task_models | ||
from flytekit.models import interface as _interface_model | ||
from flytekit.common import interface as _interface | ||
from flytekit.models.sagemaker import training_job as _training_job_models | ||
from google.protobuf.json_format import MessageToDict | ||
from flytekit.models import types as _idl_types | ||
from flytekit.models.core import types as _core_types | ||
from flytekit.models import literals as _literal_models | ||
from flytekit.common.constants import SdkTaskType | ||
|
||
|
||
class SdkSimpleTrainingJobTask(_sdk_task.SdkTask): | ||
def __init__( | ||
self, | ||
training_job_resource_config: _training_job_models.TrainingJobResourceConfig, | ||
algorithm_specification: _training_job_models.AlgorithmSpecification, | ||
retries: int = 0, | ||
cacheable: bool = False, | ||
cache_version: str = "", | ||
): | ||
""" | ||
|
||
:param training_job_resource_config: The options to configure the training job | ||
:param algorithm_specification: The options to configure the target algorithm of the training | ||
:param retries: Number of retries to attempt | ||
:param cacheable: The flag to set if the user wants the output of the task execution to be cached | ||
:param cache_version: String describing the caching version for task discovery purposes | ||
""" | ||
# Use the training job model as a measure of type checking | ||
self._training_job_model = _training_job_models.TrainingJob( | ||
algorithm_specification=algorithm_specification, | ||
training_job_resource_config=training_job_resource_config, | ||
) | ||
|
||
# Setting flyte-level timeout to 0, and let SageMaker takes the StoppingCondition and terminate the training | ||
# job gracefully | ||
timeout = _datetime.timedelta(seconds=0) | ||
|
||
super(SdkSimpleTrainingJobTask, self).__init__( | ||
type=SdkTaskType.SAGEMAKER_TRAINING_JOB_TASK, | ||
metadata=_task_models.TaskMetadata( | ||
runtime=_task_models.RuntimeMetadata( | ||
type=_task_models.RuntimeMetadata.RuntimeType.FLYTE_SDK, | ||
version=__version__, | ||
flavor='sagemaker' | ||
), | ||
discoverable=cacheable, | ||
timeout=timeout, | ||
retries=_literal_models.RetryStrategy(retries=retries), | ||
interruptible=False, | ||
discovery_version=cache_version, | ||
deprecated_error_message="", | ||
), | ||
interface=_interface.TypedInterface( | ||
inputs={ | ||
"static_hyperparameters": _interface_model.Variable( | ||
type=_idl_types.LiteralType(simple=_idl_types.SimpleType.STRUCT), | ||
description="", | ||
), | ||
"train": _interface_model.Variable( | ||
type=_idl_types.LiteralType( | ||
blob=_core_types.BlobType( | ||
format="csv", | ||
bnsblue marked this conversation as resolved.
Show resolved
Hide resolved
|
||
dimensionality=_core_types.BlobType.BlobDimensionality.MULTIPART | ||
), | ||
), | ||
description="", | ||
), | ||
"validation": _interface_model.Variable( | ||
type=_idl_types.LiteralType( | ||
blob=_core_types.BlobType( | ||
format="csv", | ||
dimensionality=_core_types.BlobType.BlobDimensionality.MULTIPART | ||
), | ||
), | ||
description="", | ||
), | ||
}, | ||
outputs={ | ||
"model": _interface_model.Variable( | ||
type=_idl_types.LiteralType( | ||
blob=_core_types.BlobType( | ||
format="", | ||
dimensionality=_core_types.BlobType.BlobDimensionality.SINGLE | ||
) | ||
), | ||
description="" | ||
) | ||
} | ||
), | ||
custom=MessageToDict(self._training_job_model.to_flyte_idl()), | ||
) | ||
|
||
@property | ||
def training_job_model(self) -> _training_job_models.TrainingJob: | ||
return self._training_job_model | ||
|
||
|
||
class SdkCustomTrainingJobTask(_sdk_runnable.SdkRunnableTask): | ||
def __init__( | ||
self, | ||
task_function: Callable, | ||
training_job_resource_config: _training_job_models.TrainingJobResourceConfig, | ||
algorithm_specification: _training_job_models.AlgorithmSpecification, | ||
cache_version: str, | ||
retries: int = 0, | ||
# interruptible: bool = False, | ||
deprecated: bool = False, | ||
cacheable: bool = False, | ||
# environment: Dict[str, str] = None, | ||
): | ||
""" | ||
|
||
:param task_function: | ||
|
||
:param training_job_resource_config: The options to configure the training job | ||
:param algorithm_specification: The options to configure the target algorithm of the training | ||
:param cache_version: String describing the caching version for task discovery purposes | ||
:param retries: Number of retries to attempt | ||
:param deprecated: This string can be used to mark the task as deprecated. Consumers of the task will | ||
receive deprecation warnings. | ||
:param cacheable: The flag to set if the user wants the output of the task execution to be cached | ||
""" | ||
# Use the training job model as a measure of type checking | ||
training_job = _training_job_models.TrainingJob( | ||
algorithm_specification=algorithm_specification, | ||
training_job_resource_config=training_job_resource_config, | ||
).to_flyte_idl() | ||
|
||
# Setting flyte-level timeout to 0, and let SageMaker takes the StoppingCondition and terminate the training | ||
# job gracefully | ||
timeout = _datetime.timedelta(seconds=0) | ||
|
||
super(SdkCustomTrainingJobTask, self).__init__( | ||
task_function=task_function, | ||
task_type=SdkTaskType.SAGEMAKER_TRAINING_JOB_TASK, | ||
discovery_version=cache_version, | ||
retries=retries, | ||
interruptible=False, | ||
deprecated=deprecated, | ||
storage_request="", | ||
cpu_request="", | ||
gpu_request="", | ||
memory_request="", | ||
storage_limit="", | ||
cpu_limit="", | ||
gpu_limit="", | ||
memory_limit="", | ||
discoverable=cacheable, | ||
timeout=timeout, | ||
environment={}, | ||
custom=MessageToDict(training_job), | ||
), | ||
self.add_inputs( | ||
{ | ||
"static_hyperparameters": _interface_model.Variable( | ||
type=_idl_types.LiteralType(simple=_idl_types.SimpleType.STRUCT), | ||
description="", | ||
), | ||
"train": _interface_model.Variable( | ||
type=_idl_types.LiteralType( | ||
blob=_core_types.BlobType( | ||
format="csv", | ||
bnsblue marked this conversation as resolved.
Show resolved
Hide resolved
|
||
dimensionality=_core_types.BlobType.BlobDimensionality.MULTIPART | ||
), | ||
), | ||
description="", | ||
), | ||
"validation": _interface_model.Variable( | ||
type=_idl_types.LiteralType( | ||
blob=_core_types.BlobType( | ||
format="csv", | ||
dimensionality=_core_types.BlobType.BlobDimensionality.MULTIPART | ||
), | ||
), | ||
description="", | ||
), | ||
}, | ||
) | ||
self.add_outputs( | ||
{ | ||
"model": _interface_model.Variable( | ||
type=_idl_types.LiteralType( | ||
blob=_core_types.BlobType( | ||
format="", | ||
dimensionality=_core_types.BlobType.BlobDimensionality.SINGLE | ||
) | ||
), | ||
description="" | ||
) | ||
} | ||
) |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What does Simple mean?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this meant to be built in?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes -- meaning using built-in algorithm mode where users don't write his/her own decorated function.