diff --git a/examples/non_persistent/text-generation-bloom560-example.py b/examples/non_persistent/text-generation-bloom560-example.py index 92bb9159..67fe0cf1 100644 --- a/examples/non_persistent/text-generation-bloom560-example.py +++ b/examples/non_persistent/text-generation-bloom560-example.py @@ -5,10 +5,11 @@ import mii mii_configs = {"tensor_parallel": 1, "dtype": "fp16"} -generator = mii.deploy_non_persistent(task='text-generation', - model="bigscience/bloom-560m", - deployment_name="bloom560m_deployment", - deployment_type=mii.constants.DeploymentType.NON_PERSISTENT, - mii_config=mii_configs) +generator = mii.deploy_non_persistent( + task='text-generation', + model="bigscience/bloom-560m", + deployment_name="bloom560m_deployment", + deployment_type=mii.constants.DeploymentType.NON_PERSISTENT, + mii_config=mii_configs) result = generator.query({'query': ["DeepSpeed is the", "Seattle is"]}) print(result) diff --git a/mii/deployment.py b/mii/deployment.py index a5fc8970..7ad95b5f 100644 --- a/mii/deployment.py +++ b/mii/deployment.py @@ -66,6 +66,7 @@ def deploy(task, If deployment_type is `LOCAL`, returns just the name of the deployment that can be used to create a query handle using `mii.mii_query_handle(deployment_name)` """ + # parse and validate mii config mii_config = mii.config.MIIConfig(**mii_config) if enable_zero: @@ -175,17 +176,11 @@ def _deploy_aml(deployment_name, model_name, version): print("Please run 'deploy.sh' to bring your deployment online") -def deploy_non_persistent(task, - model, - deployment_name, - **kwargs): +def deploy_non_persistent(task, model, deployment_name, **kwargs): if "deployment_type" not in kwargs: kwargs["deployment_type"] = DeploymentType.NON_PERSISTENT assert kwargs.get("deployment_type") == DeploymentType.NON_PERSISTENT, "Only non-persistent deployment type can be used with `deploy_non_persistent`" - deploy(task, - model, - deployment_name, - **kwargs) + deploy(task, model, deployment_name, **kwargs) return mii.mii_query_handle(deployment_name)