Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add supported inference and incremental training configs #4637

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 16 additions & 5 deletions src/sagemaker/jumpstart/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -504,7 +504,7 @@ def __init__(
enable_remote_debug (bool or PipelineVariable): Optional.
Specifies whether RemoteDebug is enabled for the training job
config_name (Optional[str]):
Name of the JumpStart Model config to apply. (Default: None).
Name of the training configuration to apply to the Estimator. (Default: None).

Raises:
ValueError: If the model ID is not recognized by JumpStart.
Expand Down Expand Up @@ -686,6 +686,7 @@ def attach(
model_version: Optional[str] = None,
sagemaker_session: session.Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
model_channel_name: str = "model",
config_name: Optional[str] = None,
) -> "JumpStartEstimator":
"""Attach to an existing training job.

Expand Down Expand Up @@ -721,6 +722,8 @@ def attach(
model data will be downloaded (default: 'model'). If no channel
with the same name exists in the training job, this option will
be ignored.
config_name (str): Optional. Name of the training configuration to use
when attaching to the training job. (Default: None).

Returns:
Instance of the calling ``JumpStartEstimator`` Class with the attached
Expand All @@ -732,7 +735,6 @@ def attach(
"""
config_name = None
if model_id is None:

model_id, model_version, _, config_name = get_model_info_from_training_job(
training_job_name=training_job_name, sagemaker_session=sagemaker_session
)
Expand All @@ -746,6 +748,9 @@ def attach(
"tolerate_deprecated_model": True, # model is already trained
}

if config_name:
additional_kwargs.update({"config_name": config_name})

model_specs = verify_model_region_and_return_specs(
model_id=model_id,
version=model_version,
Expand Down Expand Up @@ -804,6 +809,7 @@ def deploy(
dependencies: Optional[List[str]] = None,
git_config: Optional[Dict[str, str]] = None,
use_compiled_model: bool = False,
inference_config_name: Optional[str] = None,
) -> PredictorBase:
"""Creates endpoint from training job.

Expand Down Expand Up @@ -1039,6 +1045,8 @@ def deploy(
(Default: None).
use_compiled_model (bool): Flag to select whether to use compiled
(optimized) model. (Default: False).
inference_config_name (Optional[str]): Name of the inference configuration to
be used in the model. (Default: None).
"""
self.orig_predictor_cls = predictor_cls

Expand Down Expand Up @@ -1091,7 +1099,8 @@ def deploy(
git_config=git_config,
use_compiled_model=use_compiled_model,
training_instance_type=self.instance_type,
config_name=self.config_name,
training_config_name=self.config_name,
inference_config_name=inference_config_name,
)

predictor = super(JumpStartEstimator, self).deploy(
Expand All @@ -1108,7 +1117,7 @@ def deploy(
tolerate_deprecated_model=self.tolerate_deprecated_model,
tolerate_vulnerable_model=self.tolerate_vulnerable_model,
sagemaker_session=self.sagemaker_session,
config_name=self.config_name,
config_name=estimator_deploy_kwargs.config_name,
)

# If a predictor class was passed, do not mutate predictor
Expand Down Expand Up @@ -1140,7 +1149,9 @@ def set_training_config(self, config_name: str) -> None:
config_name (str): The name of the config.
"""
self.__init__(
model_id=self.model_id, model_version=self.model_version, config_name=config_name
model_id=self.model_id,
model_version=self.model_version,
config_name=config_name,
)

def __str__(self) -> str:
Expand Down
35 changes: 31 additions & 4 deletions src/sagemaker/jumpstart/factory/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,7 @@ def get_init_kwargs(
estimator_init_kwargs = _add_role_to_kwargs(estimator_init_kwargs)
estimator_init_kwargs = _add_env_to_kwargs(estimator_init_kwargs)
estimator_init_kwargs = _add_tags_to_kwargs(estimator_init_kwargs)
estimator_init_kwargs = _add_config_name_to_kwargs(estimator_init_kwargs)

return estimator_init_kwargs

Expand Down Expand Up @@ -291,7 +292,8 @@ def get_deploy_kwargs(
use_compiled_model: Optional[bool] = None,
model_name: Optional[str] = None,
training_instance_type: Optional[str] = None,
config_name: Optional[str] = None,
training_config_name: Optional[str] = None,
inference_config_name: Optional[str] = None,
) -> JumpStartEstimatorDeployKwargs:
"""Returns kwargs required to call `deploy` on `sagemaker.estimator.Estimator` object."""

Expand Down Expand Up @@ -319,7 +321,8 @@ def get_deploy_kwargs(
tolerate_vulnerable_model=tolerate_vulnerable_model,
tolerate_deprecated_model=tolerate_deprecated_model,
sagemaker_session=sagemaker_session,
config_name=config_name,
training_config_name=training_config_name,
config_name=inference_config_name,
)

model_init_kwargs: JumpStartModelInitKwargs = model.get_init_kwargs(
Expand Down Expand Up @@ -348,7 +351,7 @@ def get_deploy_kwargs(
tolerate_deprecated_model=tolerate_deprecated_model,
training_instance_type=training_instance_type,
disable_instance_type_logging=True,
config_name=config_name,
config_name=model_deploy_kwargs.config_name,
)

estimator_deploy_kwargs: JumpStartEstimatorDeployKwargs = JumpStartEstimatorDeployKwargs(
Expand Down Expand Up @@ -393,7 +396,7 @@ def get_deploy_kwargs(
tolerate_vulnerable_model=model_init_kwargs.tolerate_vulnerable_model,
tolerate_deprecated_model=model_init_kwargs.tolerate_deprecated_model,
use_compiled_model=use_compiled_model,
config_name=config_name,
config_name=model_deploy_kwargs.config_name,
)

return estimator_deploy_kwargs
Expand Down Expand Up @@ -793,3 +796,27 @@ def _add_fit_extra_kwargs(kwargs: JumpStartEstimatorFitKwargs) -> JumpStartEstim
setattr(kwargs, key, value)

return kwargs


def _add_config_name_to_kwargs(
kwargs: JumpStartEstimatorInitKwargs,
) -> JumpStartEstimatorInitKwargs:
"""Sets tags in kwargs based on default or override, returns full kwargs."""

specs = verify_model_region_and_return_specs(
model_id=kwargs.model_id,
version=kwargs.model_version,
scope=JumpStartScriptScope.TRAINING,
region=kwargs.region,
tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model,
tolerate_deprecated_model=kwargs.tolerate_deprecated_model,
sagemaker_session=kwargs.sagemaker_session,
config_name=kwargs.config_name,
)

if specs.training_configs and specs.training_configs.get_top_config_from_ranking():
kwargs.config_name = (
kwargs.config_name or specs.training_configs.get_top_config_from_ranking().config_name
)

return kwargs
77 changes: 68 additions & 9 deletions src/sagemaker/jumpstart/factory/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
JumpStartModelDeployKwargs,
JumpStartModelInitKwargs,
JumpStartModelRegisterKwargs,
JumpStartModelSpecs,
)
from sagemaker.jumpstart.utils import (
add_jumpstart_model_info_tags,
Expand Down Expand Up @@ -548,7 +549,27 @@ def _add_resources_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartModel
return kwargs


def _add_config_name_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartModelInitKwargs:
def _select_inference_config_from_training_config(
specs: JumpStartModelSpecs, training_config_name: str
) -> Optional[str]:
"""Selects the inference config from the training config.

Args:
specs (JumpStartModelSpecs): The specs for the model.
training_config_name (str): The name of the training config.

Returns:
str: The name of the inference config.
"""
if specs.training_configs:
resolved_training_config = specs.training_configs.configs.get(training_config_name)
if resolved_training_config:
return resolved_training_config.default_inference_config

return None


def _add_config_name_to_init_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartModelInitKwargs:
"""Sets default config name to the kwargs. Returns full kwargs.

Raises:
Expand All @@ -566,13 +587,9 @@ def _add_config_name_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartMod
model_type=kwargs.model_type,
config_name=kwargs.config_name,
)
if (
specs.inference_configs
and specs.inference_configs.get_top_config_from_ranking().config_name
):
kwargs.config_name = (
kwargs.config_name or specs.inference_configs.get_top_config_from_ranking().config_name
)
if specs.inference_configs:
default_config_name = specs.inference_configs.get_top_config_from_ranking().config_name
kwargs.config_name = kwargs.config_name or default_config_name

if not kwargs.config_name:
return kwargs
Expand All @@ -593,6 +610,42 @@ def _add_config_name_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartMod
return kwargs


def _add_config_name_to_deploy_kwargs(
kwargs: JumpStartModelDeployKwargs, training_config_name: Optional[str] = None
) -> JumpStartModelInitKwargs:
"""Sets default config name to the kwargs. Returns full kwargs.

If a training_config_name is passed, then choose the inference config
based on the supported inference configs in that training config.

Raises:
ValueError: If the instance_type is not supported with the current config.
"""

specs = verify_model_region_and_return_specs(
model_id=kwargs.model_id,
version=kwargs.model_version,
scope=JumpStartScriptScope.INFERENCE,
region=kwargs.region,
tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model,
tolerate_deprecated_model=kwargs.tolerate_deprecated_model,
sagemaker_session=kwargs.sagemaker_session,
model_type=kwargs.model_type,
config_name=kwargs.config_name,
)

if training_config_name:
kwargs.config_name = _select_inference_config_from_training_config(
specs=specs, training_config_name=training_config_name
)

if specs.inference_configs:
default_config_name = specs.inference_configs.get_top_config_from_ranking().config_name
kwargs.config_name = kwargs.config_name or default_config_name

return kwargs


def get_deploy_kwargs(
model_id: str,
model_version: Optional[str] = None,
Expand Down Expand Up @@ -623,6 +676,7 @@ def get_deploy_kwargs(
resources: Optional[ResourceRequirements] = None,
managed_instance_scaling: Optional[str] = None,
endpoint_type: Optional[EndpointType] = None,
training_config_name: Optional[str] = None,
config_name: Optional[str] = None,
) -> JumpStartModelDeployKwargs:
"""Returns kwargs required to call `deploy` on `sagemaker.estimator.Model` object."""
Expand Down Expand Up @@ -664,6 +718,10 @@ def get_deploy_kwargs(

deploy_kwargs = _add_endpoint_name_to_kwargs(kwargs=deploy_kwargs)

deploy_kwargs = _add_config_name_to_deploy_kwargs(
kwargs=deploy_kwargs, training_config_name=training_config_name
)

deploy_kwargs = _add_instance_type_to_kwargs(kwargs=deploy_kwargs)

deploy_kwargs.initial_instance_count = initial_instance_count or 1
Expand Down Expand Up @@ -858,6 +916,7 @@ def get_init_kwargs(
model_init_kwargs = _add_model_package_arn_to_kwargs(kwargs=model_init_kwargs)

model_init_kwargs = _add_resources_to_kwargs(kwargs=model_init_kwargs)
model_init_kwargs = _add_config_name_to_kwargs(kwargs=model_init_kwargs)

model_init_kwargs = _add_config_name_to_init_kwargs(kwargs=model_init_kwargs)

return model_init_kwargs
48 changes: 28 additions & 20 deletions src/sagemaker/jumpstart/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -1077,30 +1077,52 @@ class JumpStartMetadataConfig(JumpStartDataHolderType):
"config_components",
"resolved_metadata_config",
"config_name",
"default_inference_config",
"default_incremental_trainig_config",
"supported_inference_configs",
"supported_incremental_training_configs",
]

def __init__(
self,
config_name: str,
config: Dict[str, Any],
base_fields: Dict[str, Any],
config_components: Dict[str, JumpStartConfigComponent],
benchmark_metrics: Dict[str, List[JumpStartBenchmarkStat]],
):
"""Initializes a JumpStartMetadataConfig object from its json representation.

Args:
config_name (str): Name of the config,
config (Dict[str, Any]):
Dictionary representation of the config.
base_fields (Dict[str, Any]):
The default base fields that are used to construct the final resolved config.
config_components (Dict[str, JumpStartConfigComponent]):
The list of components that are used to construct the resolved config.
benchmark_metrics (Dict[str, List[JumpStartBenchmarkStat]]):
The dictionary of benchmark metrics with name being the key.
"""
self.base_fields = base_fields
self.config_components: Dict[str, JumpStartConfigComponent] = config_components
self.benchmark_metrics: Dict[str, List[JumpStartBenchmarkStat]] = benchmark_metrics
self.benchmark_metrics: Dict[str, List[JumpStartBenchmarkStat]] = (
{
stat_name: [JumpStartBenchmarkStat(stat) for stat in stats]
for stat_name, stats in config.get("benchmark_metrics").items()
}
if config and config.get("benchmark_metrics")
else None
)
self.resolved_metadata_config: Optional[Dict[str, Any]] = None
self.config_name: Optional[str] = config_name
self.default_inference_config: Optional[str] = config.get("default_inference_config")
self.default_incremental_trainig_config: Optional[str] = config.get(
"default_incremental_training_config"
)
self.supported_inference_configs: Optional[List[str]] = config.get(
"supported_inference_configs"
)
self.supported_incremental_training_configs: Optional[List[str]] = config.get(
"supported_incremental_training_configs"
)

def to_json(self) -> Dict[str, Any]:
"""Returns json representation of JumpStartMetadataConfig object."""
Expand Down Expand Up @@ -1255,6 +1277,7 @@ def from_json(self, json_obj: Dict[str, Any]) -> None:
{
alias: JumpStartMetadataConfig(
alias,
config,
json_obj,
(
{
Expand All @@ -1264,14 +1287,6 @@ def from_json(self, json_obj: Dict[str, Any]) -> None:
if config and config.get("component_names")
else None
),
(
{
stat_name: [JumpStartBenchmarkStat(stat) for stat in stats]
for stat_name, stats in config.get("benchmark_metrics").items()
}
if config and config.get("benchmark_metrics")
else None
),
)
for alias, config in json_obj["inference_configs"].items()
}
Expand Down Expand Up @@ -1308,6 +1323,7 @@ def from_json(self, json_obj: Dict[str, Any]) -> None:
{
alias: JumpStartMetadataConfig(
alias,
config,
json_obj,
(
{
Expand All @@ -1317,14 +1333,6 @@ def from_json(self, json_obj: Dict[str, Any]) -> None:
if config and config.get("component_names")
else None
),
(
{
stat_name: [JumpStartBenchmarkStat(stat) for stat in stats]
for stat_name, stats in config.get("benchmark_metrics").items()
}
if config and config.get("benchmark_metrics")
else None
),
)
for alias, config in json_obj["training_configs"].items()
}
Expand Down
Loading