Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feature: support session tag chaining for training job #4596

Merged
merged 4 commits into from
Apr 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 21 additions & 1 deletion src/sagemaker/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,7 @@ def __init__(
container_arguments: Optional[List[str]] = None,
disable_output_compression: bool = False,
enable_remote_debug: Optional[Union[bool, PipelineVariable]] = None,
enable_session_tag_chaining: Optional[Union[bool, PipelineVariable]] = None,
**kwargs,
):
"""Initialize an ``EstimatorBase`` instance.
Expand Down Expand Up @@ -544,7 +545,9 @@ def __init__(
enable_infra_check (bool or PipelineVariable): Optional.
Specifies whether it is running Sagemaker built-in infra check jobs.
enable_remote_debug (bool or PipelineVariable): Optional.
Specifies whether RemoteDebug is enabled for the training job
Specifies whether RemoteDebug is enabled for the training job.
enable_session_tag_chaining (bool or PipelineVariable): Optional.
Specifies whether SessionTagChaining is enabled for the training job.
"""
instance_count = renamed_kwargs(
"train_instance_count", "instance_count", instance_count, kwargs
Expand Down Expand Up @@ -785,6 +788,8 @@ def __init__(

self._enable_remote_debug = enable_remote_debug

self._enable_session_tag_chaining = enable_session_tag_chaining

@abstractmethod
def training_image_uri(self):
"""Return the Docker image to use for training.
Expand Down Expand Up @@ -2318,6 +2323,14 @@ def get_remote_debug_config(self):
else {"EnableRemoteDebug": self._enable_remote_debug}
)

def get_session_chaining_config(self):
"""dict: Return the configuration of SessionChaining"""
return (
None
if self._enable_session_tag_chaining is None
else {"EnableSessionTagChaining": self._enable_session_tag_chaining}
)

def enable_remote_debug(self):
"""Enable remote debug for a training job."""
self._update_remote_debug(True)
Expand Down Expand Up @@ -2574,6 +2587,9 @@ def _get_train_args(cls, estimator, inputs, experiment_config):
if estimator.get_remote_debug_config() is not None:
train_args["remote_debug_config"] = estimator.get_remote_debug_config()

if estimator.get_session_chaining_config() is not None:
train_args["session_chaining_config"] = estimator.get_session_chaining_config()

return train_args

@classmethod
Expand Down Expand Up @@ -2766,6 +2782,7 @@ def __init__(
disable_output_compression: bool = False,
enable_infra_check: Optional[Union[bool, PipelineVariable]] = None,
enable_remote_debug: Optional[Union[bool, PipelineVariable]] = None,
enable_session_tag_chaining: Optional[Union[bool, PipelineVariable]] = None,
**kwargs,
):
"""Initialize an ``Estimator`` instance.
Expand Down Expand Up @@ -3129,6 +3146,8 @@ def __init__(
Specifies whether it is running Sagemaker built-in infra check jobs.
enable_remote_debug (bool or PipelineVariable): Optional.
Specifies whether RemoteDebug is enabled for the training job
enable_session_tag_chaining (bool or PipelineVariable): Optional.
Specifies whether SessionTagChaining is enabled for the training job
"""
self.image_uri = image_uri
self._hyperparameters = hyperparameters.copy() if hyperparameters else {}
Expand Down Expand Up @@ -3181,6 +3200,7 @@ def __init__(
container_arguments=container_arguments,
disable_output_compression=disable_output_compression,
enable_remote_debug=enable_remote_debug,
enable_session_tag_chaining=enable_session_tag_chaining,
**kwargs,
)

Expand Down
4 changes: 4 additions & 0 deletions src/sagemaker/jumpstart/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ def __init__(
container_arguments: Optional[List[str]] = None,
disable_output_compression: Optional[bool] = None,
enable_remote_debug: Optional[Union[bool, PipelineVariable]] = None,
enable_session_tag_chaining: Optional[Union[bool, PipelineVariable]] = None,
):
"""Initializes a ``JumpStartEstimator``.

Expand Down Expand Up @@ -500,6 +501,8 @@ def __init__(
to Amazon S3 without compression after training finishes.
enable_remote_debug (bool or PipelineVariable): Optional.
Specifies whether RemoteDebug is enabled for the training job
enable_session_tag_chaining (bool or PipelineVariable): Optional.
Specifies whether SessionTagChaining is enabled for the training job

Raises:
ValueError: If the model ID is not recognized by JumpStart.
Expand Down Expand Up @@ -578,6 +581,7 @@ def _validate_model_id_and_get_type_hook():
disable_output_compression=disable_output_compression,
enable_infra_check=enable_infra_check,
enable_remote_debug=enable_remote_debug,
enable_session_tag_chaining=enable_session_tag_chaining,
)

self.model_id = estimator_init_kwargs.model_id
Expand Down
2 changes: 2 additions & 0 deletions src/sagemaker/jumpstart/factory/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ def get_init_kwargs(
disable_output_compression: Optional[bool] = None,
enable_infra_check: Optional[Union[bool, PipelineVariable]] = None,
enable_remote_debug: Optional[Union[bool, PipelineVariable]] = None,
enable_session_tag_chaining: Optional[Union[bool, PipelineVariable]] = None,
) -> JumpStartEstimatorInitKwargs:
"""Returns kwargs required to instantiate `sagemaker.estimator.Estimator` object."""

Expand Down Expand Up @@ -188,6 +189,7 @@ def get_init_kwargs(
disable_output_compression=disable_output_compression,
enable_infra_check=enable_infra_check,
enable_remote_debug=enable_remote_debug,
enable_session_tag_chaining=enable_session_tag_chaining,
)

estimator_init_kwargs = _add_model_version_to_kwargs(estimator_init_kwargs)
Expand Down
3 changes: 3 additions & 0 deletions src/sagemaker/jumpstart/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -1723,6 +1723,7 @@ class JumpStartEstimatorInitKwargs(JumpStartKwargs):
"disable_output_compression",
"enable_infra_check",
"enable_remote_debug",
"enable_session_tag_chaining",
]

SERIALIZATION_EXCLUSION_SET = {
Expand Down Expand Up @@ -1790,6 +1791,7 @@ def __init__(
disable_output_compression: Optional[bool] = None,
enable_infra_check: Optional[Union[bool, PipelineVariable]] = None,
enable_remote_debug: Optional[Union[bool, PipelineVariable]] = None,
enable_session_tag_chaining: Optional[Union[bool, PipelineVariable]] = None,
) -> None:
"""Instantiates JumpStartEstimatorInitKwargs object."""

Expand Down Expand Up @@ -1849,6 +1851,7 @@ def __init__(
self.disable_output_compression = disable_output_compression
self.enable_infra_check = enable_infra_check
self.enable_remote_debug = enable_remote_debug
self.enable_session_tag_chaining = enable_session_tag_chaining


class JumpStartEstimatorFitKwargs(JumpStartKwargs):
Expand Down
24 changes: 24 additions & 0 deletions src/sagemaker/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -758,6 +758,7 @@ def train( # noqa: C901
environment: Optional[Dict[str, str]] = None,
retry_strategy=None,
remote_debug_config=None,
session_chaining_config=None,
):
"""Create an Amazon SageMaker training job.

Expand Down Expand Up @@ -877,6 +878,15 @@ def train( # noqa: C901
remote_debug_config = {
"EnableRemoteDebug": True,
}
session_chaining_config(dict): Configuration for SessionChaining. (default: ``None``)
The dict can contain 'EnableSessionTagChaining'(bool).
For example,

.. code:: python

session_chaining_config = {
"EnableSessionTagChaining": True,
}
environment (dict[str, str]) : Environment variables to be set for
use during training job (default: ``None``)
retry_strategy(dict): Defines RetryStrategy for InternalServerFailures.
Expand Down Expand Up @@ -970,6 +980,7 @@ def train( # noqa: C901
profiler_rule_configs=profiler_rule_configs,
profiler_config=inferred_profiler_config,
remote_debug_config=remote_debug_config,
session_chaining_config=session_chaining_config,
environment=environment,
retry_strategy=retry_strategy,
)
Expand Down Expand Up @@ -1013,6 +1024,7 @@ def _get_train_request( # noqa: C901
profiler_rule_configs=None,
profiler_config=None,
remote_debug_config=None,
session_chaining_config=None,
environment=None,
retry_strategy=None,
):
Expand Down Expand Up @@ -1133,6 +1145,15 @@ def _get_train_request( # noqa: C901
remote_debug_config = {
"EnableRemoteDebug": True,
}
session_chaining_config(dict): Configuration for SessionChaining. (default: ``None``)
The dict can contain 'EnableSessionTagChaining'(bool).
For example,

.. code:: python

session_chaining_config = {
"EnableSessionTagChaining": True,
}
environment (dict[str, str]) : Environment variables to be set for
use during training job (default: ``None``)
retry_strategy(dict): Defines RetryStrategy for InternalServerFailures.
Expand Down Expand Up @@ -1239,6 +1260,9 @@ def _get_train_request( # noqa: C901
if remote_debug_config is not None:
train_request["RemoteDebugConfig"] = remote_debug_config

if session_chaining_config is not None:
train_request["SessionChainingConfig"] = session_chaining_config

if retry_strategy is not None:
train_request["RetryStrategy"] = retry_strategy

Expand Down
35 changes: 35 additions & 0 deletions tests/unit/test_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2089,6 +2089,41 @@ def test_framework_disable_remote_debug(sagemaker_session):
assert len(args) == 2


def test_framework_with_session_chaining_config(sagemaker_session):
f = DummyFramework(
entry_point=SCRIPT_PATH,
role=ROLE,
sagemaker_session=sagemaker_session,
instance_groups=[
InstanceGroup("group1", "ml.c4.xlarge", 1),
InstanceGroup("group2", "ml.m4.xlarge", 2),
],
enable_session_tag_chaining=True,
)
f.fit("s3://mydata")
sagemaker_session.train.assert_called_once()
_, args = sagemaker_session.train.call_args
assert args["session_chaining_config"]["EnableSessionTagChaining"]
assert f.get_session_chaining_config()["EnableSessionTagChaining"]


def test_framework_without_session_chaining_config(sagemaker_session):
f = DummyFramework(
entry_point=SCRIPT_PATH,
role=ROLE,
sagemaker_session=sagemaker_session,
instance_groups=[
InstanceGroup("group1", "ml.c4.xlarge", 1),
InstanceGroup("group2", "ml.m4.xlarge", 2),
],
)
f.fit("s3://mydata")
sagemaker_session.train.assert_called_once()
_, args = sagemaker_session.train.call_args
assert args.get("SessionTagChaining") is None
assert f.get_remote_debug_config() is None


@patch("time.strftime", return_value=TIMESTAMP)
def test_custom_code_bucket(time, sagemaker_session):
code_bucket = "codebucket"
Expand Down
3 changes: 3 additions & 0 deletions tests/unit/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -2197,6 +2197,7 @@ def test_train_pack_to_request_with_optional_params(sagemaker_session):
CONTAINER_ENTRY_POINT = ["bin/bash", "test.sh"]
CONTAINER_ARGUMENTS = ["--arg1", "value1", "--arg2", "value2"]
remote_debug_config = {"EnableRemoteDebug": True}
session_chaining_config = {"EnableSessionTagChaining": True}

sagemaker_session.train(
image_uri=IMAGE,
Expand All @@ -2222,6 +2223,7 @@ def test_train_pack_to_request_with_optional_params(sagemaker_session):
container_entry_point=CONTAINER_ENTRY_POINT,
container_arguments=CONTAINER_ARGUMENTS,
remote_debug_config=remote_debug_config,
session_chaining_config=session_chaining_config,
)

_, _, actual_train_args = sagemaker_session.sagemaker_client.method_calls[0]
Expand All @@ -2245,6 +2247,7 @@ def test_train_pack_to_request_with_optional_params(sagemaker_session):
)
assert actual_train_args["AlgorithmSpecification"]["ContainerArguments"] == CONTAINER_ARGUMENTS
assert actual_train_args["RemoteDebugConfig"]["EnableRemoteDebug"]
assert actual_train_args["SessionChainingConfig"]["EnableSessionTagChaining"]


def test_create_transform_job_with_sagemaker_config_injection(sagemaker_session):
Expand Down