Skip to content

Commit

Permalink
fix: Image URI should take precedence for HF models (#4684)
Browse files Browse the repository at this point in the history
* Fix: Image URI should take precedence for HF models

* Fix formatting

* Fix formatting

* Fix formatting

* Increase coverage -  UT pass
  • Loading branch information
samruds committed May 15, 2024
1 parent d4f3c91 commit 65cc586
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 3 deletions.
10 changes: 7 additions & 3 deletions src/sagemaker/serve/builder/transformers_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,17 +132,20 @@ def _create_transformers_model(self) -> Type[Model]:
vpc_config=self.vpc_config,
)

if self.mode == Mode.LOCAL_CONTAINER:
if not self.image_uri and self.mode == Mode.LOCAL_CONTAINER:
self.image_uri = pysdk_model.serving_image_uri(
self.sagemaker_session.boto_region_name, "local"
)
else:
elif not self.image_uri:
self.image_uri = pysdk_model.serving_image_uri(
self.sagemaker_session.boto_region_name, self.instance_type
)

logger.info("Detected %s. Proceeding with the the deployment.", self.image_uri)

if not pysdk_model.image_uri:
pysdk_model.image_uri = self.image_uri

self._original_deploy = pysdk_model.deploy
pysdk_model.deploy = self._transformers_model_builder_deploy_wrapper
return pysdk_model
Expand Down Expand Up @@ -251,13 +254,14 @@ def _set_instance(self, **kwargs):
if self.mode == Mode.SAGEMAKER_ENDPOINT:
if self.nb_instance_type and "instance_type" not in kwargs:
kwargs.update({"instance_type": self.nb_instance_type})
logger.info("Setting instance type to %s", self.nb_instance_type)
elif self.instance_type and "instance_type" not in kwargs:
kwargs.update({"instance_type": self.instance_type})
logger.info("Setting instance type to %s", self.instance_type)
else:
raise ValueError(
"Instance type must be provided when deploying to SageMaker Endpoint mode."
)
logger.info("Setting instance type to %s", self.instance_type)

def _get_supported_version(self, hf_config, hugging_face_version, base_fw):
"""Uses the hugging face json config to pick supported versions"""
Expand Down
44 changes: 44 additions & 0 deletions tests/unit/sagemaker/serve/builder/test_transformers_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,10 @@
mock_schema_builder = MagicMock()
mock_schema_builder.sample_input = mock_sample_input
mock_schema_builder.sample_output = mock_sample_output
MOCK_IMAGE_CONFIG = (
"763104351884.dkr.ecr.us-west-2.amazonaws.com/"
"huggingface-pytorch-inference:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04-v1.0"
)


class TestTransformersBuilder(unittest.TestCase):
Expand Down Expand Up @@ -100,3 +104,43 @@ def test_build_deploy_for_transformers_local_container_and_remote_container(

with self.assertRaises(ValueError) as _:
model.deploy(mode=Mode.IN_PROCESS)

@patch(
"sagemaker.serve.builder.transformers_builder._get_nb_instance",
return_value="ml.g5.24xlarge",
)
@patch("sagemaker.serve.builder.transformers_builder._capture_telemetry", side_effect=None)
def test_image_uri(
self,
mock_get_nb_instance,
mock_telemetry,
):
builder = ModelBuilder(
model=mock_model_id,
schema_builder=mock_schema_builder,
mode=Mode.LOCAL_CONTAINER,
image_uri=MOCK_IMAGE_CONFIG,
)

builder._prepare_for_mode = MagicMock()
builder._prepare_for_mode.side_effect = None

model = builder.build()
builder.serve_settings.telemetry_opt_out = True

builder.modes[str(Mode.LOCAL_CONTAINER)] = MagicMock()
predictor = model.deploy(model_data_download_timeout=1800)

assert builder.image_uri == MOCK_IMAGE_CONFIG
assert builder.env_vars["MODEL_LOADING_TIMEOUT"] == "1800"
assert isinstance(predictor, TransformersLocalModePredictor)

assert builder.nb_instance_type == "ml.g5.24xlarge"

builder._original_deploy = MagicMock()
builder._prepare_for_mode.return_value = (None, {})
predictor = model.deploy(mode=Mode.SAGEMAKER_ENDPOINT, role="mock_role_arn")
assert "HF_MODEL_ID" in model.env

with self.assertRaises(ValueError) as _:
model.deploy(mode=Mode.IN_PROCESS)

0 comments on commit 65cc586

Please sign in to comment.