Skip to content

Commit

Permalink
Added prompt_tokens to the response (#165)
Browse files Browse the repository at this point in the history
  • Loading branch information
tgaddair authored Jan 9, 2024
1 parent 57d5470 commit d88ffed
Show file tree
Hide file tree
Showing 14 changed files with 89 additions and 221 deletions.
4 changes: 4 additions & 0 deletions clients/python/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,8 @@ class BestOfSequence:
class Details:
# Generation finish reason
finish_reason: FinishReason
# Number of prompt tokens
prompt_tokens: int
# Number of generated tokens
generated_tokens: int
# Sampling seed if sampling was activated
Expand All @@ -205,6 +207,8 @@ class Response:
class StreamDetails:
# Generation finish reason
finish_reason: FinishReason
# Number of prompt tokens
prompt_tokens: int
# Number of generated tokens
generated_tokens: int
# Sampling seed if sampling was activated
Expand Down
4 changes: 4 additions & 0 deletions clients/python/lorax/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,8 @@ class BestOfSequence(BaseModel):
class Details(BaseModel):
# Generation finish reason
finish_reason: FinishReason
# Number of prompt tokens
prompt_tokens: int
# Number of generated tokens
generated_tokens: int
# Sampling seed if sampling was activated
Expand All @@ -222,6 +224,8 @@ class Response(BaseModel):
class StreamDetails(BaseModel):
# Generation finish reason
finish_reason: FinishReason
# Number of prompt tokens
prompt_tokens: int
# Number of generated tokens
generated_tokens: int
# Sampling seed if sampling was activated
Expand Down
51 changes: 0 additions & 51 deletions clients/python/tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,51 +0,0 @@
import pytest

from lorax import __version__
from huggingface_hub.utils import build_hf_headers


@pytest.fixture
def flan_t5_xxl():
return "google/flan-t5-xxl"


@pytest.fixture
def fake_model():
return "fake/model"


@pytest.fixture
def unsupported_model():
return "gpt2"


@pytest.fixture
def base_url():
return "https://api-inference.huggingface.co/models"


@pytest.fixture
def bloom_url(base_url, bloom_model):
return f"{base_url}/{bloom_model}"


@pytest.fixture
def flan_t5_xxl_url(base_url, flan_t5_xxl):
return f"{base_url}/{flan_t5_xxl}"


@pytest.fixture
def fake_url(base_url, fake_model):
return f"{base_url}/{fake_model}"


@pytest.fixture
def unsupported_url(base_url, unsupported_model):
return f"{base_url}/{unsupported_model}"


@pytest.fixture(scope="session")
def hf_headers():
return build_hf_headers(
library_name="lorax-tests", library_version=__version__
)
150 changes: 0 additions & 150 deletions clients/python/tests/test_client.py

This file was deleted.

14 changes: 14 additions & 0 deletions docs/reference/openapi.json
Original file line number Diff line number Diff line change
Expand Up @@ -413,6 +413,7 @@
"type": "object",
"required": [
"finish_reason",
"prompt_tokens",
"generated_tokens",
"prefill",
"tokens"
Expand All @@ -428,6 +429,12 @@
"finish_reason": {
"$ref": "#/components/schemas/FinishReason"
},
"prompt_tokens": {
"type": "integer",
"format": "int32",
"example": 1,
"minimum": 0.0
},
"generated_tokens": {
"type": "integer",
"format": "int32",
Expand Down Expand Up @@ -773,12 +780,19 @@
"type": "object",
"required": [
"finish_reason",
"prompt_tokens",
"generated_tokens"
],
"properties": {
"finish_reason": {
"$ref": "#/components/schemas/FinishReason"
},
"prompt_tokens": {
"type": "integer",
"format": "int32",
"example": 1,
"minimum": 0.0
},
"generated_tokens": {
"type": "integer",
"format": "int32",
Expand Down
4 changes: 4 additions & 0 deletions docs/reference/python_client.md
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,8 @@ class BestOfSequence:
class Details:
# Generation finish reason
finish_reason: FinishReason
# Number of prompt tokens
prompt_tokens: int
# Number of generated tokens
generated_tokens: int
# Sampling seed if sampling was activated
Expand All @@ -205,6 +207,8 @@ class Response:
class StreamDetails:
# Generation finish reason
finish_reason: FinishReason
# Number of prompt tokens
prompt_tokens: int
# Number of generated tokens
generated_tokens: int
# Sampling seed if sampling was activated
Expand Down
2 changes: 2 additions & 0 deletions proto/generate.proto
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,8 @@ message Generation {
bool token_is_special = 6;
/// Complete generated text
optional GeneratedText generated_text = 7;
/// Prefill tokens length
uint32 prefill_tokens_length = 8;
}

message FilterBatchRequest {
Expand Down
37 changes: 26 additions & 11 deletions router/src/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,7 @@ impl Infer {
// Return values
let mut result_prefill = Vec::new();
let mut result_tokens = Vec::new();
let mut result_prefill_length = 0;
let mut result_generated_text = None;
let mut result_start = None;
let mut result_queued = None;
Expand All @@ -197,16 +198,22 @@ impl Infer {
while let Some(response) = stream.next().await {
match response? {
// Add prefill tokens
InferStreamResponse::Prefill(tokens) => {
InferStreamResponse::Prefill {
tokens,
tokens_length,
} => {
// Create Token objects
// We do that here instead of in the Python code as Rust for loops are faster
result_prefill = tokens
.ids
.into_iter()
.zip(tokens.logprobs.into_iter())
.zip(tokens.texts.into_iter())
.map(|((id, logprob), text)| PrefillToken { id, text, logprob })
.collect();
if let Some(tokens_val) = tokens {
result_prefill = tokens_val
.ids
.into_iter()
.zip(tokens_val.logprobs.into_iter())
.zip(tokens_val.texts.into_iter())
.map(|((id, logprob), text)| PrefillToken { id, text, logprob })
.collect();
}
result_prefill_length = tokens_length;
}
// Push last token
InferStreamResponse::Token(token) => result_tokens.push(token),
Expand All @@ -233,6 +240,7 @@ impl Infer {
Ok(InferResponse {
prefill: result_prefill,
tokens: result_tokens,
prompt_tokens: result_prefill_length,
generated_text,
queued,
start,
Expand Down Expand Up @@ -569,10 +577,13 @@ fn send_responses(

let mut stopped = false;

if let Some(prefill_tokens) = generation.prefill_tokens {
if generation.prefill_tokens_length > 0 {
// Send message
entry.response_tx.send_timeout(
Ok(InferStreamResponse::Prefill(prefill_tokens)),
Ok(InferStreamResponse::Prefill {
tokens: generation.prefill_tokens,
tokens_length: generation.prefill_tokens_length,
}),
Duration::from_millis(10),
)?;
}
Expand Down Expand Up @@ -629,7 +640,10 @@ fn send_errors(error: ClientError, entries: &mut IntMap<u64, Entry>) {
#[derive(Debug)]
pub(crate) enum InferStreamResponse {
// Optional first message
Prefill(PrefillTokens),
Prefill {
tokens: Option<PrefillTokens>,
tokens_length: u32,
},
// Intermediate messages
Token(Token),
// Last message
Expand All @@ -645,6 +659,7 @@ pub(crate) enum InferStreamResponse {
pub(crate) struct InferResponse {
pub(crate) prefill: Vec<PrefillToken>,
pub(crate) tokens: Vec<Token>,
pub(crate) prompt_tokens: u32,
pub(crate) generated_text: GeneratedText,
pub(crate) queued: Instant,
pub(crate) start: Instant,
Expand Down
Loading

0 comments on commit d88ffed

Please sign in to comment.