-
Notifications
You must be signed in to change notification settings - Fork 175
/
utils.py
197 lines (154 loc) · 6.23 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
import os
import pickle
import time
import importlib
import torch
import mii.legacy as mii
from types import SimpleNamespace
from huggingface_hub import HfApi
from mii.legacy.models.score.generate import generated_score_path
from mii.legacy.constants import (
MII_CACHE_PATH,
MII_CACHE_PATH_DEFAULT,
ModelProvider,
SUPPORTED_MODEL_TYPES,
REQUIRED_KEYS_PER_TASK,
MII_HF_CACHE_EXPIRATION,
MII_HF_CACHE_EXPIRATION_DEFAULT,
)
from mii.legacy.config import TaskType
def _get_hf_models_by_type(model_type=None, task=None):
cache_file_path = os.path.join(mii_cache_path(), "HF_model_cache.pkl")
cache_expiration_seconds = os.getenv(MII_HF_CACHE_EXPIRATION,
MII_HF_CACHE_EXPIRATION_DEFAULT)
# Load or initialize the cache
model_data = {"cache_time": 0, "model_list": []}
if os.path.isfile(cache_file_path):
with open(cache_file_path, 'rb') as f:
model_data = pickle.load(f)
current_time = time.time()
# Update the cache if it has expired
if (model_data["cache_time"] + cache_expiration_seconds) < current_time:
api = HfApi()
model_data["model_list"] = [
SimpleNamespace(id=m.id,
pipeline_tag=m.pipeline_tag,
tags=m.tags) for m in api.list_models()
]
model_data["cache_time"] = current_time
# Save the updated cache
with open(cache_file_path, 'wb') as f:
pickle.dump(model_data, f)
# Filter the model list
models = model_data["model_list"]
if model_type is not None:
models = [m for m in models if model_type in m.tags]
if task is not None:
models = [m for m in models if m.pipeline_tag == task]
# Extract model IDs
model_ids = [m.id for m in models]
if task == TaskType.TEXT_GENERATION:
# TODO: this is a temp solution to get around some HF models not having the correct tags
model_ids.extend([
"microsoft/bloom-deepspeed-inference-fp16",
"microsoft/bloom-deepspeed-inference-int8",
"EleutherAI/gpt-neox-20b"
])
return model_ids
def get_supported_models(task):
supported_models = []
for model_type, provider in SUPPORTED_MODEL_TYPES.items():
if provider == ModelProvider.HUGGING_FACE:
models = _get_hf_models_by_type(model_type, task)
elif provider == ModelProvider.ELEUTHER_AI:
if task == TaskType.TEXT_GENERATION:
models = [model_type]
elif provider == ModelProvider.DIFFUSERS:
models = _get_hf_models_by_type(model_type, task)
supported_models.extend(models)
if not supported_models:
raise ValueError(f"Task {task} not supported")
return supported_models
def check_if_task_and_model_is_supported(task, model_name):
supported_models = get_supported_models(task)
assert (
model_name in supported_models
), f"{task} is not supported by {model_name}. This task is supported by {len(supported_models)} other models. See which models with `mii.get_supported_models(mii.{task})`."
def check_if_task_and_model_is_valid(task, model_name):
valid_task_models = _get_hf_models_by_type(None, task)
assert (
model_name in valid_task_models
), f"{task} is not supported by {model_name}. This task is supported by {len(valid_task_models)} other models. See which models with `mii.get_supported_models(mii.{task})`."
def full_model_path(model_path):
aml_model_dir = os.environ.get('AZUREML_MODEL_DIR', None)
if aml_model_dir:
# (potentially) append relative model_path w. aml path
assert os.path.isabs(aml_model_dir), f"AZUREML_MODEL_DIR={aml_model_dir} must be an absolute path"
if model_path:
assert not os.path.isabs(model_path), f"model_path={model_path} must be relative to append w. AML path"
return os.path.join(aml_model_dir, model_path)
else:
return aml_model_dir
elif model_path:
return model_path
else:
return mii.constants.MII_MODEL_PATH_DEFAULT
def is_aml():
return os.getenv("AZUREML_MODEL_DIR") is not None
def mii_cache_path():
cache_path = os.environ.get(MII_CACHE_PATH, MII_CACHE_PATH_DEFAULT)
if not os.path.isdir(cache_path):
os.makedirs(cache_path)
return cache_path
def import_score_file(deployment_name, deployment_type):
score_path = generated_score_path(deployment_name, deployment_type)
spec = importlib.util.spec_from_file_location("score", score_path)
score = importlib.util.module_from_spec(spec)
spec.loader.exec_module(score)
return score
dtype_proto_field = {
str: "svalue",
int: "ivalue",
float: "fvalue",
bool: "bvalue",
}
def kwarg_dict_to_proto(kwarg_dict):
def get_proto_value(value):
proto_value = mii.grpc_related.proto.legacymodelresponse_pb2.Value()
setattr(proto_value, dtype_proto_field[type(value)], value)
return proto_value
return {k: get_proto_value(v) for k, v in kwarg_dict.items()}
def unpack_proto_query_kwargs(query_kwargs):
query_kwargs = {
k: getattr(v,
v.WhichOneof("oneof_values"))
for k,
v in query_kwargs.items()
}
return query_kwargs
def extract_query_dict(task, request_dict):
required_keys = REQUIRED_KEYS_PER_TASK[task]
query_dict = {}
for key in required_keys:
value = request_dict.pop(key, None)
if value is None:
raise ValueError("Request for task: {task} is missing required key: {key}.")
query_dict[key] = value
return query_dict
def get_num_gpus(mii_config):
num_gpus = mii_config.model_conf.tensor_parallel
assert (
torch.cuda.device_count() >= num_gpus
), f"Available GPU count: {torch.cuda.device_count()} does not meet the required gpu count: {num_gpus}"
return num_gpus
def get_provider(model_name, task):
if model_name == "gpt-neox":
provider = ModelProvider.ELEUTHER_AI
elif task in [TaskType.TEXT2IMG, TaskType.INPAINTING]:
provider = ModelProvider.DIFFUSERS
else:
provider = ModelProvider.HUGGING_FACE
return provider