Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use Hugging Face pipeline API #1

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
197 changes: 94 additions & 103 deletions python/huggingfaceserver/huggingfaceserver/encoder_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,41 +12,90 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from concurrent.futures import ThreadPoolExecutor
from functools import partial
import pathlib
from typing import Any, Dict, Optional, Union
from typing import Any, Callable, Dict, Optional, Union

import torch
from accelerate import init_empty_weights
from kserve import Model
from kserve.errors import InferenceError
from kserve.logging import logger
from kserve.model import PredictorConfig
from kserve.protocol.infer_type import InferInput, InferRequest, InferResponse
from kserve.utils.utils import (
from_np_dtype,
get_predict_input,
get_predict_response,
)
import asyncio
from torch import Tensor
from transformers import (
AutoConfig,
AutoModel,
AutoTokenizer,
BatchEncoding,
PreTrainedModel,
PreTrainedTokenizerBase,
PretrainedConfig,
TensorType,
Pipeline,
)

from .task import (
MLTask,
get_pipeline_for_task,
is_generative_task,
get_model_class_for_task,
infer_task_from_model_architecture,
)


class PredictorProxyModel(PreTrainedModel):
"""
This class acts like a Huggingface PreTrainedModel but for its forward pass
it forwards the request to the predictor server for inference.
"""

config: PretrainedConfig
_modules = {}
_parameters = {}
_buffers = {}
loop: asyncio.AbstractEventLoop

def __init__(
self,
config: PretrainedConfig,
predict: Callable,
model_name: str,
input_names: Optional[str] = None,
):
self.config = config
self.predict = predict
self.model_name = model_name
self.input_names = input_names

def __call__(self, **model_inputs):
"""
Run inference. We will send an inference request to the predictor to do
the actual work.
"""
infer_inputs = []
for key, input_tensor in model_inputs.items():
# Send only specific inputs if they have been provided, otherwise send everything.
if not self.input_names or key in self.input_names:
infer_input = InferInput(
name=key,
datatype=from_np_dtype(input_tensor.numpy().dtype),
shape=list(input_tensor.shape),
data=input_tensor.numpy(),
)
infer_inputs.append(infer_input)
infer_request = InferRequest(
infer_inputs=infer_inputs, model_name=self.model_name
)
# Since predict is async it returns a coroutine. This needs to be run in an event loop.
res = asyncio.run_coroutine_threadsafe(self.predict(infer_request), self.loop)
res = res.result()
return {out.name: torch.Tensor(out.data).view(out.shape) for out in res.outputs}


class HuggingfaceEncoderModel(Model): # pylint:disable=c-extension-no-member
task: MLTask
model_config: PretrainedConfig
Expand All @@ -60,9 +109,12 @@ class HuggingfaceEncoderModel(Model): # pylint:disable=c-extension-no-member
tokenizer_revision: Optional[str]
trust_remote_code: bool
ready: bool = False
pipeline: Pipeline
max_threadpool_workers: Optional[int] = None
_tokenizer: PreTrainedTokenizerBase
_model: Optional[PreTrainedModel] = None
_device: torch.device
_executor: ThreadPoolExecutor

def __init__(
self,
Expand All @@ -79,6 +131,7 @@ def __init__(
model_revision: Optional[str] = None,
tokenizer_revision: Optional[str] = None,
trust_remote_code: bool = False,
max_threadpool_workers: Optional[int] = None,
predictor_config: Optional[PredictorConfig] = None,
):
super().__init__(model_name, predictor_config)
Expand All @@ -93,6 +146,7 @@ def __init__(
self.model_revision = model_revision
self.tokenizer_revision = tokenizer_revision
self.trust_remote_code = trust_remote_code
self.max_threadpool_workers = max_threadpool_workers

if model_config:
self.model_config = model_config
Expand Down Expand Up @@ -154,6 +208,9 @@ def load(self) -> bool:

# load huggingface model using from_pretrained for inference mode
if not self.predictor_host:
# If loading the model locally we want a single threadpool worker as we only
# want to run a single forward pass at a time.
max_threadpool_workers = self.max_threadpool_workers or 1
model_cls = get_model_class_for_task(self.task)
self._model = model_cls.from_pretrained(
model_id_or_path,
Expand All @@ -176,108 +233,42 @@ def load(self) -> bool:
logger.info(
f"Successfully loaded huggingface model from path {model_id_or_path}"
)
self.ready = True
return self.ready

def preprocess(
self,
payload: Union[Dict, InferRequest],
context: Dict[str, Any],
) -> Union[BatchEncoding, InferRequest]:
instances = get_predict_input(payload)
# Serialize to tensor
if self.predictor_host:
inputs = self._tokenizer(
instances,
max_length=self.max_length,
add_special_tokens=self.add_special_tokens,
return_tensors=TensorType.NUMPY,
return_token_type_ids=self.return_token_type_ids,
padding=True,
truncation=True,
)
context["payload"] = payload
context["input_ids"] = inputs["input_ids"]
infer_inputs = []
for key, input_tensor in inputs.items():
if (not self.tensor_input_names) or (key in self.tensor_input_names):
infer_input = InferInput(
name=key,
datatype=from_np_dtype(input_tensor.dtype),
shape=list(input_tensor.shape),
data=input_tensor,
)
infer_inputs.append(infer_input)
infer_request = InferRequest(
infer_inputs=infer_inputs, model_name=self.name
)
return infer_request
else:
inputs = self._tokenizer(
instances,
max_length=self.max_length,
add_special_tokens=self.add_special_tokens,
return_tensors=TensorType.PYTORCH,
return_token_type_ids=self.return_token_type_ids,
padding=True,
truncation=True,
# If the model is remote use the default number of threadpool workers as it
# is configured to parallelize IO.
max_threadpool_workers = self.max_threadpool_workers
self._model = PredictorProxyModel(
self.model_config,
model_name=self.name,
predict=super().predict,
input_names=self.tensor_input_names,
)
context["payload"] = payload
context["input_ids"] = inputs["input_ids"]
return inputs
self.pipeline = get_pipeline_for_task(
self.task,
self._model,
self._tokenizer,
)
self._executor = ThreadPoolExecutor(max_workers=max_threadpool_workers)
self.ready = True
return self.ready

async def predict(
self,
input_batch: Union[BatchEncoding, InferRequest],
payload: Dict,
context: Dict[str, Any],
) -> Union[Tensor, InferResponse]:
if self.predictor_host:
# when predictor_host is provided, serialize the tensor and send to optimized model serving runtime
# like NVIDIA triton inference server
return await super().predict(input_batch, context)
else:
input_batch = input_batch.to(self._device)
try:
with torch.no_grad():
outputs = self._model(**input_batch).logits
return outputs
except Exception as e:
raise InferenceError(str(e))
if isinstance(self._model, PredictorProxyModel) and not hasattr(
self._model, "loop"
):
self._model.loop = asyncio.get_running_loop()
# Run the inference in a thread-pool executor. Since the call to `pipeline` is
# blocking this ensures we don't block the event loop.
output = await asyncio.get_running_loop().run_in_executor(
self._executor, partial(self.pipeline, **payload)
)
if self.task == MLTask.token_classification:
for s in output:
for entity in s:
entity["score"] = entity["score"].item()

def postprocess(
self, outputs: Union[Tensor, InferResponse], context: Dict[str, Any]
) -> Union[Dict, InferResponse]:
input_ids = context["input_ids"]
request = context["payload"]
if isinstance(outputs, InferResponse):
shape = torch.Size(outputs.outputs[0].shape)
data = torch.Tensor(outputs.outputs[0].data)
outputs = data.view(shape)
input_ids = torch.Tensor(input_ids)
inferences = []
if self.task == MLTask.sequence_classification:
num_rows, num_cols = outputs.shape
for i in range(num_rows):
out = outputs[i].unsqueeze(0)
predicted_idx = out.argmax().item()
inferences.append(predicted_idx)
return get_predict_response(request, inferences, self.name)
elif self.task == MLTask.fill_mask:
num_rows = outputs.shape[0]
for i in range(num_rows):
mask_pos = (input_ids == self._tokenizer.mask_token_id)[i]
mask_token_index = mask_pos.nonzero(as_tuple=True)[0]
predicted_token_id = outputs[i, mask_token_index].argmax(axis=-1)
inferences.append(self._tokenizer.decode(predicted_token_id))
return get_predict_response(request, inferences, self.name)
elif self.task == MLTask.token_classification:
num_rows = outputs.shape[0]
for i in range(num_rows):
output = outputs[i].unsqueeze(0)
predictions = torch.argmax(output, dim=2)
inferences.append(predictions.tolist())
return get_predict_response(request, inferences, self.name)
else:
raise ValueError(
f"Unsupported task {self.task}. Please check the supported `task` option."
)
return output
19 changes: 19 additions & 0 deletions python/huggingfaceserver/huggingfaceserver/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@
AutoModelForTableQuestionAnswering,
AutoModelForTokenClassification,
PretrainedConfig,
Pipeline,
PreTrainedTokenizer,
pipeline,
)


Expand Down Expand Up @@ -76,6 +79,14 @@ def _missing_(cls, value: str):
MLTask.multiple_choice: AutoModelForMultipleChoice,
}

TASK_2_PIPELINE = {
MLTask.sequence_classification: "text-classification",
MLTask.question_answering: "question-answering",
MLTask.table_question_answering: "table-question-answering",
MLTask.token_classification: "token-classification",
MLTask.fill_mask: "fill-mask",
}

SUPPORTED_TASKS = {
MLTask.sequence_classification,
MLTask.token_classification,
Expand Down Expand Up @@ -114,5 +125,13 @@ def is_generative_task(task: MLTask) -> bool:
}


def get_pipeline_for_task(
task: MLTask, model, tokenizer: PreTrainedTokenizer
) -> Pipeline:
if task not in TASK_2_PIPELINE:
raise ValueError(f"Pipeline not found for task '{task.name}'")
return pipeline(TASK_2_PIPELINE[task], model=model, tokenizer=tokenizer)


def get_model_class_for_task(task: MLTask) -> Type[AutoModel]:
return TASK_2_CLS[task]