Skip to content

Commit

Permalink
Add batched_generate_fn() (#1702)
Browse files Browse the repository at this point in the history
Co-authored-by: Sebastian Raschka <[email protected]>
  • Loading branch information
apaz-cli and rasbt committed Sep 5, 2024
1 parent fdf6a12 commit 1d37f9a
Show file tree
Hide file tree
Showing 3 changed files with 367 additions and 2 deletions.
146 changes: 146 additions & 0 deletions litgpt/generate/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,23 @@ def generate_fn(
include_prompt: bool,
include_eos: bool,
) -> Iterator[torch.Tensor]:
"""
Generates tokens for a single prompt.
Args:
model: The model to use.
prompt: The tokenized prompt to generate from.
max_returned_tokens: The maximum number of new tokens to return. Does not include the prompt tokens.
temperature: The temp to pass to sample().
top_k: The top_k to pass to sample().
top_p: The top_p to pass to sample().
stop_tokens: A tuple of stop sequences. If any of the sequences are generated, the generation stops early before max_returned_tokens.
include_prompt: Whether to output the prompt tokens.
include_eos: Whether to output the stop tokens if generation stops early.
"""



prompt_size = prompt.size(0)
device = prompt.device

Expand Down Expand Up @@ -194,6 +211,135 @@ def generate_fn(
yield from tokens[yielded_idx:]


# TODO: Make include_eos work.
# TODO: Rewrite unbatched generate_fn to use batched_generate_fn.
@torch.inference_mode()
def batched_generate_fn(
model: GPT,
prompts: torch.Tensor,
max_returned_tokens: int,
*,
sample_args: Union[list[dict], dict],
stop_tokens: Tuple[List[int], ...] = (),
include_prompt: bool,
include_eos: bool,
) -> Iterator[list[Union[torch.Tensor, None]]]:
"""
Generates tokens for a batch of prompts.
Args:
model: The model to use.
prompts: A 2D tensor of shape [batch_size, prompt_length].
max_returned_tokens: The maximum number of new tokens to return. Does not include the prompt tokens.
sample_args: The dictionary of kwargs to pass to sample() for each each token for each index in the batch.
stop_tokens: A tuple of stop sequences. If any of the sequences are generated, the generation stops early before max_returned_tokens.
include_prompt: Whether to output the prompt tokens.
include_eos: Whether to output the stop tokens if generation stops early.
Yields:
A list of tokens for each prompt in the batch, or None if a stop sequence has already been encountered for that index in the batch.
"""

if prompts.ndim == 1:
prompts = prompts.unsqueeze(0)
assert prompts.ndim == 2, "Prompts must be a 2D tensor."

batch_size = prompts.size(0)
max_prompt_size = prompts.size(1)
device = prompts.device

if isinstance(sample_args, dict):
sample_args = [sample_args] * len(prompts)
else:
assert len(sample_args) == batch_size, "sample_args must have the length as the batch size."

# TODO: This check (and the one in generate_fn) is not sufficient. We do the proper checks in LLM.generate().
assert max_returned_tokens > max_prompt_size, f"Not enough space for {max_prompt_size} prompt tokens in a context length of {max_returned_tokens}."
if model.max_seq_length < max_returned_tokens - 1:
raise NotImplementedError(f"max_seq_length {model.max_seq_length} needs to be >= {max_returned_tokens - 1}")

# Yield the prompts if include_prompt is True
if include_prompt:
# TODO: Prompt length is padded, but they shouldn't all be the same length.
for i in range(max_prompt_size):
yield [prompt[i].view(-1) for prompt in prompts]

stop_progresses = [[0] * len(stop_tokens) for _ in range(batch_size)] # [batch_size, ~len(stop_tokens)]
stop_idxes = [-1] * batch_size
yielded_idx = 0

# Generate output tokens.
# The first token generated is the prefill token.
# The input_pos for this token is the width of the entire prompt.
# For subsequent iterations, it's the index in the context for the token that we're generating.
token_lists = [[] for _ in range(batch_size)]
tokens: torch.Tensor = prompts
prefill_token = True
input_pos = torch.arange(0, max_prompt_size, device=device, dtype=torch.int64)
for current_idx in range(max_returned_tokens - max_prompt_size):

# Generate the next token for each prompt in the batch.
# This is of shape [batch_size, 1].
tokens = batched_next_token(model, input_pos, tokens, sample_args)
for i in range(batch_size):
token_lists[i].append(tokens[i])
int_tokens = [token.item() for token in tokens]

# Check for stop sequences
# For each stop sequence, we keep a running total of how many are matched in stop_progress.
# If the current token matches the next token in the stop sequence, we increment the
# running total and hold off on yielding the token.
for batch_idx, int_token in enumerate(int_tokens):
if stop_idxes[batch_idx] != -1:
continue
for seq_idx, seq in enumerate(stop_tokens):
seq_pos = stop_progresses[batch_idx][seq_idx]
if seq_pos >= len(seq):
continue
if int_token == seq[seq_pos]:
stop_progresses[batch_idx][seq_idx] += 1
if stop_progresses[batch_idx][seq_idx] == len(seq):
stop_idxes[batch_idx] = current_idx
else:
stop_progresses[batch_idx][seq_idx] = 0

# Yield tokens that are not part of a stop sequence in progress.
# If there are no stop sequences, then that's all of them.
if len(stop_tokens) != 0:
safe_idxes = [len(token_lists[i]) - max(stop_progresses[i]) for i in range(batch_size)]
else:
safe_idxes = [current_idx + 1] # include the token just generated
safe_idx = min(safe_idxes)

if yielded_idx < safe_idx:
for idx in range(yielded_idx, safe_idx):
y_tokens = [token_lists[i][idx] if (stop_idxes[i] == -1 or idx < stop_idxes[i]) else None for i in range(batch_size)]
if all(y is None for y in y_tokens):
return
yield y_tokens
yielded_idx = safe_idx

# Update input_pos for the next iteration.
if prefill_token:
prefill_token = False

# TODO: Make the model support a batched input_pos of shape [batch_size, 1].
# The kvcache has been fixed, but the rope cache is still broken.
input_pos = torch.tensor([max_prompt_size], device=device, dtype=torch.int64)
else:
input_pos.add_(1)

# Yield any remaining tokens
max_token_lists = max(len(l) for l in token_lists)
if yielded_idx < max_token_lists:
for idx in range(yielded_idx, max_token_lists):
y_tokens = [token_lists[i][idx] if (stop_idxes[i] == -1 or idx < stop_idxes[i]) else None for i in range(batch_size)]
if all(y is None for y in y_tokens):
return
yield y_tokens
return


@torch.inference_mode()
def generate(
model: GPT,
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ license = { file = "LICENSE" }

dependencies = [
"torch>=2.2.0",
"numpy<2.0",
"lightning==2.4.0.dev20240728",
"jsonargparse[signatures]>=4.27.6",
"huggingface_hub>=0.23.5", # download models
Expand Down
Loading

0 comments on commit 1d37f9a

Please sign in to comment.