Skip to content

Commit

Permalink
OpenAI v1 Completions API (#170)
Browse files Browse the repository at this point in the history
  • Loading branch information
tgaddair authored Jan 9, 2024
1 parent b14207c commit 8e479ce
Show file tree
Hide file tree
Showing 3 changed files with 288 additions and 6 deletions.
2 changes: 1 addition & 1 deletion router/src/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ impl Infer {
})?;

let mut adapter_id = request.parameters.adapter_id.clone();
if adapter_id.is_none() {
if adapter_id.is_none() || adapter_id.as_ref().unwrap().is_empty() {
adapter_id = Some(BASE_MODEL_ADAPTER_ID.to_string());
}
let mut adapter_source = request.parameters.adapter_source.clone();
Expand Down
210 changes: 210 additions & 0 deletions router/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -308,6 +308,216 @@ pub(crate) struct ErrorResponse {
pub error_type: String,
}

// OpenAI compatible structs

#[derive(Serialize, ToSchema)]
struct UsageInfo {
prompt_tokens: u32,
total_tokens: u32,
completion_tokens: Option<u32>,
}

#[derive(Clone, Debug, Deserialize, ToSchema)]
struct ChatCompletionRequest {
model: String,
messages: Vec<String>,
temperature: Option<f32>,
top_p: Option<f32>,
n: Option<i32>,
max_tokens: Option<i32>,
#[serde(default)]
stop: Vec<String>,
stream: Option<bool>,
presence_penalty: Option<f32>,
frequency_penalty: Option<f32>,
logit_bias: Option<std::collections::HashMap<String, f32>>,
user: Option<String>,
// Additional parameters
// TODO(travis): add other LoRAX params here
}

#[derive(Clone, Debug, Deserialize, ToSchema)]
struct CompletionRequest {
model: String,
prompt: String,
suffix: Option<String>,
max_tokens: Option<i32>,
temperature: Option<f32>,
top_p: Option<f32>,
n: Option<i32>,
stream: Option<bool>,
logprobs: Option<i32>,
echo: Option<bool>,
#[serde(default)]
stop: Vec<String>,
presence_penalty: Option<f32>,
frequency_penalty: Option<f32>,
best_of: Option<i32>,
logit_bias: Option<std::collections::HashMap<String, f32>>,
user: Option<String>,
// Additional parameters
// TODO(travis): add other LoRAX params here
}

#[derive(Serialize, ToSchema)]
struct LogProbs {
text_offset: Vec<i32>,
token_logprobs: Vec<Option<f32>>,
tokens: Vec<String>,
top_logprobs: Option<Vec<Option<std::collections::HashMap<i32, f32>>>>,
}

#[derive(Serialize, ToSchema)]
struct CompletionResponseChoice {
index: i32,
text: String,
logprobs: Option<LogProbs>,
finish_reason: Option<String>, // Literal replaced with String
}

#[derive(Serialize, ToSchema)]
struct CompletionResponse {
id: String,
object: String,
created: i64,
model: String,
choices: Vec<CompletionResponseChoice>,
usage: UsageInfo,
}

#[derive(Serialize, ToSchema)]
struct CompletionResponseStreamChoice {
index: i32,
text: String,
logprobs: Option<LogProbs>,
finish_reason: Option<String>, // Literal replaced with String
}

#[derive(Serialize, ToSchema)]
struct CompletionStreamResponse {
id: String,
object: String,
created: i64,
model: String,
choices: Vec<CompletionResponseStreamChoice>,
usage: Option<UsageInfo>,
}

#[derive(Serialize, ToSchema)]
struct ChatMessage {
role: String,
content: String,
}

#[derive(Serialize, ToSchema)]
struct ChatCompletionResponseChoice {
index: i32,
message: ChatMessage,
finish_reason: Option<String>, // Literal replaced with String
}

#[derive(Serialize, ToSchema)]
struct ChatCompletionResponse {
id: String,
object: String,
created: i64,
model: String,
choices: Vec<ChatCompletionResponseChoice>,
usage: UsageInfo,
}

impl From<CompletionRequest> for CompatGenerateRequest {
fn from(req: CompletionRequest) -> Self {
CompatGenerateRequest {
inputs: req.prompt,
parameters: GenerateParameters {
adapter_id: req.model.parse().ok(),
adapter_source: None,
api_token: None,
best_of: req.best_of.map(|x| x as usize),
temperature: req.temperature,
repetition_penalty: None,
top_k: None,
top_p: req.top_p,
typical_p: None,
do_sample: !req.n.is_none(),
max_new_tokens: req
.max_tokens
.map(|x| x as u32)
.unwrap_or(default_max_new_tokens()),
return_full_text: req.echo,
stop: req.stop,
truncate: None,
watermark: false,
details: true,
decoder_input_details: req.logprobs.is_some(),
seed: None,
},
stream: req.stream.unwrap_or(false),
}
}
}

impl From<GenerateResponse> for CompletionResponse {
fn from(resp: GenerateResponse) -> Self {
let prompt_tokens = resp.details.as_ref().map(|x| x.prompt_tokens).unwrap_or(0);
let completion_tokens = resp
.details
.as_ref()
.map(|x| x.generated_tokens)
.unwrap_or(0);
let total_tokens = prompt_tokens + completion_tokens;

CompletionResponse {
id: "null".to_string(),
object: "text_completion".to_string(),
created: 0,
model: "null".to_string(),
choices: vec![CompletionResponseChoice {
index: 0,
text: resp.generated_text,
logprobs: None,
finish_reason: None,
}],
usage: UsageInfo {
prompt_tokens: prompt_tokens,
total_tokens: total_tokens,
completion_tokens: Some(completion_tokens),
},
}
}
}

impl From<StreamResponse> for CompletionStreamResponse {
fn from(resp: StreamResponse) -> Self {
let prompt_tokens = resp.details.as_ref().map(|x| x.prompt_tokens).unwrap_or(0);
let completion_tokens = resp
.details
.as_ref()
.map(|x| x.generated_tokens)
.unwrap_or(0);
let total_tokens = prompt_tokens + completion_tokens;

CompletionStreamResponse {
id: "null".to_string(),
object: "text_completion".to_string(),
created: 0,
model: "null".to_string(),
choices: vec![CompletionResponseStreamChoice {
index: 0,
text: resp.generated_text.unwrap_or_default(),
logprobs: None,
finish_reason: None,
}],
usage: Some(UsageInfo {
prompt_tokens: prompt_tokens,
total_tokens: total_tokens,
completion_tokens: Some(completion_tokens),
}),
}
}
}

#[cfg(test)]
mod tests {
use std::io::Write;
Expand Down
82 changes: 77 additions & 5 deletions router/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@ use crate::health::Health;
use crate::infer::{InferError, InferResponse, InferStreamResponse};
use crate::validation::ValidationError;
use crate::{
BestOfSequence, CompatGenerateRequest, Details, ErrorResponse, FinishReason,
GenerateParameters, GenerateRequest, GenerateResponse, HubModelInfo, Infer, Info, PrefillToken,
StreamDetails, StreamResponse, Token, Validation,
BestOfSequence, CompatGenerateRequest, CompletionRequest, CompletionResponse,
CompletionStreamResponse, Details, ErrorResponse, FinishReason, GenerateParameters,
GenerateRequest, GenerateResponse, HubModelInfo, Infer, Info, PrefillToken, StreamDetails,
StreamResponse, Token, Validation,
};
use axum::extract::Extension;
use axum::http::{HeaderMap, Method, StatusCode};
Expand Down Expand Up @@ -77,6 +78,66 @@ async fn compat_generate(
}
}

/// Generate tokens if `stream == false` or a stream of token if `stream == true`
#[utoipa::path(
post,
tag = "LoRAX",
path = "/v1/completions",
request_body = CompletionRequest,
responses(
(status = 200, description = "Generated Text",
content(
("application/json" = CompletionResponse),
("text/event-stream" = CompletionStreamResponse),
)),
(status = 424, description = "Generation Error", body = ErrorResponse,
example = json ! ({"error": "Request failed during generation"})),
(status = 429, description = "Model is overloaded", body = ErrorResponse,
example = json ! ({"error": "Model is overloaded"})),
(status = 422, description = "Input validation error", body = ErrorResponse,
example = json ! ({"error": "Input validation error"})),
(status = 500, description = "Incomplete generation", body = ErrorResponse,
example = json ! ({"error": "Incomplete generation"})),
)
)]
#[instrument(skip(infer, req))]
async fn completions_v1(
default_return_full_text: Extension<bool>,
infer: Extension<Infer>,
req: Json<CompletionRequest>,
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
let req = req.0;
let mut gen_req = CompatGenerateRequest::from(req);

// default return_full_text given the pipeline_tag
if gen_req.parameters.return_full_text.is_none() {
gen_req.parameters.return_full_text = Some(default_return_full_text.0)
}

// switch on stream
if gen_req.stream {
let callback = move |resp: StreamResponse| {
Event::default()
.json_data(CompletionStreamResponse::from(resp))
.map_or_else(
|err| {
tracing::error!("Failed to serialize CompletionStreamResponse: {err}");
Event::default()
},
|data| data,
)
};

let (headers, stream) =
generate_stream_with_callback(infer, Json(gen_req.into()), callback).await;
Ok((headers, Sse::new(stream).keep_alive(KeepAlive::default())).into_response())
} else {
let (headers, generation) = generate(infer, Json(gen_req.into())).await?;
// wrap generation inside a Vec to match api-inference
Ok((headers, Json(vec![CompletionResponse::from(generation.0)])).into_response())
}
}

/// LoRAX endpoint info
#[utoipa::path(
get,
Expand Down Expand Up @@ -351,6 +412,16 @@ async fn generate_stream(
HeaderMap,
Sse<impl Stream<Item = Result<Event, Infallible>>>,
) {
let callback = |resp: StreamResponse| Event::default().json_data(resp).unwrap();
let (headers, stream) = generate_stream_with_callback(infer, req, callback).await;
(headers, Sse::new(stream).keep_alive(KeepAlive::default()))
}

async fn generate_stream_with_callback(
infer: Extension<Infer>,
req: Json<GenerateRequest>,
callback: impl Fn(StreamResponse) -> Event,
) -> (HeaderMap, impl Stream<Item = Result<Event, Infallible>>) {
let span = tracing::Span::current();
let start_time = Instant::now();
metrics::increment_counter!("lorax_request_count");
Expand Down Expand Up @@ -479,7 +550,7 @@ async fn generate_stream(
details
};

yield Ok(Event::default().json_data(stream_token).unwrap());
yield Ok(callback(stream_token));
break;
}
}
Expand Down Expand Up @@ -510,7 +581,7 @@ async fn generate_stream(
}
};

(headers, Sse::new(stream).keep_alive(KeepAlive::default()))
(headers, stream)
}

/// Prometheus metrics scrape endpoint
Expand Down Expand Up @@ -699,6 +770,7 @@ pub async fn run(
.route("/info", get(get_model_info))
.route("/generate", post(generate))
.route("/generate_stream", post(generate_stream))
.route("/v1/completions", post(completions_v1))
// AWS Sagemaker route
.route("/invocations", post(compat_generate))
// Base Health route
Expand Down

0 comments on commit 8e479ce

Please sign in to comment.