-
Notifications
You must be signed in to change notification settings - Fork 129
/
api.py
218 lines (175 loc) · 6.98 KB
/
api.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
# Copyright The Lightning AI team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import json
from abc import ABC, abstractmethod
from typing import Optional
from pydantic import BaseModel
from litserve.specs.base import LitSpec
def no_batch_unbatch_message_no_stream(obj, data):
return f"""
You set `max_batch_size > 1`, but the default implementation for batch() and unbatch() only supports
PyTorch tensors or NumPy ndarrays, while we found {type(data)}.
Please implement these two methods in {obj.__class__.__name__}.
Example:
def batch(self, inputs):
return np.stack(inputs)
def unbatch(self, output):
return list(output)
"""
def no_batch_unbatch_message_stream(obj, data):
return f"""
You set `max_batch_size > 1`, but the default implementation for batch() and unbatch() only supports
PyTorch tensors or NumPy ndarrays, while we found {type(data)}.
Please implement these two methods in {obj.__class__.__name__}.
Example:
def batch(self, inputs):
return np.stack(inputs)
def unbatch(self, output):
for out in output:
yield list(out)
"""
class LitAPI(ABC):
_stream: bool = False
_default_unbatch: callable = None
_spec: LitSpec = None
_device: Optional[str] = None
request_timeout: Optional[float] = None
@abstractmethod
def setup(self, devices):
"""Setup the model so it can be called in `predict`."""
pass
def decode_request(self, request, **kwargs):
"""Convert the request payload to your model input."""
if self._spec:
return self._spec.decode_request(request, **kwargs)
return request
def batch(self, inputs):
"""Convert a list of inputs to a batched input."""
# consider assigning an implementation when starting server
# to avoid the runtime cost of checking (should be negligible)
if hasattr(inputs[0], "__torch_function__"):
import torch
return torch.stack(inputs)
if inputs[0].__class__.__name__ == "ndarray":
import numpy
return numpy.stack(inputs)
if self.stream:
message = no_batch_unbatch_message_stream(self, inputs)
else:
message = no_batch_unbatch_message_no_stream(self, inputs)
raise NotImplementedError(message)
@abstractmethod
def predict(self, x, **kwargs):
"""Run the model on the input and return or yield the output."""
pass
def _unbatch_no_stream(self, output):
if hasattr(output, "__torch_function__") or output.__class__.__name__ == "ndarray":
return list(output)
message = no_batch_unbatch_message_no_stream(self, output)
raise NotImplementedError(message)
def _unbatch_stream(self, output_stream):
for output in output_stream:
if hasattr(output, "__torch_function__") or output.__class__.__name__ == "ndarray":
yield list(output)
else:
message = no_batch_unbatch_message_no_stream(self, output)
raise NotImplementedError(message)
def unbatch(self, output):
"""Convert a batched output to a list of outputs."""
return self._default_unbatch(output)
def encode_response(self, output, **kwargs):
"""Convert the model output to a response payload.
To enable streaming, it should yield the output.
"""
if self._spec:
return self._spec.encode_response(output, **kwargs)
return output
def format_encoded_response(self, data):
if isinstance(data, dict):
return json.dumps(data) + "\n"
if isinstance(data, BaseModel):
return data.model_dump_json() + "\n"
return data
@property
def stream(self):
return self._stream
@stream.setter
def stream(self, value):
self._stream = value
@property
def device(self):
return self._device
@device.setter
def device(self, value):
self._device = value
def _sanitize(self, max_batch_size: int, spec: LitSpec):
if self.stream:
self._default_unbatch = self._unbatch_stream
else:
self._default_unbatch = self._unbatch_no_stream
# we will sanitize regularly if no spec
# in case, we have spec then:
# case 1: spec implements a streaming API
# Case 2: spec implements a non-streaming API
if spec:
# TODO: Implement sanitization
self._spec = spec
return
original = self.unbatch.__code__ is LitAPI.unbatch.__code__
if (
self.stream
and max_batch_size > 1
and not all([
inspect.isgeneratorfunction(self.predict),
inspect.isgeneratorfunction(self.encode_response),
(original or inspect.isgeneratorfunction(self.unbatch)),
])
):
raise ValueError(
"""When `stream=True` with max_batch_size > 1, `lit_api.predict`, `lit_api.encode_response` and
`lit_api.unbatch` must generate values using `yield`.
Example:
def predict(self, inputs):
...
for i in range(max_token_length):
yield prediction
def encode_response(self, outputs):
for output in outputs:
encoded_output = ...
yield encoded_output
def unbatch(self, outputs):
for output in outputs:
unbatched_output = ...
yield unbatched_output
"""
)
if self.stream and not all([
inspect.isgeneratorfunction(self.predict),
inspect.isgeneratorfunction(self.encode_response),
]):
raise ValueError(
"""When `stream=True` both `lit_api.predict` and
`lit_api.encode_response` must generate values using `yield`.
Example:
def predict(self, inputs):
...
for i in range(max_token_length):
yield prediction
def encode_response(self, outputs):
for output in outputs:
encoded_output = ...
yield encoded_output
"""
)