Skip to content

Commit

Permalink
fix: populate default config name to model (#4617)
Browse files Browse the repository at this point in the history
* fix: populate default config name to model

* update condition

* fix

* format

* flake8

* fix tests

* fix coverage

* temporarily skip integ test vulnerbility

* fix tolerate attach method

* format

* fix predictor

* format
  • Loading branch information
Captainia committed Apr 26, 2024
1 parent 8a37956 commit 2b49e05
Show file tree
Hide file tree
Showing 9 changed files with 71 additions and 10 deletions.
7 changes: 6 additions & 1 deletion src/sagemaker/jumpstart/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -739,7 +739,12 @@ def attach(

model_version = model_version or "*"

additional_kwargs = {"model_id": model_id, "model_version": model_version}
additional_kwargs = {
"model_id": model_id,
"model_version": model_version,
"tolerate_vulnerable_model": True, # model is already trained
"tolerate_deprecated_model": True, # model is already trained
}

model_specs = verify_model_region_and_return_specs(
model_id=model_id,
Expand Down
26 changes: 26 additions & 0 deletions src/sagemaker/jumpstart/factory/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -543,6 +543,31 @@ def _add_resources_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartModel
return kwargs


def _add_config_name_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartModelInitKwargs:
"""Sets default config name to the kwargs. Returns full kwargs."""

specs = verify_model_region_and_return_specs(
model_id=kwargs.model_id,
version=kwargs.model_version,
scope=JumpStartScriptScope.INFERENCE,
region=kwargs.region,
tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model,
tolerate_deprecated_model=kwargs.tolerate_deprecated_model,
sagemaker_session=kwargs.sagemaker_session,
model_type=kwargs.model_type,
config_name=kwargs.config_name,
)
if (
specs.inference_configs
and specs.inference_configs.get_top_config_from_ranking().config_name
):
kwargs.config_name = (
kwargs.config_name or specs.inference_configs.get_top_config_from_ranking().config_name
)

return kwargs


def get_deploy_kwargs(
model_id: str,
model_version: Optional[str] = None,
Expand Down Expand Up @@ -808,5 +833,6 @@ def get_init_kwargs(
model_init_kwargs = _add_model_package_arn_to_kwargs(kwargs=model_init_kwargs)

model_init_kwargs = _add_resources_to_kwargs(kwargs=model_init_kwargs)
model_init_kwargs = _add_config_name_to_kwargs(kwargs=model_init_kwargs)

return model_init_kwargs
2 changes: 1 addition & 1 deletion src/sagemaker/jumpstart/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,7 +351,7 @@ def _validate_model_id_and_type():
self.tolerate_deprecated_model = model_init_kwargs.tolerate_deprecated_model
self.region = model_init_kwargs.region
self.sagemaker_session = model_init_kwargs.sagemaker_session
self.config_name = config_name
self.config_name = model_init_kwargs.config_name

if self.model_type == JumpStartModelType.PROPRIETARY:
self.log_subscription_warning()
Expand Down
5 changes: 5 additions & 0 deletions src/sagemaker/jumpstart/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -1076,10 +1076,12 @@ class JumpStartMetadataConfig(JumpStartDataHolderType):
"benchmark_metrics",
"config_components",
"resolved_metadata_config",
"config_name",
]

def __init__(
self,
config_name: str,
base_fields: Dict[str, Any],
config_components: Dict[str, JumpStartConfigComponent],
benchmark_metrics: Dict[str, List[JumpStartBenchmarkStat]],
Expand All @@ -1098,6 +1100,7 @@ def __init__(
self.config_components: Dict[str, JumpStartConfigComponent] = config_components
self.benchmark_metrics: Dict[str, List[JumpStartBenchmarkStat]] = benchmark_metrics
self.resolved_metadata_config: Optional[Dict[str, Any]] = None
self.config_name: Optional[str] = config_name

def to_json(self) -> Dict[str, Any]:
"""Returns json representation of JumpStartMetadataConfig object."""
Expand Down Expand Up @@ -1251,6 +1254,7 @@ def from_json(self, json_obj: Dict[str, Any]) -> None:
inference_configs_dict: Optional[Dict[str, JumpStartMetadataConfig]] = (
{
alias: JumpStartMetadataConfig(
alias,
json_obj,
(
{
Expand Down Expand Up @@ -1303,6 +1307,7 @@ def from_json(self, json_obj: Dict[str, Any]) -> None:
training_configs_dict: Optional[Dict[str, JumpStartMetadataConfig]] = (
{
alias: JumpStartMetadataConfig(
alias,
json_obj,
(
{
Expand Down
6 changes: 4 additions & 2 deletions src/sagemaker/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def retrieve_default(
tolerate_vulnerable_model: bool = False,
tolerate_deprecated_model: bool = False,
model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS,
config_name: Optional[str] = None,
) -> Predictor:
"""Retrieves the default predictor for the model matching the given arguments.
Expand All @@ -65,6 +66,8 @@ def retrieve_default(
tolerate_deprecated_model (bool): True if deprecated models should be tolerated
(exception not raised). False if these models should raise an exception.
(Default: False).
config_name (Optional[str]): The name of the configuration to use for the
predictor. (Default: None)
Returns:
Predictor: The default predictor to use for the model.
Expand All @@ -91,10 +94,9 @@ def retrieve_default(
model_id = inferred_model_id
model_version = model_version or inferred_model_version or "*"
inference_component_name = inference_component_name or inferred_inference_component_name
config_name = inferred_config_name or None
config_name = config_name or inferred_config_name or None
else:
model_version = model_version or "*"
config_name = None

predictor = Predictor(
endpoint_name=endpoint_name,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ def test_gated_model_training_v2(setup):
tags=[{"Key": JUMPSTART_TAG, "Value": os.environ[ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID]}],
environment={"accept_eula": "true"},
max_run=259200, # avoid exceeding resource limits
tolerate_vulnerable_model=True, # TODO: remove once vulnerbility is patched
)

# uses ml.g5.12xlarge instance
Expand Down
4 changes: 4 additions & 0 deletions tests/unit/sagemaker/jumpstart/estimator/test_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1011,6 +1011,8 @@ def test_jumpstart_estimator_attach_eula_model(
additional_kwargs={
"model_id": "gemma-model",
"model_version": "*",
"tolerate_vulnerable_model": True,
"tolerate_deprecated_model": True,
"environment": {"accept_eula": "true"},
},
)
Expand Down Expand Up @@ -1056,6 +1058,8 @@ def test_jumpstart_estimator_attach_no_model_id_happy_case(
additional_kwargs={
"model_id": "js-trainable-model-prepacked",
"model_version": "1.0.0",
"tolerate_vulnerable_model": True,
"tolerate_deprecated_model": True,
},
)

Expand Down
21 changes: 18 additions & 3 deletions tests/unit/sagemaker/jumpstart/model/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1552,6 +1552,8 @@ def test_model_initialization_with_config_name(

model = JumpStartModel(model_id=model_id, config_name="neuron-inference")

assert model.config_name == "neuron-inference"

model.deploy()

mock_model_deploy.assert_called_once_with(
Expand Down Expand Up @@ -1594,6 +1596,8 @@ def test_model_set_deployment_config(

model = JumpStartModel(model_id=model_id)

assert model.config_name is None

model.deploy()

mock_model_deploy.assert_called_once_with(
Expand All @@ -1612,6 +1616,8 @@ def test_model_set_deployment_config(
mock_get_model_specs.side_effect = get_prototype_spec_with_configs
model.set_deployment_config("neuron-inference")

assert model.config_name == "neuron-inference"

model.deploy()

mock_model_deploy.assert_called_once_with(
Expand Down Expand Up @@ -1654,6 +1660,8 @@ def test_model_unset_deployment_config(

model = JumpStartModel(model_id=model_id, config_name="neuron-inference")

assert model.config_name == "neuron-inference"

model.deploy()

mock_model_deploy.assert_called_once_with(
Expand Down Expand Up @@ -1789,7 +1797,6 @@ def test_model_retrieve_deployment_config(
):
model_id, _ = "pytorch-eqa-bert-base-cased", "*"

mock_get_init_kwargs.side_effect = lambda *args, **kwargs: get_mock_init_kwargs(model_id)
mock_verify_model_region_and_return_specs.side_effect = (
lambda *args, **kwargs: get_base_spec_with_prototype_configs_with_missing_benchmarks()
)
Expand All @@ -1804,15 +1811,23 @@ def test_model_retrieve_deployment_config(
)
mock_model_deploy.return_value = default_predictor

expected = get_base_deployment_configs()[0]
config_name = expected.get("DeploymentConfigName")
mock_get_init_kwargs.side_effect = lambda *args, **kwargs: get_mock_init_kwargs(
model_id, config_name
)

mock_session.return_value = sagemaker_session

model = JumpStartModel(model_id=model_id)

expected = get_base_deployment_configs()[0]
model.set_deployment_config(expected.get("DeploymentConfigName"))
model.set_deployment_config(config_name)

self.assertEqual(model.deployment_config, expected)

mock_get_init_kwargs.reset_mock()
mock_get_init_kwargs.side_effect = lambda *args, **kwargs: get_mock_init_kwargs(model_id)

# Unset
model.set_deployment_config(None)
self.assertIsNone(model.deployment_config)
Expand Down
9 changes: 6 additions & 3 deletions tests/unit/sagemaker/jumpstart/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# language governing permissions and limitations under the License.
from __future__ import absolute_import
import copy
from typing import List, Dict, Any
from typing import List, Dict, Any, Optional
import boto3

from sagemaker.compute_resource_requirements import ResourceRequirements
Expand Down Expand Up @@ -237,7 +237,7 @@ def get_base_spec_with_prototype_configs_with_missing_benchmarks(
copy_inference_configs = copy.deepcopy(INFERENCE_CONFIGS)
copy_inference_configs["inference_configs"]["neuron-inference"]["benchmark_metrics"] = None

inference_configs = {**INFERENCE_CONFIGS, **INFERENCE_CONFIG_RANKINGS}
inference_configs = {**copy_inference_configs, **INFERENCE_CONFIG_RANKINGS}
training_configs = {**TRAINING_CONFIGS, **TRAINING_CONFIG_RANKINGS}

spec.update(inference_configs)
Expand Down Expand Up @@ -335,7 +335,9 @@ def get_base_deployment_configs_with_acceleration_configs() -> List[Dict[str, An
return configs


def get_mock_init_kwargs(model_id) -> JumpStartModelInitKwargs:
def get_mock_init_kwargs(
model_id: str, config_name: Optional[str] = None
) -> JumpStartModelInitKwargs:
return JumpStartModelInitKwargs(
model_id=model_id,
model_type=JumpStartModelType.OPEN_WEIGHTS,
Expand All @@ -344,4 +346,5 @@ def get_mock_init_kwargs(model_id) -> JumpStartModelInitKwargs:
instance_type=INIT_KWARGS.get("instance_type"),
env=INIT_KWARGS.get("env"),
resources=ResourceRequirements(),
config_name=config_name,
)

0 comments on commit 2b49e05

Please sign in to comment.