Skip to content

Commit

Permalink
Fix formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
samruds committed May 15, 2024
1 parent 3686dd9 commit 8cad99f
Showing 1 changed file with 6 additions and 10 deletions.
16 changes: 6 additions & 10 deletions tests/unit/sagemaker/serve/builder/test_transformers_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import unittest
from sagemaker.serve.builder.model_builder import ModelBuilder
from sagemaker.serve.mode.function_pointers import Mode
from tests.unit.sagemaker.serve.constants import MOCK_IMAGE_CONFIG, MOCK_VPC_CONFIG
from tests.unit.sagemaker.serve.constants import MOCK_VPC_CONFIG

from sagemaker.serve.utils.predictors import TransformersLocalModePredictor

Expand Down Expand Up @@ -58,6 +58,9 @@
mock_schema_builder = MagicMock()
mock_schema_builder.sample_input = mock_sample_input
mock_schema_builder.sample_output = mock_sample_output
MOCK_IMAGE_CONFIG = \
"763104351884.dkr.ecr.us-west-2.amazonaws.com/" \
"huggingface-pytorch-inference:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04-v1.0"


class TestTransformersBuilder(unittest.TestCase):
Expand Down Expand Up @@ -115,8 +118,7 @@ def test_image_uri(
model=mock_model_id,
schema_builder=mock_schema_builder,
mode=Mode.LOCAL_CONTAINER,
vpc_config=MOCK_VPC_CONFIG,
image_config=MOCK_IMAGE_CONFIG,
image_uri=MOCK_IMAGE_CONFIG,
)

builder._prepare_for_mode = MagicMock()
Expand All @@ -128,17 +130,11 @@ def test_image_uri(
builder.modes[str(Mode.LOCAL_CONTAINER)] = MagicMock()
predictor = model.deploy(model_data_download_timeout=1800)

assert model.image_config == MOCK_IMAGE_CONFIG
assert model.vpc_config == MOCK_VPC_CONFIG
assert builder.image_uri == MOCK_IMAGE_CONFIG
assert builder.env_vars["MODEL_LOADING_TIMEOUT"] == "1800"
assert isinstance(predictor, TransformersLocalModePredictor)

assert builder.nb_instance_type == "ml.g5.24xlarge"

builder._original_deploy = MagicMock()
builder._prepare_for_mode.return_value = (None, {})
predictor = model.deploy(mode=Mode.SAGEMAKER_ENDPOINT, role="mock_role_arn")
assert "HF_MODEL_ID" in model.env

with self.assertRaises(ValueError) as _:
model.deploy(mode=Mode.IN_PROCESS)

0 comments on commit 8cad99f

Please sign in to comment.