Skip to content

Commit

Permalink
Fixing formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
TosinSeg committed Jun 22, 2023
1 parent df2a6ed commit 228e9f2
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 13 deletions.
11 changes: 6 additions & 5 deletions examples/non_persistent/text-generation-bloom560-example.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,11 @@
import mii

mii_configs = {"tensor_parallel": 1, "dtype": "fp16"}
generator = mii.deploy_non_persistent(task='text-generation',
model="bigscience/bloom-560m",
deployment_name="bloom560m_deployment",
deployment_type=mii.constants.DeploymentType.NON_PERSISTENT,
mii_config=mii_configs)
generator = mii.deploy_non_persistent(
task='text-generation',
model="bigscience/bloom-560m",
deployment_name="bloom560m_deployment",
deployment_type=mii.constants.DeploymentType.NON_PERSISTENT,
mii_config=mii_configs)
result = generator.query({'query': ["DeepSpeed is the", "Seattle is"]})
print(result)
11 changes: 3 additions & 8 deletions mii/deployment.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ def deploy(task,
If deployment_type is `LOCAL`, returns just the name of the deployment that can be used to create a query handle using `mii.mii_query_handle(deployment_name)`
"""

# parse and validate mii config
mii_config = mii.config.MIIConfig(**mii_config)
if enable_zero:
Expand Down Expand Up @@ -175,17 +176,11 @@ def _deploy_aml(deployment_name, model_name, version):
print("Please run 'deploy.sh' to bring your deployment online")


def deploy_non_persistent(task,
model,
deployment_name,
**kwargs):
def deploy_non_persistent(task, model, deployment_name, **kwargs):
if "deployment_type" not in kwargs:
kwargs["deployment_type"] = DeploymentType.NON_PERSISTENT
assert kwargs.get("deployment_type") == DeploymentType.NON_PERSISTENT, "Only non-persistent deployment type can be used with `deploy_non_persistent`"
deploy(task,
model,
deployment_name,
**kwargs)
deploy(task, model, deployment_name, **kwargs)
return mii.mii_query_handle(deployment_name)


Expand Down

0 comments on commit 228e9f2

Please sign in to comment.