Skip to content

Commit

Permalink
Use separate tags for inference and training configs (aws#4635)
Browse files Browse the repository at this point in the history
* Use separate tags for inference and training

* format

* format

* format

* format
  • Loading branch information
Captainia authored and benieric committed May 14, 2024
1 parent 7b5ef04 commit 92219f2
Show file tree
Hide file tree
Showing 14 changed files with 417 additions and 180 deletions.
4 changes: 3 additions & 1 deletion src/sagemaker/jumpstart/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,9 @@ class JumpStartTag(str, Enum):
MODEL_ID = "sagemaker-sdk:jumpstart-model-id"
MODEL_VERSION = "sagemaker-sdk:jumpstart-model-version"
MODEL_TYPE = "sagemaker-sdk:jumpstart-model-type"
MODEL_CONFIG_NAME = "sagemaker-sdk:jumpstart-model-config-name"

INFERENCE_CONFIG_NAME = "sagemaker-sdk:jumpstart-inference-config-name"
TRAINING_CONFIG_NAME = "sagemaker-sdk:jumpstart-training-config-name"


class SerializerType(str, Enum):
Expand Down
6 changes: 4 additions & 2 deletions src/sagemaker/jumpstart/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -737,7 +737,7 @@ def attach(
config_name = None
if model_id is None:

model_id, model_version, config_name = get_model_info_from_training_job(
model_id, model_version, _, config_name = get_model_info_from_training_job(
training_job_name=training_job_name, sagemaker_session=sagemaker_session
)

Expand Down Expand Up @@ -1143,7 +1143,9 @@ def set_training_config(self, config_name: str) -> None:
Args:
config_name (str): The name of the config.
"""
self.__init__(**self.init_kwargs, config_name=config_name)
self.__init__(
model_id=self.model_id, model_version=self.model_version, config_name=config_name
)

def __str__(self) -> str:
"""Overriding str(*) method to make more human-readable."""
Expand Down
5 changes: 3 additions & 2 deletions src/sagemaker/jumpstart/factory/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@
JumpStartModelInitKwargs,
)
from sagemaker.jumpstart.utils import (
add_jumpstart_model_id_version_tags,
add_jumpstart_model_info_tags,
get_eula_message,
update_dict_if_key_not_present,
resolve_estimator_sagemaker_config_field,
Expand Down Expand Up @@ -479,11 +479,12 @@ def _add_tags_to_kwargs(kwargs: JumpStartEstimatorInitKwargs) -> JumpStartEstima
).version

if kwargs.sagemaker_session.settings.include_jumpstart_tags:
kwargs.tags = add_jumpstart_model_id_version_tags(
kwargs.tags = add_jumpstart_model_info_tags(
kwargs.tags,
kwargs.model_id,
full_model_version,
config_name=kwargs.config_name,
scope=JumpStartScriptScope.TRAINING,
)
return kwargs

Expand Down
11 changes: 8 additions & 3 deletions src/sagemaker/jumpstart/factory/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
JumpStartModelRegisterKwargs,
)
from sagemaker.jumpstart.utils import (
add_jumpstart_model_id_version_tags,
add_jumpstart_model_info_tags,
update_dict_if_key_not_present,
resolve_model_sagemaker_config_field,
verify_model_region_and_return_specs,
Expand Down Expand Up @@ -495,8 +495,13 @@ def _add_tags_to_kwargs(kwargs: JumpStartModelDeployKwargs) -> Dict[str, Any]:
).version

if kwargs.sagemaker_session.settings.include_jumpstart_tags:
kwargs.tags = add_jumpstart_model_id_version_tags(
kwargs.tags, kwargs.model_id, full_model_version, kwargs.model_type, kwargs.config_name
kwargs.tags = add_jumpstart_model_info_tags(
kwargs.tags,
kwargs.model_id,
full_model_version,
kwargs.model_type,
config_name=kwargs.config_name,
scope=JumpStartScriptScope.INFERENCE,
)

return kwargs
Expand Down
58 changes: 39 additions & 19 deletions src/sagemaker/jumpstart/session_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from typing import Optional, Tuple
from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION

from sagemaker.jumpstart.utils import get_jumpstart_model_id_version_from_resource_arn
from sagemaker.jumpstart.utils import get_jumpstart_model_info_from_resource_arn
from sagemaker.session import Session
from sagemaker.utils import aws_partition

Expand All @@ -26,7 +26,7 @@ def get_model_info_from_endpoint(
endpoint_name: str,
inference_component_name: Optional[str] = None,
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
) -> Tuple[str, str, Optional[str], Optional[str]]:
) -> Tuple[str, str, Optional[str], Optional[str], Optional[str]]:
"""Optionally inference component names, return the model ID, version and config name.
Infers the model ID and version based on the resource tags. Returns a tuple of the model ID
Expand All @@ -46,7 +46,8 @@ def get_model_info_from_endpoint(
(
model_id,
model_version,
config_name,
inference_config_name,
training_config_name,
) = _get_model_info_from_inference_component_endpoint_with_inference_component_name( # noqa E501 # pylint: disable=c0301
inference_component_name, sagemaker_session
)
Expand All @@ -55,17 +56,29 @@ def get_model_info_from_endpoint(
(
model_id,
model_version,
config_name,
inference_config_name,
training_config_name,
inference_component_name,
) = _get_model_info_from_inference_component_endpoint_without_inference_component_name( # noqa E501 # pylint: disable=c0301
endpoint_name, sagemaker_session
)

else:
model_id, model_version, config_name = _get_model_info_from_model_based_endpoint(
(
model_id,
model_version,
inference_config_name,
training_config_name,
) = _get_model_info_from_model_based_endpoint(
endpoint_name, inference_component_name, sagemaker_session
)
return model_id, model_version, inference_component_name, config_name
return (
model_id,
model_version,
inference_component_name,
inference_config_name,
training_config_name,
)


def _get_model_info_from_inference_component_endpoint_without_inference_component_name(
Expand Down Expand Up @@ -125,9 +138,12 @@ def _get_model_info_from_inference_component_endpoint_with_inference_component_n
f"inference-component/{inference_component_name}"
)

model_id, model_version, config_name = get_jumpstart_model_id_version_from_resource_arn(
inference_component_arn, sagemaker_session
)
(
model_id,
model_version,
inference_config_name,
training_config_name,
) = get_jumpstart_model_info_from_resource_arn(inference_component_arn, sagemaker_session)

if not model_id:
raise ValueError(
Expand All @@ -136,14 +152,14 @@ def _get_model_info_from_inference_component_endpoint_with_inference_component_n
"when retrieving default predictor for this inference component."
)

return model_id, model_version, config_name
return model_id, model_version, inference_config_name, training_config_name


def _get_model_info_from_model_based_endpoint(
endpoint_name: str,
inference_component_name: Optional[str],
sagemaker_session: Session,
) -> Tuple[str, str, Optional[str]]:
) -> Tuple[str, str, Optional[str], Optional[str]]:
"""Returns the model ID, version and config name inferred from a model-based endpoint.
Raises:
Expand All @@ -163,9 +179,12 @@ def _get_model_info_from_model_based_endpoint(

endpoint_arn = f"arn:{partition}:sagemaker:{region}:{account_id}:endpoint/{endpoint_name}"

model_id, model_version, config_name = get_jumpstart_model_id_version_from_resource_arn(
endpoint_arn, sagemaker_session
)
(
model_id,
model_version,
inference_config_name,
training_config_name,
) = get_jumpstart_model_info_from_resource_arn(endpoint_arn, sagemaker_session)

if not model_id:
raise ValueError(
Expand All @@ -174,13 +193,13 @@ def _get_model_info_from_model_based_endpoint(
"predictor for this endpoint."
)

return model_id, model_version, config_name
return model_id, model_version, inference_config_name, training_config_name


def get_model_info_from_training_job(
training_job_name: str,
sagemaker_session: Optional[Session] = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
) -> Tuple[str, str, Optional[str]]:
) -> Tuple[str, str, Optional[str], Optional[str]]:
"""Returns the model ID and version and config name inferred from a training job.
Raises:
Expand All @@ -199,8 +218,9 @@ def get_model_info_from_training_job(
(
model_id,
inferred_model_version,
config_name,
) = get_jumpstart_model_id_version_from_resource_arn(training_job_arn, sagemaker_session)
inference_config_name,
trainig_config_name,
) = get_jumpstart_model_info_from_resource_arn(training_job_arn, sagemaker_session)

model_version = inferred_model_version or None

Expand All @@ -211,4 +231,4 @@ def get_model_info_from_training_job(
"for this training job."
)

return model_id, model_version, config_name
return model_id, model_version, inference_config_name, trainig_config_name
39 changes: 28 additions & 11 deletions src/sagemaker/jumpstart/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,7 +320,8 @@ def add_single_jumpstart_tag(
tag_key_in_array(enums.JumpStartTag.MODEL_ID, curr_tags)
or tag_key_in_array(enums.JumpStartTag.MODEL_VERSION, curr_tags)
or tag_key_in_array(enums.JumpStartTag.MODEL_TYPE, curr_tags)
or tag_key_in_array(enums.JumpStartTag.MODEL_CONFIG_NAME, curr_tags)
or tag_key_in_array(enums.JumpStartTag.INFERENCE_CONFIG_NAME, curr_tags)
or tag_key_in_array(enums.JumpStartTag.TRAINING_CONFIG_NAME, curr_tags)
)
if is_uri
else False
Expand Down Expand Up @@ -351,12 +352,13 @@ def get_jumpstart_base_name_if_jumpstart_model(
return None


def add_jumpstart_model_id_version_tags(
def add_jumpstart_model_info_tags(
tags: Optional[List[TagsDict]],
model_id: str,
model_version: str,
model_type: Optional[enums.JumpStartModelType] = None,
config_name: Optional[str] = None,
scope: enums.JumpStartScriptScope = None,
) -> List[TagsDict]:
"""Add custom model ID and version tags to JumpStart related resources."""
if model_id is None or model_version is None:
Expand All @@ -380,10 +382,17 @@ def add_jumpstart_model_id_version_tags(
tags,
is_uri=False,
)
if config_name:
if config_name and scope == enums.JumpStartScriptScope.INFERENCE:
tags = add_single_jumpstart_tag(
config_name,
enums.JumpStartTag.MODEL_CONFIG_NAME,
enums.JumpStartTag.INFERENCE_CONFIG_NAME,
tags,
is_uri=False,
)
if config_name and scope == enums.JumpStartScriptScope.TRAINING:
tags = add_single_jumpstart_tag(
config_name,
enums.JumpStartTag.TRAINING_CONFIG_NAME,
tags,
is_uri=False,
)
Expand Down Expand Up @@ -840,10 +849,10 @@ def _extract_value_from_list_of_tags(
return resolved_value


def get_jumpstart_model_id_version_from_resource_arn(
def get_jumpstart_model_info_from_resource_arn(
resource_arn: str,
sagemaker_session: Session = constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
) -> Tuple[Optional[str], Optional[str], Optional[str]]:
) -> Tuple[Optional[str], Optional[str], Optional[str], Optional[str]]:
"""Returns the JumpStart model ID, version and config name if in resource tags.
Returns 'None' if model ID or version or config name cannot be inferred from tags.
Expand All @@ -853,7 +862,8 @@ def get_jumpstart_model_id_version_from_resource_arn(

model_id_keys = [enums.JumpStartTag.MODEL_ID, *constants.EXTRA_MODEL_ID_TAGS]
model_version_keys = [enums.JumpStartTag.MODEL_VERSION, *constants.EXTRA_MODEL_VERSION_TAGS]
model_config_name_keys = [enums.JumpStartTag.MODEL_CONFIG_NAME]
inference_config_name_keys = [enums.JumpStartTag.INFERENCE_CONFIG_NAME]
training_config_name_keys = [enums.JumpStartTag.TRAINING_CONFIG_NAME]

model_id: Optional[str] = _extract_value_from_list_of_tags(
tag_keys=model_id_keys,
Expand All @@ -869,14 +879,21 @@ def get_jumpstart_model_id_version_from_resource_arn(
resource_arn=resource_arn,
)

config_name: Optional[str] = _extract_value_from_list_of_tags(
tag_keys=model_config_name_keys,
inference_config_name: Optional[str] = _extract_value_from_list_of_tags(
tag_keys=inference_config_name_keys,
list_tags_result=list_tags_result,
resource_name="inference config name",
resource_arn=resource_arn,
)

training_config_name: Optional[str] = _extract_value_from_list_of_tags(
tag_keys=training_config_name_keys,
list_tags_result=list_tags_result,
resource_name="model config name",
resource_name="training config name",
resource_arn=resource_arn,
)

return model_id, model_version, config_name
return model_id, model_version, inference_config_name, training_config_name


def get_region_fallback(
Expand Down
1 change: 1 addition & 0 deletions src/sagemaker/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ def retrieve_default(
inferred_model_version,
inferred_inference_component_name,
inferred_config_name,
_,
) = get_model_info_from_endpoint(endpoint_name, inference_component_name, sagemaker_session)

if not inferred_model_id:
Expand Down
Loading

0 comments on commit 92219f2

Please sign in to comment.