From d76632cf3f7ab72d80cd58e018a2af0edc33a7db Mon Sep 17 00:00:00 2001 From: Curtis Maddalozzo Date: Fri, 17 May 2024 14:42:08 -0400 Subject: [PATCH] Use Hugging Face pipeline API Signed-off-by: Curtis Maddalozzo --- .../huggingfaceserver/encoder_model.py | 197 +++++++++--------- .../huggingfaceserver/task.py | 19 ++ 2 files changed, 113 insertions(+), 103 deletions(-) diff --git a/python/huggingfaceserver/huggingfaceserver/encoder_model.py b/python/huggingfaceserver/huggingfaceserver/encoder_model.py index b4965b1506..7ad4d050bc 100644 --- a/python/huggingfaceserver/huggingfaceserver/encoder_model.py +++ b/python/huggingfaceserver/huggingfaceserver/encoder_model.py @@ -12,41 +12,90 @@ # See the License for the specific language governing permissions and # limitations under the License. +from concurrent.futures import ThreadPoolExecutor +from functools import partial import pathlib -from typing import Any, Dict, Optional, Union +from typing import Any, Callable, Dict, Optional, Union import torch from accelerate import init_empty_weights from kserve import Model -from kserve.errors import InferenceError from kserve.logging import logger from kserve.model import PredictorConfig from kserve.protocol.infer_type import InferInput, InferRequest, InferResponse from kserve.utils.utils import ( from_np_dtype, - get_predict_input, - get_predict_response, ) +import asyncio from torch import Tensor from transformers import ( AutoConfig, AutoModel, AutoTokenizer, - BatchEncoding, PreTrainedModel, PreTrainedTokenizerBase, PretrainedConfig, - TensorType, + Pipeline, ) from .task import ( MLTask, + get_pipeline_for_task, is_generative_task, get_model_class_for_task, infer_task_from_model_architecture, ) +class PredictorProxyModel(PreTrainedModel): + """ + This class acts like a Huggingface PreTrainedModel but for its forward pass + it forwards the request to the predictor server for inference. + """ + + config: PretrainedConfig + _modules = {} + _parameters = {} + _buffers = {} + loop: asyncio.AbstractEventLoop + + def __init__( + self, + config: PretrainedConfig, + predict: Callable, + model_name: str, + input_names: Optional[str] = None, + ): + self.config = config + self.predict = predict + self.model_name = model_name + self.input_names = input_names + + def __call__(self, **model_inputs): + """ + Run inference. We will send an inference request to the predictor to do + the actual work. + """ + infer_inputs = [] + for key, input_tensor in model_inputs.items(): + # Send only specific inputs if they have been provided, otherwise send everything. + if not self.input_names or key in self.input_names: + infer_input = InferInput( + name=key, + datatype=from_np_dtype(input_tensor.numpy().dtype), + shape=list(input_tensor.shape), + data=input_tensor.numpy(), + ) + infer_inputs.append(infer_input) + infer_request = InferRequest( + infer_inputs=infer_inputs, model_name=self.model_name + ) + # Since predict is async it returns a coroutine. This needs to be run in an event loop. + res = asyncio.run_coroutine_threadsafe(self.predict(infer_request), self.loop) + res = res.result() + return {out.name: torch.Tensor(out.data).view(out.shape) for out in res.outputs} + + class HuggingfaceEncoderModel(Model): # pylint:disable=c-extension-no-member task: MLTask model_config: PretrainedConfig @@ -60,9 +109,12 @@ class HuggingfaceEncoderModel(Model): # pylint:disable=c-extension-no-member tokenizer_revision: Optional[str] trust_remote_code: bool ready: bool = False + pipeline: Pipeline + max_threadpool_workers: Optional[int] = None _tokenizer: PreTrainedTokenizerBase _model: Optional[PreTrainedModel] = None _device: torch.device + _executor: ThreadPoolExecutor def __init__( self, @@ -79,6 +131,7 @@ def __init__( model_revision: Optional[str] = None, tokenizer_revision: Optional[str] = None, trust_remote_code: bool = False, + max_threadpool_workers: Optional[int] = None, predictor_config: Optional[PredictorConfig] = None, ): super().__init__(model_name, predictor_config) @@ -93,6 +146,7 @@ def __init__( self.model_revision = model_revision self.tokenizer_revision = tokenizer_revision self.trust_remote_code = trust_remote_code + self.max_threadpool_workers = max_threadpool_workers if model_config: self.model_config = model_config @@ -154,6 +208,9 @@ def load(self) -> bool: # load huggingface model using from_pretrained for inference mode if not self.predictor_host: + # If loading the model locally we want a single threadpool worker as we only + # want to run a single forward pass at a time. + max_threadpool_workers = self.max_threadpool_workers or 1 model_cls = get_model_class_for_task(self.task) self._model = model_cls.from_pretrained( model_id_or_path, @@ -176,108 +233,42 @@ def load(self) -> bool: logger.info( f"Successfully loaded huggingface model from path {model_id_or_path}" ) - self.ready = True - return self.ready - - def preprocess( - self, - payload: Union[Dict, InferRequest], - context: Dict[str, Any], - ) -> Union[BatchEncoding, InferRequest]: - instances = get_predict_input(payload) - # Serialize to tensor - if self.predictor_host: - inputs = self._tokenizer( - instances, - max_length=self.max_length, - add_special_tokens=self.add_special_tokens, - return_tensors=TensorType.NUMPY, - return_token_type_ids=self.return_token_type_ids, - padding=True, - truncation=True, - ) - context["payload"] = payload - context["input_ids"] = inputs["input_ids"] - infer_inputs = [] - for key, input_tensor in inputs.items(): - if (not self.tensor_input_names) or (key in self.tensor_input_names): - infer_input = InferInput( - name=key, - datatype=from_np_dtype(input_tensor.dtype), - shape=list(input_tensor.shape), - data=input_tensor, - ) - infer_inputs.append(infer_input) - infer_request = InferRequest( - infer_inputs=infer_inputs, model_name=self.name - ) - return infer_request else: - inputs = self._tokenizer( - instances, - max_length=self.max_length, - add_special_tokens=self.add_special_tokens, - return_tensors=TensorType.PYTORCH, - return_token_type_ids=self.return_token_type_ids, - padding=True, - truncation=True, + # If the model is remote use the default number of threadpool workers as it + # is configured to parallelize IO. + max_threadpool_workers = self.max_threadpool_workers + self._model = PredictorProxyModel( + self.model_config, + model_name=self.name, + predict=super().predict, + input_names=self.tensor_input_names, ) - context["payload"] = payload - context["input_ids"] = inputs["input_ids"] - return inputs + self.pipeline = get_pipeline_for_task( + self.task, + self._model, + self._tokenizer, + ) + self._executor = ThreadPoolExecutor(max_workers=max_threadpool_workers) + self.ready = True + return self.ready async def predict( self, - input_batch: Union[BatchEncoding, InferRequest], + payload: Dict, context: Dict[str, Any], ) -> Union[Tensor, InferResponse]: - if self.predictor_host: - # when predictor_host is provided, serialize the tensor and send to optimized model serving runtime - # like NVIDIA triton inference server - return await super().predict(input_batch, context) - else: - input_batch = input_batch.to(self._device) - try: - with torch.no_grad(): - outputs = self._model(**input_batch).logits - return outputs - except Exception as e: - raise InferenceError(str(e)) + if isinstance(self._model, PredictorProxyModel) and not hasattr( + self._model, "loop" + ): + self._model.loop = asyncio.get_running_loop() + # Run the inference in a thread-pool executor. Since the call to `pipeline` is + # blocking this ensures we don't block the event loop. + output = await asyncio.get_running_loop().run_in_executor( + self._executor, partial(self.pipeline, **payload) + ) + if self.task == MLTask.token_classification: + for s in output: + for entity in s: + entity["score"] = entity["score"].item() - def postprocess( - self, outputs: Union[Tensor, InferResponse], context: Dict[str, Any] - ) -> Union[Dict, InferResponse]: - input_ids = context["input_ids"] - request = context["payload"] - if isinstance(outputs, InferResponse): - shape = torch.Size(outputs.outputs[0].shape) - data = torch.Tensor(outputs.outputs[0].data) - outputs = data.view(shape) - input_ids = torch.Tensor(input_ids) - inferences = [] - if self.task == MLTask.sequence_classification: - num_rows, num_cols = outputs.shape - for i in range(num_rows): - out = outputs[i].unsqueeze(0) - predicted_idx = out.argmax().item() - inferences.append(predicted_idx) - return get_predict_response(request, inferences, self.name) - elif self.task == MLTask.fill_mask: - num_rows = outputs.shape[0] - for i in range(num_rows): - mask_pos = (input_ids == self._tokenizer.mask_token_id)[i] - mask_token_index = mask_pos.nonzero(as_tuple=True)[0] - predicted_token_id = outputs[i, mask_token_index].argmax(axis=-1) - inferences.append(self._tokenizer.decode(predicted_token_id)) - return get_predict_response(request, inferences, self.name) - elif self.task == MLTask.token_classification: - num_rows = outputs.shape[0] - for i in range(num_rows): - output = outputs[i].unsqueeze(0) - predictions = torch.argmax(output, dim=2) - inferences.append(predictions.tolist()) - return get_predict_response(request, inferences, self.name) - else: - raise ValueError( - f"Unsupported task {self.task}. Please check the supported `task` option." - ) + return output diff --git a/python/huggingfaceserver/huggingfaceserver/task.py b/python/huggingfaceserver/huggingfaceserver/task.py index 96213488f6..943c553e58 100644 --- a/python/huggingfaceserver/huggingfaceserver/task.py +++ b/python/huggingfaceserver/huggingfaceserver/task.py @@ -26,6 +26,9 @@ AutoModelForTableQuestionAnswering, AutoModelForTokenClassification, PretrainedConfig, + Pipeline, + PreTrainedTokenizer, + pipeline, ) @@ -76,6 +79,14 @@ def _missing_(cls, value: str): MLTask.multiple_choice: AutoModelForMultipleChoice, } +TASK_2_PIPELINE = { + MLTask.sequence_classification: "text-classification", + MLTask.question_answering: "question-answering", + MLTask.table_question_answering: "table-question-answering", + MLTask.token_classification: "token-classification", + MLTask.fill_mask: "fill-mask", +} + SUPPORTED_TASKS = { MLTask.sequence_classification, MLTask.token_classification, @@ -114,5 +125,13 @@ def is_generative_task(task: MLTask) -> bool: } +def get_pipeline_for_task( + task: MLTask, model, tokenizer: PreTrainedTokenizer +) -> Pipeline: + if task not in TASK_2_PIPELINE: + raise ValueError(f"Pipeline not found for task '{task.name}'") + return pipeline(TASK_2_PIPELINE[task], model=model, tokenizer=tokenizer) + + def get_model_class_for_task(task: MLTask) -> Type[AutoModel]: return TASK_2_CLS[task]