Skip to content

Commit

Permalink
ModelBuilder: Add functionalities to get and set deployment config. (a…
Browse files Browse the repository at this point in the history
…ws#4614)

* Add funtionalities to get and set deployment config

* Resolve PR comments

* ModelBuilder-JS

* Add Unit tests

* Refactoring

* Testing with Notebook

* Test backward compatibility

* Remove Accelerated column if all not enabled

* Fix docstring

* Resolved PR Review comments

* Docstring

* increase code coverage

---------

Co-authored-by: Jonathan Makunga <[email protected]>
  • Loading branch information
2 people authored and benieric committed May 15, 2024
1 parent bf91274 commit 006e577
Show file tree
Hide file tree
Showing 9 changed files with 348 additions and 69 deletions.
34 changes: 30 additions & 4 deletions src/sagemaker/jumpstart/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from __future__ import absolute_import

from functools import lru_cache
from typing import Dict, List, Optional, Union, Any
from typing import Dict, List, Optional, Any, Union
import pandas as pd
from botocore.exceptions import ClientError

Expand Down Expand Up @@ -441,14 +441,23 @@ def set_deployment_config(self, config_name: Optional[str]) -> None:
model_id=self.model_id, model_version=self.model_version, config_name=config_name
)

@property
def deployment_config(self) -> Optional[Dict[str, Any]]:
"""The deployment config that will be applied to the model.
Returns:
Optional[Dict[str, Any]]: Deployment config that will be applied to the model.
"""
return self._retrieve_selected_deployment_config(self.config_name)

@property
def benchmark_metrics(self) -> pd.DataFrame:
"""Benchmark Metrics for deployment configs
Returns:
Metrics: Pandas DataFrame object.
"""
return pd.DataFrame(self._get_benchmark_data(self.config_name))
return pd.DataFrame(self._get_benchmarks_data(self.config_name))

def display_benchmark_metrics(self) -> None:
"""Display Benchmark Metrics for deployment configs."""
Expand Down Expand Up @@ -851,8 +860,8 @@ def register_deploy_wrapper(*args, **kwargs):
return model_package

@lru_cache
def _get_benchmark_data(self, config_name: str) -> Dict[str, List[str]]:
"""Constructs deployment configs benchmark data.
def _get_benchmarks_data(self, config_name: str) -> Dict[str, List[str]]:
"""Deployment configs benchmark metrics.
Args:
config_name (str): The name of the selected deployment config.
Expand All @@ -864,6 +873,23 @@ def _get_benchmark_data(self, config_name: str) -> Dict[str, List[str]]:
config_name,
)

@lru_cache
def _retrieve_selected_deployment_config(self, config_name: str) -> Optional[Dict[str, Any]]:
"""Retrieve the deployment config to apply to the model.
Args:
config_name (str): The name of the deployment config to retrieve.
Returns:
Optional[Dict[str, Any]]: The retrieved deployment config.
"""
if config_name is None:
return None

for deployment_config in self._deployment_configs:
if deployment_config.get("DeploymentConfigName") == config_name:
return deployment_config
return None

def _convert_to_deployment_config_metadata(
self, config_name: str, metadata_config: JumpStartMetadataConfig
) -> Dict[str, Any]:
Expand Down
18 changes: 10 additions & 8 deletions src/sagemaker/jumpstart/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -2254,17 +2254,17 @@ def to_json(self) -> Dict[str, Any]:
return json_obj


class DeploymentConfig(BaseDeploymentConfigDataHolder):
class DeploymentArgs(BaseDeploymentConfigDataHolder):
"""Dataclass representing a Deployment Config."""

__slots__ = [
"model_data_download_timeout",
"container_startup_health_check_timeout",
"image_uri",
"model_data",
"instance_type",
"environment",
"instance_type",
"compute_resource_requirements",
"model_data_download_timeout",
"container_startup_health_check_timeout",
]

def __init__(
Expand All @@ -2291,9 +2291,10 @@ class DeploymentConfigMetadata(BaseDeploymentConfigDataHolder):
"""Dataclass representing a Deployment Config Metadata"""

__slots__ = [
"config_name",
"deployment_config_name",
"deployment_args",
"acceleration_configs",
"benchmark_metrics",
"deployment_config",
]

def __init__(
Expand All @@ -2304,6 +2305,7 @@ def __init__(
deploy_kwargs: JumpStartModelDeployKwargs,
):
"""Instantiates DeploymentConfigMetadata object."""
self.config_name = config_name
self.deployment_config_name = config_name
self.deployment_args = DeploymentArgs(init_kwargs, deploy_kwargs)
self.acceleration_configs = None
self.benchmark_metrics = benchmark_metrics
self.deployment_config = DeploymentConfig(init_kwargs, deploy_kwargs)
28 changes: 23 additions & 5 deletions src/sagemaker/jumpstart/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1040,24 +1040,40 @@ def extract_metrics_from_deployment_configs(
config_name (str): The name of the deployment config use by the model.
"""

data = {"Config Name": [], "Instance Type": [], "Selected": []}
data = {"Config Name": [], "Instance Type": [], "Selected": [], "Accelerated": []}

for index, deployment_config in enumerate(deployment_configs):
if deployment_config.get("DeploymentConfig") is None:
if deployment_config.get("DeploymentArgs") is None:
continue

benchmark_metrics = deployment_config.get("BenchmarkMetrics")
if benchmark_metrics is not None:
data["Config Name"].append(deployment_config.get("ConfigName"))
data["Config Name"].append(deployment_config.get("DeploymentConfigName"))
data["Instance Type"].append(
deployment_config.get("DeploymentConfig").get("InstanceType")
deployment_config.get("DeploymentArgs").get("InstanceType")
)
data["Selected"].append(
"Yes"
if (config_name is not None and config_name == deployment_config.get("ConfigName"))
if (
config_name is not None
and config_name == deployment_config.get("DeploymentConfigName")
)
else "No"
)

accelerated_configs = deployment_config.get("AccelerationConfigs")
if accelerated_configs is None:
data["Accelerated"].append("No")
else:
data["Accelerated"].append(
"Yes"
if (
len(accelerated_configs) > 0
and accelerated_configs[0].get("Enabled", False)
)
else "No"
)

if index == 0:
for benchmark_metric in benchmark_metrics:
column_name = f"{benchmark_metric.get('name')} ({benchmark_metric.get('unit')})"
Expand All @@ -1068,4 +1084,6 @@ def extract_metrics_from_deployment_configs(
if column_name in data.keys():
data[column_name].append(benchmark_metric.get("value"))

if "Yes" not in data["Accelerated"]:
del data["Accelerated"]
return data
57 changes: 41 additions & 16 deletions src/sagemaker/serve/builder/jumpstart_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import copy
from abc import ABC, abstractmethod
from datetime import datetime, timedelta
from typing import Type, Any, List, Dict
from typing import Type, Any, List, Dict, Optional
import logging

from sagemaker.model import Model
Expand Down Expand Up @@ -431,8 +431,35 @@ def tune_for_tgi_jumpstart(self, max_tuning_duration: int = 1800):
sharded_supported=sharded_supported, max_tuning_duration=max_tuning_duration
)

def set_deployment_config(self, config_name: Optional[str]) -> None:
"""Sets the deployment config to apply to the model.
Args:
config_name (Optional[str]):
The name of the deployment config. Set to None to unset
any existing config that is applied to the model.
"""
if not hasattr(self, "pysdk_model") or self.pysdk_model is None:
raise Exception("Cannot set deployment config to an uninitialized model.")

self.pysdk_model.set_deployment_config(config_name)

def get_deployment_config(self) -> Optional[Dict[str, Any]]:
"""Gets the deployment config to apply to the model.
Returns:
Optional[Dict[str, Any]]: Deployment config to apply to this model.
"""
if not hasattr(self, "pysdk_model") or self.pysdk_model is None:
self.pysdk_model = self._create_pre_trained_js_model()

return self.pysdk_model.deployment_config

def display_benchmark_metrics(self):
"""Display Markdown Benchmark Metrics for deployment configs."""
if not hasattr(self, "pysdk_model") or self.pysdk_model is None:
self.pysdk_model = self._create_pre_trained_js_model()

self.pysdk_model.display_benchmark_metrics()

def list_deployment_configs(self) -> List[Dict[str, Any]]:
Expand All @@ -441,6 +468,9 @@ def list_deployment_configs(self) -> List[Dict[str, Any]]:
Returns:
List[Dict[str, Any]]: A list of deployment configs.
"""
if not hasattr(self, "pysdk_model") or self.pysdk_model is None:
self.pysdk_model = self._create_pre_trained_js_model()

return self.pysdk_model.list_deployment_configs()

def _build_for_jumpstart(self):
Expand All @@ -449,32 +479,29 @@ def _build_for_jumpstart(self):
self.secret_key = None
self.jumpstart = True

pysdk_model = self._create_pre_trained_js_model()

image_uri = pysdk_model.image_uri
if not hasattr(self, "pysdk_model") or self.pysdk_model is None:
self.pysdk_model = self._create_pre_trained_js_model()

logger.info("JumpStart ID %s is packaged with Image URI: %s", self.model, image_uri)
logger.info(
"JumpStart ID %s is packaged with Image URI: %s", self.model, self.pysdk_model.image_uri
)

if self._is_gated_model(pysdk_model) and self.mode != Mode.SAGEMAKER_ENDPOINT:
if self._is_gated_model() and self.mode != Mode.SAGEMAKER_ENDPOINT:
raise ValueError(
"JumpStart Gated Models are only supported in SAGEMAKER_ENDPOINT mode."
)

if "djl-inference" in image_uri:
if "djl-inference" in self.pysdk_model.image_uri:
logger.info("Building for DJL JumpStart Model ID...")
self.model_server = ModelServer.DJL_SERVING

self.pysdk_model = pysdk_model
self.image_uri = self.pysdk_model.image_uri

self._build_for_djl_jumpstart()

self.pysdk_model.tune = self.tune_for_djl_jumpstart
elif "tgi-inference" in image_uri:
elif "tgi-inference" in self.pysdk_model.image_uri:
logger.info("Building for TGI JumpStart Model ID...")
self.model_server = ModelServer.TGI

self.pysdk_model = pysdk_model
self.image_uri = self.pysdk_model.image_uri

self._build_for_tgi_jumpstart()
Expand All @@ -487,15 +514,13 @@ def _build_for_jumpstart(self):

return self.pysdk_model

def _is_gated_model(self, model) -> bool:
def _is_gated_model(self) -> bool:
"""Determine if ``this`` Model is Gated
Args:
model (Model): Jumpstart Model
Returns:
bool: ``True`` if ``this`` Model is Gated
"""
s3_uri = model.model_data
s3_uri = self.pysdk_model.model_data
if isinstance(s3_uri, dict):
s3_uri = s3_uri.get("S3DataSource").get("S3Uri")

Expand Down
Loading

0 comments on commit 006e577

Please sign in to comment.