Skip to content

Commit

Permalink
Add supported inference and incremental training configs (aws#4637)
Browse files Browse the repository at this point in the history
* supported inference configs

* add tests

* format

* tests

* tests

* address comments

* format and address comments

* updates

* formt

* format
  • Loading branch information
Captainia authored and benieric committed May 14, 2024
1 parent 92219f2 commit edc98b9
Show file tree
Hide file tree
Showing 8 changed files with 359 additions and 52 deletions.
21 changes: 16 additions & 5 deletions src/sagemaker/jumpstart/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -507,7 +507,7 @@ def __init__(
enable_session_tag_chaining (bool or PipelineVariable): Optional.
Specifies whether SessionTagChaining is enabled for the training job
config_name (Optional[str]):
Name of the JumpStart Model config to apply. (Default: None).
Name of the training configuration to apply to the Estimator. (Default: None).
Raises:
ValueError: If the model ID is not recognized by JumpStart.
Expand Down Expand Up @@ -690,6 +690,7 @@ def attach(
model_version: Optional[str] = None,
sagemaker_session: session.Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
model_channel_name: str = "model",
config_name: Optional[str] = None,
) -> "JumpStartEstimator":
"""Attach to an existing training job.
Expand Down Expand Up @@ -725,6 +726,8 @@ def attach(
model data will be downloaded (default: 'model'). If no channel
with the same name exists in the training job, this option will
be ignored.
config_name (str): Optional. Name of the training configuration to use
when attaching to the training job. (Default: None).
Returns:
Instance of the calling ``JumpStartEstimator`` Class with the attached
Expand All @@ -736,7 +739,6 @@ def attach(
"""
config_name = None
if model_id is None:

model_id, model_version, _, config_name = get_model_info_from_training_job(
training_job_name=training_job_name, sagemaker_session=sagemaker_session
)
Expand All @@ -750,6 +752,9 @@ def attach(
"tolerate_deprecated_model": True, # model is already trained
}

if config_name:
additional_kwargs.update({"config_name": config_name})

model_specs = verify_model_region_and_return_specs(
model_id=model_id,
version=model_version,
Expand Down Expand Up @@ -808,6 +813,7 @@ def deploy(
dependencies: Optional[List[str]] = None,
git_config: Optional[Dict[str, str]] = None,
use_compiled_model: bool = False,
inference_config_name: Optional[str] = None,
) -> PredictorBase:
"""Creates endpoint from training job.
Expand Down Expand Up @@ -1043,6 +1049,8 @@ def deploy(
(Default: None).
use_compiled_model (bool): Flag to select whether to use compiled
(optimized) model. (Default: False).
inference_config_name (Optional[str]): Name of the inference configuration to
be used in the model. (Default: None).
"""
self.orig_predictor_cls = predictor_cls

Expand Down Expand Up @@ -1095,7 +1103,8 @@ def deploy(
git_config=git_config,
use_compiled_model=use_compiled_model,
training_instance_type=self.instance_type,
config_name=self.config_name,
training_config_name=self.config_name,
inference_config_name=inference_config_name,
)

predictor = super(JumpStartEstimator, self).deploy(
Expand All @@ -1112,7 +1121,7 @@ def deploy(
tolerate_deprecated_model=self.tolerate_deprecated_model,
tolerate_vulnerable_model=self.tolerate_vulnerable_model,
sagemaker_session=self.sagemaker_session,
config_name=self.config_name,
config_name=estimator_deploy_kwargs.config_name,
)

# If a predictor class was passed, do not mutate predictor
Expand Down Expand Up @@ -1144,7 +1153,9 @@ def set_training_config(self, config_name: str) -> None:
config_name (str): The name of the config.
"""
self.__init__(
model_id=self.model_id, model_version=self.model_version, config_name=config_name
model_id=self.model_id,
model_version=self.model_version,
config_name=config_name,
)

def __str__(self) -> str:
Expand Down
35 changes: 31 additions & 4 deletions src/sagemaker/jumpstart/factory/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,7 @@ def get_init_kwargs(
estimator_init_kwargs = _add_role_to_kwargs(estimator_init_kwargs)
estimator_init_kwargs = _add_env_to_kwargs(estimator_init_kwargs)
estimator_init_kwargs = _add_tags_to_kwargs(estimator_init_kwargs)
estimator_init_kwargs = _add_config_name_to_kwargs(estimator_init_kwargs)

return estimator_init_kwargs

Expand Down Expand Up @@ -293,7 +294,8 @@ def get_deploy_kwargs(
use_compiled_model: Optional[bool] = None,
model_name: Optional[str] = None,
training_instance_type: Optional[str] = None,
config_name: Optional[str] = None,
training_config_name: Optional[str] = None,
inference_config_name: Optional[str] = None,
) -> JumpStartEstimatorDeployKwargs:
"""Returns kwargs required to call `deploy` on `sagemaker.estimator.Estimator` object."""

Expand Down Expand Up @@ -321,7 +323,8 @@ def get_deploy_kwargs(
tolerate_vulnerable_model=tolerate_vulnerable_model,
tolerate_deprecated_model=tolerate_deprecated_model,
sagemaker_session=sagemaker_session,
config_name=config_name,
training_config_name=training_config_name,
config_name=inference_config_name,
)

model_init_kwargs: JumpStartModelInitKwargs = model.get_init_kwargs(
Expand Down Expand Up @@ -350,7 +353,7 @@ def get_deploy_kwargs(
tolerate_deprecated_model=tolerate_deprecated_model,
training_instance_type=training_instance_type,
disable_instance_type_logging=True,
config_name=config_name,
config_name=model_deploy_kwargs.config_name,
)

estimator_deploy_kwargs: JumpStartEstimatorDeployKwargs = JumpStartEstimatorDeployKwargs(
Expand Down Expand Up @@ -395,7 +398,7 @@ def get_deploy_kwargs(
tolerate_vulnerable_model=model_init_kwargs.tolerate_vulnerable_model,
tolerate_deprecated_model=model_init_kwargs.tolerate_deprecated_model,
use_compiled_model=use_compiled_model,
config_name=config_name,
config_name=model_deploy_kwargs.config_name,
)

return estimator_deploy_kwargs
Expand Down Expand Up @@ -795,3 +798,27 @@ def _add_fit_extra_kwargs(kwargs: JumpStartEstimatorFitKwargs) -> JumpStartEstim
setattr(kwargs, key, value)

return kwargs


def _add_config_name_to_kwargs(
kwargs: JumpStartEstimatorInitKwargs,
) -> JumpStartEstimatorInitKwargs:
"""Sets tags in kwargs based on default or override, returns full kwargs."""

specs = verify_model_region_and_return_specs(
model_id=kwargs.model_id,
version=kwargs.model_version,
scope=JumpStartScriptScope.TRAINING,
region=kwargs.region,
tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model,
tolerate_deprecated_model=kwargs.tolerate_deprecated_model,
sagemaker_session=kwargs.sagemaker_session,
config_name=kwargs.config_name,
)

if specs.training_configs and specs.training_configs.get_top_config_from_ranking():
kwargs.config_name = (
kwargs.config_name or specs.training_configs.get_top_config_from_ranking().config_name
)

return kwargs
77 changes: 68 additions & 9 deletions src/sagemaker/jumpstart/factory/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
JumpStartModelDeployKwargs,
JumpStartModelInitKwargs,
JumpStartModelRegisterKwargs,
JumpStartModelSpecs,
)
from sagemaker.jumpstart.utils import (
add_jumpstart_model_info_tags,
Expand Down Expand Up @@ -548,7 +549,27 @@ def _add_resources_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartModel
return kwargs


def _add_config_name_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartModelInitKwargs:
def _select_inference_config_from_training_config(
specs: JumpStartModelSpecs, training_config_name: str
) -> Optional[str]:
"""Selects the inference config from the training config.
Args:
specs (JumpStartModelSpecs): The specs for the model.
training_config_name (str): The name of the training config.
Returns:
str: The name of the inference config.
"""
if specs.training_configs:
resolved_training_config = specs.training_configs.configs.get(training_config_name)
if resolved_training_config:
return resolved_training_config.default_inference_config

return None


def _add_config_name_to_init_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartModelInitKwargs:
"""Sets default config name to the kwargs. Returns full kwargs.
Raises:
Expand All @@ -566,13 +587,9 @@ def _add_config_name_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartMod
model_type=kwargs.model_type,
config_name=kwargs.config_name,
)
if (
specs.inference_configs
and specs.inference_configs.get_top_config_from_ranking().config_name
):
kwargs.config_name = (
kwargs.config_name or specs.inference_configs.get_top_config_from_ranking().config_name
)
if specs.inference_configs:
default_config_name = specs.inference_configs.get_top_config_from_ranking().config_name
kwargs.config_name = kwargs.config_name or default_config_name

if not kwargs.config_name:
return kwargs
Expand All @@ -593,6 +610,42 @@ def _add_config_name_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartMod
return kwargs


def _add_config_name_to_deploy_kwargs(
kwargs: JumpStartModelDeployKwargs, training_config_name: Optional[str] = None
) -> JumpStartModelInitKwargs:
"""Sets default config name to the kwargs. Returns full kwargs.
If a training_config_name is passed, then choose the inference config
based on the supported inference configs in that training config.
Raises:
ValueError: If the instance_type is not supported with the current config.
"""

specs = verify_model_region_and_return_specs(
model_id=kwargs.model_id,
version=kwargs.model_version,
scope=JumpStartScriptScope.INFERENCE,
region=kwargs.region,
tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model,
tolerate_deprecated_model=kwargs.tolerate_deprecated_model,
sagemaker_session=kwargs.sagemaker_session,
model_type=kwargs.model_type,
config_name=kwargs.config_name,
)

if training_config_name:
kwargs.config_name = _select_inference_config_from_training_config(
specs=specs, training_config_name=training_config_name
)

if specs.inference_configs:
default_config_name = specs.inference_configs.get_top_config_from_ranking().config_name
kwargs.config_name = kwargs.config_name or default_config_name

return kwargs


def get_deploy_kwargs(
model_id: str,
model_version: Optional[str] = None,
Expand Down Expand Up @@ -623,6 +676,7 @@ def get_deploy_kwargs(
resources: Optional[ResourceRequirements] = None,
managed_instance_scaling: Optional[str] = None,
endpoint_type: Optional[EndpointType] = None,
training_config_name: Optional[str] = None,
config_name: Optional[str] = None,
) -> JumpStartModelDeployKwargs:
"""Returns kwargs required to call `deploy` on `sagemaker.estimator.Model` object."""
Expand Down Expand Up @@ -664,6 +718,10 @@ def get_deploy_kwargs(

deploy_kwargs = _add_endpoint_name_to_kwargs(kwargs=deploy_kwargs)

deploy_kwargs = _add_config_name_to_deploy_kwargs(
kwargs=deploy_kwargs, training_config_name=training_config_name
)

deploy_kwargs = _add_instance_type_to_kwargs(kwargs=deploy_kwargs)

deploy_kwargs.initial_instance_count = initial_instance_count or 1
Expand Down Expand Up @@ -858,6 +916,7 @@ def get_init_kwargs(
model_init_kwargs = _add_model_package_arn_to_kwargs(kwargs=model_init_kwargs)

model_init_kwargs = _add_resources_to_kwargs(kwargs=model_init_kwargs)
model_init_kwargs = _add_config_name_to_kwargs(kwargs=model_init_kwargs)

model_init_kwargs = _add_config_name_to_init_kwargs(kwargs=model_init_kwargs)

return model_init_kwargs
48 changes: 28 additions & 20 deletions src/sagemaker/jumpstart/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -1077,30 +1077,52 @@ class JumpStartMetadataConfig(JumpStartDataHolderType):
"config_components",
"resolved_metadata_config",
"config_name",
"default_inference_config",
"default_incremental_trainig_config",
"supported_inference_configs",
"supported_incremental_training_configs",
]

def __init__(
self,
config_name: str,
config: Dict[str, Any],
base_fields: Dict[str, Any],
config_components: Dict[str, JumpStartConfigComponent],
benchmark_metrics: Dict[str, List[JumpStartBenchmarkStat]],
):
"""Initializes a JumpStartMetadataConfig object from its json representation.
Args:
config_name (str): Name of the config,
config (Dict[str, Any]):
Dictionary representation of the config.
base_fields (Dict[str, Any]):
The default base fields that are used to construct the final resolved config.
config_components (Dict[str, JumpStartConfigComponent]):
The list of components that are used to construct the resolved config.
benchmark_metrics (Dict[str, List[JumpStartBenchmarkStat]]):
The dictionary of benchmark metrics with name being the key.
"""
self.base_fields = base_fields
self.config_components: Dict[str, JumpStartConfigComponent] = config_components
self.benchmark_metrics: Dict[str, List[JumpStartBenchmarkStat]] = benchmark_metrics
self.benchmark_metrics: Dict[str, List[JumpStartBenchmarkStat]] = (
{
stat_name: [JumpStartBenchmarkStat(stat) for stat in stats]
for stat_name, stats in config.get("benchmark_metrics").items()
}
if config and config.get("benchmark_metrics")
else None
)
self.resolved_metadata_config: Optional[Dict[str, Any]] = None
self.config_name: Optional[str] = config_name
self.default_inference_config: Optional[str] = config.get("default_inference_config")
self.default_incremental_trainig_config: Optional[str] = config.get(
"default_incremental_training_config"
)
self.supported_inference_configs: Optional[List[str]] = config.get(
"supported_inference_configs"
)
self.supported_incremental_training_configs: Optional[List[str]] = config.get(
"supported_incremental_training_configs"
)

def to_json(self) -> Dict[str, Any]:
"""Returns json representation of JumpStartMetadataConfig object."""
Expand Down Expand Up @@ -1255,6 +1277,7 @@ def from_json(self, json_obj: Dict[str, Any]) -> None:
{
alias: JumpStartMetadataConfig(
alias,
config,
json_obj,
(
{
Expand All @@ -1264,14 +1287,6 @@ def from_json(self, json_obj: Dict[str, Any]) -> None:
if config and config.get("component_names")
else None
),
(
{
stat_name: [JumpStartBenchmarkStat(stat) for stat in stats]
for stat_name, stats in config.get("benchmark_metrics").items()
}
if config and config.get("benchmark_metrics")
else None
),
)
for alias, config in json_obj["inference_configs"].items()
}
Expand Down Expand Up @@ -1308,6 +1323,7 @@ def from_json(self, json_obj: Dict[str, Any]) -> None:
{
alias: JumpStartMetadataConfig(
alias,
config,
json_obj,
(
{
Expand All @@ -1317,14 +1333,6 @@ def from_json(self, json_obj: Dict[str, Any]) -> None:
if config and config.get("component_names")
else None
),
(
{
stat_name: [JumpStartBenchmarkStat(stat) for stat in stats]
for stat_name, stats in config.get("benchmark_metrics").items()
}
if config and config.get("benchmark_metrics")
else None
),
)
for alias, config in json_obj["training_configs"].items()
}
Expand Down
Loading

0 comments on commit edc98b9

Please sign in to comment.