diff --git a/litgpt/generate/base.py b/litgpt/generate/base.py index 5a2f443b7..201a707dc 100644 --- a/litgpt/generate/base.py +++ b/litgpt/generate/base.py @@ -127,6 +127,23 @@ def generate_fn( include_prompt: bool, include_eos: bool, ) -> Iterator[torch.Tensor]: + """ + Generates tokens for a single prompt. + + Args: + model: The model to use. + prompt: The tokenized prompt to generate from. + max_returned_tokens: The maximum number of new tokens to return. Does not include the prompt tokens. + temperature: The temp to pass to sample(). + top_k: The top_k to pass to sample(). + top_p: The top_p to pass to sample(). + stop_tokens: A tuple of stop sequences. If any of the sequences are generated, the generation stops early before max_returned_tokens. + include_prompt: Whether to output the prompt tokens. + include_eos: Whether to output the stop tokens if generation stops early. + """ + + + prompt_size = prompt.size(0) device = prompt.device @@ -194,6 +211,135 @@ def generate_fn( yield from tokens[yielded_idx:] +# TODO: Make include_eos work. +# TODO: Rewrite unbatched generate_fn to use batched_generate_fn. +@torch.inference_mode() +def batched_generate_fn( + model: GPT, + prompts: torch.Tensor, + max_returned_tokens: int, + *, + sample_args: Union[list[dict], dict], + stop_tokens: Tuple[List[int], ...] = (), + include_prompt: bool, + include_eos: bool, +) -> Iterator[list[Union[torch.Tensor, None]]]: + """ + Generates tokens for a batch of prompts. + + Args: + model: The model to use. + prompts: A 2D tensor of shape [batch_size, prompt_length]. + max_returned_tokens: The maximum number of new tokens to return. Does not include the prompt tokens. + sample_args: The dictionary of kwargs to pass to sample() for each each token for each index in the batch. + stop_tokens: A tuple of stop sequences. If any of the sequences are generated, the generation stops early before max_returned_tokens. + include_prompt: Whether to output the prompt tokens. + include_eos: Whether to output the stop tokens if generation stops early. + + Yields: + A list of tokens for each prompt in the batch, or None if a stop sequence has already been encountered for that index in the batch. + """ + + if prompts.ndim == 1: + prompts = prompts.unsqueeze(0) + assert prompts.ndim == 2, "Prompts must be a 2D tensor." + + batch_size = prompts.size(0) + max_prompt_size = prompts.size(1) + device = prompts.device + + if isinstance(sample_args, dict): + sample_args = [sample_args] * len(prompts) + else: + assert len(sample_args) == batch_size, "sample_args must have the length as the batch size." + + # TODO: This check (and the one in generate_fn) is not sufficient. We do the proper checks in LLM.generate(). + assert max_returned_tokens > max_prompt_size, f"Not enough space for {max_prompt_size} prompt tokens in a context length of {max_returned_tokens}." + if model.max_seq_length < max_returned_tokens - 1: + raise NotImplementedError(f"max_seq_length {model.max_seq_length} needs to be >= {max_returned_tokens - 1}") + + # Yield the prompts if include_prompt is True + if include_prompt: + # TODO: Prompt length is padded, but they shouldn't all be the same length. + for i in range(max_prompt_size): + yield [prompt[i].view(-1) for prompt in prompts] + + stop_progresses = [[0] * len(stop_tokens) for _ in range(batch_size)] # [batch_size, ~len(stop_tokens)] + stop_idxes = [-1] * batch_size + yielded_idx = 0 + + # Generate output tokens. + # The first token generated is the prefill token. + # The input_pos for this token is the width of the entire prompt. + # For subsequent iterations, it's the index in the context for the token that we're generating. + token_lists = [[] for _ in range(batch_size)] + tokens: torch.Tensor = prompts + prefill_token = True + input_pos = torch.arange(0, max_prompt_size, device=device, dtype=torch.int64) + for current_idx in range(max_returned_tokens - max_prompt_size): + + # Generate the next token for each prompt in the batch. + # This is of shape [batch_size, 1]. + tokens = batched_next_token(model, input_pos, tokens, sample_args) + for i in range(batch_size): + token_lists[i].append(tokens[i]) + int_tokens = [token.item() for token in tokens] + + # Check for stop sequences + # For each stop sequence, we keep a running total of how many are matched in stop_progress. + # If the current token matches the next token in the stop sequence, we increment the + # running total and hold off on yielding the token. + for batch_idx, int_token in enumerate(int_tokens): + if stop_idxes[batch_idx] != -1: + continue + for seq_idx, seq in enumerate(stop_tokens): + seq_pos = stop_progresses[batch_idx][seq_idx] + if seq_pos >= len(seq): + continue + if int_token == seq[seq_pos]: + stop_progresses[batch_idx][seq_idx] += 1 + if stop_progresses[batch_idx][seq_idx] == len(seq): + stop_idxes[batch_idx] = current_idx + else: + stop_progresses[batch_idx][seq_idx] = 0 + + # Yield tokens that are not part of a stop sequence in progress. + # If there are no stop sequences, then that's all of them. + if len(stop_tokens) != 0: + safe_idxes = [len(token_lists[i]) - max(stop_progresses[i]) for i in range(batch_size)] + else: + safe_idxes = [current_idx + 1] # include the token just generated + safe_idx = min(safe_idxes) + + if yielded_idx < safe_idx: + for idx in range(yielded_idx, safe_idx): + y_tokens = [token_lists[i][idx] if (stop_idxes[i] == -1 or idx < stop_idxes[i]) else None for i in range(batch_size)] + if all(y is None for y in y_tokens): + return + yield y_tokens + yielded_idx = safe_idx + + # Update input_pos for the next iteration. + if prefill_token: + prefill_token = False + + # TODO: Make the model support a batched input_pos of shape [batch_size, 1]. + # The kvcache has been fixed, but the rope cache is still broken. + input_pos = torch.tensor([max_prompt_size], device=device, dtype=torch.int64) + else: + input_pos.add_(1) + + # Yield any remaining tokens + max_token_lists = max(len(l) for l in token_lists) + if yielded_idx < max_token_lists: + for idx in range(yielded_idx, max_token_lists): + y_tokens = [token_lists[i][idx] if (stop_idxes[i] == -1 or idx < stop_idxes[i]) else None for i in range(batch_size)] + if all(y is None for y in y_tokens): + return + yield y_tokens + return + + @torch.inference_mode() def generate( model: GPT, diff --git a/pyproject.toml b/pyproject.toml index 28de27175..dc7cd31e8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -10,6 +10,7 @@ license = { file = "LICENSE" } dependencies = [ "torch>=2.2.0", + "numpy<2.0", "lightning==2.4.0.dev20240728", "jsonargparse[signatures]>=4.27.6", "huggingface_hub>=0.23.5", # download models diff --git a/tests/test_batch.py b/tests/test_batch.py index 0f6539515..47dbbec5f 100644 --- a/tests/test_batch.py +++ b/tests/test_batch.py @@ -2,14 +2,42 @@ import pytest import warnings from pathlib import Path +import lightning as L import litgpt -from litgpt.generate.base import next_token, batched_next_token +from litgpt.generate.base import ( + next_token, + batched_next_token, + batched_generate_fn, + generate_fn, +) from litgpt.api import LLM, GPT from litgpt.scripts.download import download_from_hub from tests.conftest import RunIf + warnings.filterwarnings("ignore") + +def create_llm(tmp_path, batch_size, max_seq_length, device) -> tuple[LLM, GPT]: + + L.seed_everything(42) + + model_name = "microsoft/phi-2" + download_from_hub(repo_id=model_name, tokenizer_only=True, checkpoint_dir=tmp_path) + + llm: LLM = LLM.load( + model_name, + tokenizer_dir=Path(tmp_path / model_name), + init="random", + ) + model: GPT = llm.model + model.set_kv_cache( + batch_size=batch_size, max_seq_length=max_seq_length, device=device + ) + + return llm, model + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="Test requires a GPU.") def test_batched_equivalence(tmp_path): @@ -54,7 +82,9 @@ def test_batched_equivalence(tmp_path): model.clear_kv_cache() model.set_kv_cache(batch_size=batch_size, max_seq_length=50, device="cuda:0") - toks_1: torch.Tensor = batched_next_token(model, input_pos_1, batch_x1, sample_kwargs) + toks_1: torch.Tensor = batched_next_token( + model, input_pos_1, batch_x1, sample_kwargs + ) toks_2: torch.Tensor = batched_next_token(model, input_pos_2, toks_1, sample_kwargs) assert toks_1.ndim == 2 @@ -106,3 +136,191 @@ def test_simple_batch(): torch.testing.assert_close(outs0, outs0_ref) torch.testing.assert_close(outs1, outs1_ref) torch.backends.cuda.matmul.allow_tf32 = old_allow_tf32 + + +@RunIf(min_cuda_gpus=1) +def test_batch_generate(tmp_path): + + torch.use_deterministic_algorithms(True) + + device = "cuda:0" + batch_size = 3 + sample_kwargs = {"top_k": 1} + llm, model = create_llm(tmp_path, batch_size, 50, device) + + batch_x = torch.tensor( + [ + [43993, 25, 1867, 466, 32660, 17485, 4483, 30, 198, 26410], + [25, 1867, 466, 32660, 17485, 4483, 30, 198, 26410, 7596], + [1867, 466, 32660, 17485, 4483, 30, 198, 26410, 7596, 7596], + ], + device=device, + dtype=torch.int64, + ) + + # Generate tokens + tokens = [] + for l in batched_generate_fn( + model, + prompts=batch_x, + max_returned_tokens=50, + sample_args=sample_kwargs, + include_prompt=True, + include_eos=False, + ): + tokens.append([t.item() if t is not None else None for t in l]) + + def find_unique_stop(triplets): + # Initialize a dictionary to count all number occurrences + number_count = {} + + # Count occurrences of each number across all positions + for triplet in triplets: + for num in triplet: + number_count[num] = number_count.get(num, 0) + 1 + + # Initialize lists to store unique numbers for each position + unique_first = [] + unique_second = [] + unique_third = [] + + # Check each triplet + for a, b, c in triplets: + if number_count[a] == 1: + unique_first.append(a) + if number_count[b] == 1: + unique_second.append(b) + if number_count[c] == 1: + unique_third.append(c) + + import random # Seeded earlier + + random.shuffle(unique_first) + random.shuffle(unique_second) + random.shuffle(unique_third) + return [unique_first[0], unique_second[0], unique_third[0]] + + # Now that we know the randomly generated tokens, sample some tokens to stop each stream at. + stops = find_unique_stop(tokens[batch_x.size(1) :]) + first_stream = [t[0] for t in tokens if t[0] is not None] + second_stream = [t[1] for t in tokens if t[1] is not None] + third_stream = [t[2] for t in tokens if t[2] is not None] + + # Let's slice the streams at the stop tokens. + stop_idxes = [ + first_stream.index(stops[0]), + second_stream.index(stops[1]), + third_stream.index(stops[2]), + ] + + # While we're at it, grab the last token that would be generated before stopping. + last_tokens = [ + first_stream[stop_idxes[0] - 1], + second_stream[stop_idxes[1] - 1], + third_stream[stop_idxes[2] - 1], + ] + + for t in tokens: + print(t) + + # Now we generate again, stopping early at the stop tokens. + tokens = [] + for l in batched_generate_fn( + model, + prompts=batch_x, + max_returned_tokens=50, + stop_tokens=[(s,) for s in stops], + sample_args=sample_kwargs, + include_prompt=True, + include_eos=False, + ): + tokens.append([t.item() if t is not None else None for t in l]) + + # Finally, assert that the streams are correct. + + first_stream = [t[0] for t in tokens if t[0] is not None] + print(first_stream) + print(len(first_stream), stop_idxes[0]) + assert len(first_stream) == stop_idxes[0] + assert first_stream[-1] == last_tokens[0] + + second_stream = [t[1] for t in tokens if t[1] is not None] + print(second_stream) + print(len(second_stream), stop_idxes[1]) + assert len(second_stream) == stop_idxes[1] + assert second_stream[-1] == last_tokens[1] + + third_stream = [t[2] for t in tokens if t[2] is not None] + print(third_stream) + print(len(third_stream), stop_idxes[2]) + assert len(third_stream) == stop_idxes[2] + assert third_stream[-1] == last_tokens[2] + + torch.use_deterministic_algorithms(False) + + # for t in llm.tokenizer.decode_stream([torch.tensor(i) for i in first_stream]): + # print(t, end="", flush=True) + # print() + + +@RunIf(min_cuda_gpus=1) +def test_batch_generate_equivalence(tmp_path): + + torch.use_deterministic_algorithms(True) + + device = "cuda:0" + batch_size = 3 + sample_kwargs = {"top_k": 1} + llm, model = create_llm(tmp_path, batch_size, 50, device) + + batch_x = torch.tensor( + [ + [43993, 25, 1867, 466, 32660, 17485, 4483, 30, 198, 26410], + [25, 1867, 466, 32660, 17485, 4483, 30, 198, 26410, 7596], + [1867, 466, 32660, 17485, 4483, 30, 198, 26410, 7596, 7596], + ], + device=device, + dtype=torch.int64, + ) + + # The other test tests the stop_tokens functionality much more exhaustively, we'll just generate and compare 50 tokens here. + + batch_tokens = [] + for l in batched_generate_fn( + model, + prompts=batch_x, + max_returned_tokens=50, + sample_args=sample_kwargs, + include_prompt=False, + include_eos=False, + ): + batch_tokens.append([t.item() if t is not None else None for t in l]) + + first_stream = [t[0] for t in batch_tokens if t[0] is not None] + + batch_size = 1 + llm, model = create_llm(tmp_path, batch_size, 50, device) + + tokens = [] + for t in generate_fn( + model, + prompt=batch_x[0], + max_returned_tokens=50, + include_prompt=False, + include_eos=False, + **sample_kwargs, + ): + if t.size(0) == 1: + tokens.append(t.item()) + else: + tokens.extend(t.tolist()) + + torch.use_deterministic_algorithms(False) + + # TODO: (apaz-cli) This consistency test doesn't actually work at the moment. It's inconsistent. + # The output is really close... Something is going on here. For the moment, maybe this is close enough? + # Enough at least that we can start prototyping. + + print(first_stream) + print(tokens) + # assert first_stream == tokens