Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds streaming option to generate #1424

Closed
wants to merge 10 commits into from
Closed
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
114 changes: 93 additions & 21 deletions litgpt/deploy/serve.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
import json
from pathlib import Path
from typing import Dict, Any, Optional
from litgpt.utils import check_valid_checkpoint_dir
Expand All @@ -23,7 +24,7 @@
LitAPI, LitServer = object, object


class SimpleLitAPI(LitAPI):
class BaseLitAPI(LitAPI):
def __init__(self,
checkpoint_dir: Path,
precision: Optional[str] = None,
Expand Down Expand Up @@ -81,19 +82,35 @@ def decode_request(self, request: Dict[str, Any]) -> Any:
encoded = self.tokenizer.encode(prompt, device=self.device)
return encoded


class SimpleLitAPI(BaseLitAPI):
def __init__(self,
checkpoint_dir: Path,
precision: Optional[str] = None,
temperature: float = 0.8,
top_k: int = 50,
top_p: float = 1.0,
max_new_tokens: int = 50):
super().__init__(checkpoint_dir, precision, temperature, top_k, top_p, max_new_tokens)

def setup(self, device: str):
super().setup(device)

def predict(self, inputs: torch.Tensor) -> Any:
# Run the model on the input and return the output.
prompt_length = inputs.size(0)
max_returned_tokens = prompt_length + self.max_new_tokens

y = generate(
self.model,
inputs,
max_returned_tokens,
temperature=self.temperature,
top_k=self.top_k,
top_p=self.top_p,
eos_id=self.tokenizer.eos_id
y = next(
generate(
self.model,
inputs,
max_returned_tokens,
temperature=self.temperature,
top_k=self.top_k,
top_p=self.top_p,
eos_id=self.tokenizer.eos_id
)
)

for block in self.model.transformer.h:
Expand All @@ -106,6 +123,41 @@ def encode_response(self, output: torch.Tensor) -> Dict[str, Any]:
return {"output": decoded_output}


class StreamLitAPI(BaseLitAPI):
def __init__(self,
checkpoint_dir: Path,
precision: Optional[str] = None,
temperature: float = 0.8,
top_k: int = 50,
top_p: float = 1.0,
max_new_tokens: int = 50):
super().__init__(checkpoint_dir, precision, temperature, top_k, top_p, max_new_tokens)

def setup(self, device: str):
super().setup(device)

def predict(self, inputs: torch.Tensor) -> Any:
# Run the model on the input and return the output.
prompt_length = inputs.size(0)
max_returned_tokens = prompt_length + self.max_new_tokens

for block in self.model.transformer.h:
block.attn.kv_cache.reset_parameters()

yield from generate(
self.model,
inputs,
max_returned_tokens,
temperature=self.temperature,
top_k=self.top_k,
top_p=self.top_p,
eos_id=self.tokenizer.eos_id
)

def encode_response(self, output):
yield {"output": self.tokenizer.decode(next(output))}
rasbt marked this conversation as resolved.
Show resolved Hide resolved


def run_server(
checkpoint_dir: Path = Path("checkpoints"),
precision: Optional[str] = None,
Expand All @@ -115,7 +167,8 @@ def run_server(
max_new_tokens: int = 50,
devices: int = 1,
accelerator: str = "auto",
port: int = 8000
port: int = 8000,
stream: bool = False
) -> None:
"""Serve a LitGPT model using LitServe

Expand Down Expand Up @@ -146,19 +199,38 @@ def run_server(
accelerator: The type of accelerator to use. For example, "auto", "cuda", "cpu", or "mps".
The "auto" setting (default) chooses a GPU if available, and otherwise uses a CPU.
port: The network port number on which the model is configured to be served.
stream: Whether to stream the responses.
"""
check_valid_checkpoint_dir(checkpoint_dir, model_filename="lit_model.pth")

server = LitServer(
SimpleLitAPI(
checkpoint_dir=checkpoint_dir,
precision=precision,
temperature=temperature,
top_k=top_k,
top_p=top_p,
max_new_tokens=max_new_tokens,
),
accelerator=accelerator,
devices=devices)
if not stream:

server = LitServer(
SimpleLitAPI(
checkpoint_dir=checkpoint_dir,
precision=precision,
temperature=temperature,
top_k=top_k,
top_p=top_p,
max_new_tokens=max_new_tokens,
),
accelerator=accelerator,
devices=devices
)

else:
server = LitServer(
StreamLitAPI(
checkpoint_dir=checkpoint_dir,
precision=precision,
temperature=temperature,
top_k=top_k,
top_p=top_p,
max_new_tokens=max_new_tokens,
),
accelerator=accelerator,
devices=devices,
stream=True
)

server.run(port=port)
34 changes: 28 additions & 6 deletions litgpt/generate/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import sys
import time
from pathlib import Path
from typing import Any, Literal, Optional
from typing import Any, Literal, Optional, Generator

import lightning as L
import torch
Expand Down Expand Up @@ -78,7 +78,8 @@ def generate(
top_k: Optional[int] = None,
top_p: float = 1.0,
eos_id: Optional[int] = None,
) -> torch.Tensor:
stream: bool = False,
rasbt marked this conversation as resolved.
Show resolved Hide resolved
) -> Generator[torch.Tensor, None, None]:
"""Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested.

The implementation of this function is modified from A. Karpathy's nanoGPT.
Expand All @@ -104,6 +105,8 @@ def generate(
For more details, see https://arxiv.org/abs/1904.09751
or https://huyenchip.com/2024/01/16/sampling.html#top_p
eos_id: If specified, stop generating any more token once the <eos> token is triggered.
stream: If True, yields tokens as they are generated instead of returning them all at once.

"""
T = prompt.size(0)
assert max_returned_tokens > T
Expand All @@ -120,6 +123,10 @@ def generate(
model, torch.arange(0, T, device=device), prompt.view(1, -1), temperature=temperature, top_k=top_k, top_p=top_p
).clone()
tokens.append(token)

if stream: # Otherwise 1 token is missing (see tests)
yield token

for _ in range(2, max_returned_tokens - T + 1):
token = next_token(
model, input_pos, token.view(1, -1), temperature=temperature, top_k=top_k, top_p=top_p
Expand All @@ -128,7 +135,11 @@ def generate(
if token == eos_id:
break
input_pos = input_pos.add_(1)
return torch.cat(tokens)
if stream:
yield token

if not stream:
yield torch.cat(tokens)


@torch.inference_mode()
Expand All @@ -144,6 +155,7 @@ def main(
quantize: Optional[Literal["bnb.nf4", "bnb.nf4-dq", "bnb.fp4", "bnb.fp4-dq", "bnb.int8"]] = None,
precision: Optional[str] = None,
compile: bool = False,
stream: bool = False,
) -> None:
"""Generates text samples based on a pre-trained model and tokenizer.

Expand Down Expand Up @@ -175,6 +187,7 @@ def main(
for more details, see https://github.com/Lightning-AI/litgpt/blob/main/tutorials/quantize.md
precision: Indicates the Fabric precision setting to use.
compile: Whether to compile the model.
stream: If True, yields tokens as they are generated instead of returning them all at once.
"""
precision = precision or get_default_supported_precision(training=False)

Expand Down Expand Up @@ -231,12 +244,21 @@ def main(
L.seed_everything(1234)
for i in range(num_samples):
t0 = time.perf_counter()
y = generate(model, encoded, max_returned_tokens, temperature=temperature, top_k=top_k, top_p=top_p, eos_id=tokenizer.eos_id)
if stream:
tokens_generated = 0
for token in generate(model, encoded, max_returned_tokens, temperature=temperature, top_k=top_k, top_p=top_p, eos_id=tokenizer.eos_id, stream=stream):
tokens_generated += 1
fabric.print(tokenizer.decode(token), end="", flush=True)
fabric.print("")
else:
y_gen = generate(model, encoded, max_returned_tokens, temperature=temperature, top_k=top_k, top_p=top_p, eos_id=tokenizer.eos_id, stream=stream)
y = list(y_gen)[0]
fabric.print(tokenizer.decode(y))
tokens_generated = y.size(0) - prompt_length

t = time.perf_counter() - t0
for block in model.transformer.h:
block.attn.kv_cache.reset_parameters()
fabric.print(tokenizer.decode(y))
tokens_generated = y.size(0) - prompt_length
fabric.print(
f"Time for inference {i + 1}: {t:.02f} sec total, {tokens_generated / t:.02f} tokens/sec", file=sys.stderr
)
Expand Down
62 changes: 46 additions & 16 deletions tests/test_generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def multinomial(*args, **kwargs):
return out

with mock.patch("litgpt.generate.base.multinomial_num_samples_1", multinomial):
out = generate.generate(model, input_idx, T + max_new_tokens, top_k=4)
out = next(generate.generate(model, input_idx, T + max_new_tokens, top_k=4))

assert out.size(0) == T + max_new_tokens
multinomial_results = torch.hstack(multinomial_results)
Expand All @@ -47,39 +47,69 @@ def multinomial(*args, **kwargs):
torch.testing.assert_close(out, expected)


@pytest.mark.parametrize(
"max_returned_tokens",
[15, 25, 20]
)
def test_generate_stream(max_returned_tokens):
T = 5
prompt = torch.randint(10, size=(T,))
config = Config(block_size=128, vocab_size=16, n_layer=1, n_head=4, n_embd=8)
model = GPT(config)
model.max_seq_length = 30
max_new_tokens = max_returned_tokens - T

model.set_kv_cache(batch_size=1)

multinomial_results = []

def multinomial(*args, **kwargs):
out = torch.multinomial(*args, **kwargs, num_samples=1)
multinomial_results.append(out)
return out

with mock.patch("litgpt.generate.base.multinomial_num_samples_1", multinomial):
token_generator = generate.generate(model, prompt, max_returned_tokens, stream=True)
generated_tokens = list(token_generator)

expected_length = min(max_new_tokens, len(multinomial_results))
assert len(generated_tokens) == expected_length


def test_main(fake_checkpoint_dir, monkeypatch, tensor_like):
config_path = fake_checkpoint_dir / "model_config.yaml"
config = {"block_size": 128, "vocab_size": 50, "n_layer": 2, "n_head": 4, "n_embd": 8, "rotary_percentage": 1}
config_path.write_text(yaml.dump(config))

module_mock = Mock()
module_mock.config.block_size = 128
module_mock.max_seq_length = 150
load_mock = Mock()
load_mock.return_value = load_mock
load_mock.return_value = module_mock
monkeypatch.setattr(generate, "load_checkpoint", load_mock)
tokenizer_mock = Mock()
tokenizer_mock.return_value.encode.return_value = torch.tensor([1, 2, 3])
tokenizer_mock.return_value.decode.return_value = "foo bar baz"
monkeypatch.setattr(generate, "Tokenizer", tokenizer_mock)
generate_mock = Mock()
generate_mock.return_value = torch.tensor([3, 2, 1])
monkeypatch.setattr(generate, "generate", generate_mock)

def generate_mock(model, prompt, max_returned_tokens, *, temperature, top_k, top_p, eos_id, stream):
if stream:
for i in range(max_returned_tokens - prompt.size(0)):
yield torch.tensor([3, 2, 1][i % 3])
else:
yield torch.cat([prompt] + [torch.tensor([3, 2, 1])] * (max_returned_tokens - prompt.size(0)))

generate_function_mock = Mock()
generate_function_mock.side_effect = generate_mock
monkeypatch.setattr(generate, "generate", generate_function_mock)

num_samples = 2
out, err = StringIO(), StringIO()
with redirect_stdout(out), redirect_stderr(err):
generate.main(temperature=2.0, top_k=2, top_p=0.9, num_samples=num_samples, checkpoint_dir=fake_checkpoint_dir)

assert len(tokenizer_mock.return_value.decode.mock_calls) == num_samples
assert torch.allclose(tokenizer_mock.return_value.decode.call_args[0][0], generate_mock.return_value)
assert (
generate_mock.mock_calls
== [call(ANY, tensor_like, 53, temperature=2.0, top_k=2, top_p=0.9, eos_id=tokenizer_mock.return_value.eos_id)]
* num_samples
)
# only the generated result is printed to stdout
assert out.getvalue() == "foo bar baz\n" * num_samples

assert out.getvalue().strip().split('\n') == ["foo bar baz"] * num_samples
assert "'padded_vocab_size': 512, 'n_layer': 2, 'n_head': 4" in err.getvalue()


Expand Down Expand Up @@ -117,8 +147,8 @@ def test_generate_different_results_with_different_top_p():
input_idx = torch.randint(10, size=(1,))

torch.manual_seed(123)
output1 = generate.generate(model, input_idx, 20, top_p=1.0)
output1 = next(generate.generate(model, input_idx, 20, top_p=1.0))
torch.manual_seed(123)
output2 = generate.generate(model, input_idx, 20, top_p=0.1)
output2 = next(generate.generate(model, input_idx, 20, top_p=0.1))

assert not torch.equal(output1, output2)
15 changes: 12 additions & 3 deletions tests/test_serve.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
from dataclasses import asdict
import json
import shutil

from lightning.fabric import seed_everything
Expand All @@ -10,7 +11,7 @@


from litgpt import GPT, Config
from litgpt.deploy.serve import SimpleLitAPI
from litgpt.deploy.serve import SimpleLitAPI, StreamLitAPI
from litgpt.scripts.download import download_from_hub


Expand All @@ -37,6 +38,14 @@ def test_simple(tmp_path):

with TestClient(server.app) as client:
response = client.post("/predict", json={"prompt": "Hello world"})
# Model is a small random model, not trained, hence the gibberish.
# We are just testing that the server works.
assert response.json()["output"][:19] == "Hello world statues"

# Test with streaming enabled
server = LitServer(
StreamLitAPI(checkpoint_dir=tmp_path, temperature=1, top_k=1),
accelerator=accelerator, devices=1, timeout=60, stream=True
)
with TestClient(server.app) as client:
response = client.post("/predict", json={"prompt": "Hello world"})

assert response.json()["output"][:19] == "Hello world statues"
Loading