Skip to content

Commit

Permalink
speculative decoding
Browse files Browse the repository at this point in the history
  • Loading branch information
gante committed Dec 12, 2023
1 parent 4b759da commit 9b51da1
Show file tree
Hide file tree
Showing 3 changed files with 129 additions and 84 deletions.
78 changes: 27 additions & 51 deletions src/transformers/generation/candidate_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

import copy
import warnings
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union

import torch

Expand All @@ -28,7 +28,7 @@
class CandidateGenerator:
"""Abstract base class for all candidate generators that can be applied during assisted generation."""

def get_candidates(self, input_ids: torch.LongTensor) -> torch.LongTensor:
def get_candidates(self, input_ids: torch.LongTensor) -> Tuple[torch.LongTensor, Optional[torch.FloatTensor]]:
"""
Fetches the candidates to be tried for the current input.
Expand All @@ -37,8 +37,9 @@ def get_candidates(self, input_ids: torch.LongTensor) -> torch.LongTensor:
Indices of input sequence tokens in the vocabulary. [What are input IDs?](../glossary#input-ids)
Return:
`torch.LongTensor` of shape `(num_candidates, candidate_length)`: The candidate sequences to be assessed by
the model.
`torch.LongTensor` of shape `(batch_size, candidate_length)` containing the candidate sequences to be
assessed by the model and, optionally, a `torch.FloatTensor` of shape `(batch_size, candidate_length,
vocabulary_size)` containing the logits associated to each candidate.
"""
raise NotImplementedError(
f"{self.__class__} is an abstract class. Only classes inheriting this class can call `get_candidates`."
Expand Down Expand Up @@ -152,7 +153,7 @@ def __init__(
)
self.logits_processor = logits_processor

def get_candidates(self, input_ids: torch.LongTensor) -> torch.LongTensor:
def get_candidates(self, input_ids: torch.LongTensor) -> Tuple[torch.LongTensor, Optional[torch.FloatTensor]]:
"""
Fetches the candidates to be tried for the current input.
Expand All @@ -161,7 +162,9 @@ def get_candidates(self, input_ids: torch.LongTensor) -> torch.LongTensor:
Indices of input sequence tokens in the vocabulary. [What are input IDs?](../glossary#input-ids)
Return:
`torch.LongTensor` of shape `(num_candidates, candidate_length)`: The candidate sequences to be tried.
`torch.LongTensor` of shape `(batch_size, candidate_length)` containing the candidate sequences to be
assessed by the model and a `torch.FloatTensor` of shape `(batch_size, candidate_length,
vocabulary_size)` containing the logits associated to each candidate.
"""
# 1. If it is not the first round of candidate generation, prepare the inputs based on the input_ids length
# (which implicitly contains the number of accepted candidates from the previous round)
Expand All @@ -179,51 +182,24 @@ def get_candidates(self, input_ids: torch.LongTensor) -> torch.LongTensor:
)
self.assistant_kwargs = _prepare_token_type_ids(self.assistant_kwargs, new_cur_len)

# 2. Forecast next N tokens using the assistant model. This `for` block can be replaced with a `.generate()`
# call if we decide to add `past_key_values` as a possible output of generate, as we need access to the
# assistant cache to secure strong speedups.
candidate_input_ids = input_ids
for _ in range(int(self.num_assistant_tokens)):
# 2.1 prepare assistant model inputs
assistant_inputs = self.assistant_model.prepare_inputs_for_generation(
candidate_input_ids,
**self.assistant_kwargs,
)

# 2.2. check if the input ids length is correct
has_past_key_values = assistant_inputs.get("past_key_values", None) is not None
if has_past_key_values and assistant_inputs[self.input_ids_key].shape[-1] not in (1, 2):
raise ValueError("The length of the input ids in assistant inputs should be 1 or 2")

# 2.3. use the assistant model to obtain the next candidate logits
assistant_model_outputs = self.assistant_model(**assistant_inputs)

# 2.4. greedily select the next candidate token
if len(self.logits_processor) > 0:
assistant_model_outputs.logits[:, -1, :] = self.logits_processor(
candidate_input_ids, assistant_model_outputs.logits[:, -1, :]
)
new_token = assistant_model_outputs.logits[:, -1, :].argmax(dim=-1)
candidate_input_ids = torch.cat((candidate_input_ids, new_token[:, None]), dim=-1)

# 2.5. update assistant model inputs
if self.assistant_kwargs.get(self.attention_key, None) is not None:
mask = self.assistant_kwargs[self.attention_key]
self.assistant_kwargs[self.attention_key] = torch.cat(
[mask, mask.new_ones((mask.shape[0], 1))], dim=-1
)
self.assistant_kwargs["past_key_values"] = assistant_model_outputs.past_key_values

# 2.6. stop assistant generation on EOS
if self.eos_token_id_tensor is not None:
last_assistant_token_is_eos = new_token.tile(self.eos_token_id_tensor.shape[0], 1)
last_assistant_token_is_eos = (
~last_assistant_token_is_eos.ne(self.eos_token_id_tensor.unsqueeze(1)).prod(dim=0).bool()
)
if last_assistant_token_is_eos:
break

return candidate_input_ids
# 2. Forecast next N tokens using the assistant model.
assistant_generation_kwargs = {
self.input_ids_key: input_ids,
"do_sample": False,
"num_beams": 1,
"max_new_tokens": int(self.num_assistant_tokens),
"return_dict_in_generate": True,
"output_scores": True,
}
assistant_output = self.assistant_model.generate(**assistant_generation_kwargs, **self.assistant_kwargs)

# 3. Update variables for the next round of candidate generation
self.assistant_kwargs["past_key_values"] = assistant_output.past_key_values

# 4. Prepare variables for output
candidate_logits = torch.stack(assistant_output.scores, dim=1)
candidate_ids = assistant_output.sequences
return candidate_ids, candidate_logits

def update_candidate_strategy(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, num_matches: int):
"""
Expand Down
87 changes: 68 additions & 19 deletions src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4585,7 +4585,7 @@ def assisted_decoding(
cur_len = input_ids.shape[-1]

# 1. Fetch candidate sequences from a `CandidateGenerator`
candidate_input_ids = candidate_generator.get_candidates(input_ids)
candidate_input_ids, candidate_logits = candidate_generator.get_candidates(input_ids)
candidate_length = candidate_input_ids.shape[1] - input_ids.shape[1]
last_assistant_token_is_eos = (
~candidate_input_ids[:, -1]
Expand Down Expand Up @@ -4624,40 +4624,89 @@ def assisted_decoding(
for i in range(candidate_length + 1):
new_logits[:, i, :] = logits_warper(candidate_input_ids[:, : cur_len + i], new_logits[:, i, :])

# 3. Obtain the next tokens from the original model logits.
if do_sample:
probs = new_logits.softmax(dim=-1)
selected_tokens = torch.multinomial(probs[0, :, :], num_samples=1).squeeze(1)[None, :]
# 3. Select the accepted tokens. There are two possible cases:
# Case 1: `do_sample=True` and we have logits for the candidates (originally from speculative decoding)
# 👉 Apply algorithm 1 from the speculative decoding paper (https://arxiv.org/pdf/2211.17192.pdf).
# NOTE:Unless otherwise stated, the variable names match those in the paper.
if do_sample and candidate_logits is not None:
# Gets the probabilities from the logits. q_i and p_i denote the model and assistant (respectively)
# probabilities of the tokens selected by the assistant.
q = candidate_logits.softmax(dim=-1)
q_i = q[
:,
torch.range(0, candidate_length - 1, dtype=torch.int),
candidate_input_ids[:, -candidate_length:],
].squeeze(0, 1)
p = new_logits.softmax(dim=-1)
p_i = p[
:,
torch.range(0, candidate_length - 1, dtype=torch.int),
candidate_input_ids[:, -candidate_length:],
].squeeze(0, 1)
probability_ratio = p_i / q_i

# When probability_ratio > 1 (i.e. q_i(x) < p_i(x)), keep the token. Otherwise reject with
# p = 1 - probability_ratio (= keep with p=probability_ratio). Keep all the tokens until the first
# rejection
r_i = torch.rand_like(probability_ratio)
is_rejected = r_i > probability_ratio # equivalent: is_accepted = r_i <= probability_ratio
n_matches = (is_rejected.cumsum(dim=-1) < 1).sum() # this is `n` in algorithm 1

# Ensure we don't generate beyond max_len or an EOS token (not in algorithm 1, but needed for correct
# behavior)
if last_assistant_token_is_eos and n_matches == candidate_length:
n_matches -= 1
n_matches = min(n_matches, max_len - cur_len - 1)

# Next token selection: if there is a rejection, adjust the distribution from the main model before
# sampling.
gamma = candidate_logits.shape[1]
p_n_plus_1 = p[:, n_matches, :]
if n_matches < gamma:
q_n_plus_1 = q[:, n_matches, :]
p_prime = torch.clamp((p_n_plus_1 - q_n_plus_1), min=0).softmax(dim=-1)
else:
p_prime = p_n_plus_1
t = torch.multinomial(p_prime, num_samples=1).squeeze(1)[None, :]

# The selected tokens include the matches plus the next sampled token
selected_tokens = torch.cat((candidate_input_ids[:, :n_matches], t), dim=-1)

# Case 2: all other cases (originally from assisted generation) 👉 Compare the tokens selected from the
# original model logits with the candidate tokens. We can keep the candidate tokens until the first
# mismatch, or until the max length is reached.
else:
selected_tokens = new_logits.argmax(dim=-1)
if do_sample:
probs = new_logits.softmax(dim=-1)
selected_tokens = torch.multinomial(probs[0, :, :], num_samples=1).squeeze(1)[None, :]
else:
selected_tokens = new_logits.argmax(dim=-1)

candidate_new_tokens = candidate_input_ids[:, -candidate_length:]
n_matches = ((~(candidate_new_tokens == selected_tokens[:, :-1])).cumsum(dim=-1) < 1).sum()

# 4. Compare the argmax from the original model logits with the assistant forecasted tokens. We can keep
# the assistant forecasted tokens until the first mismatch, or until the max length is reached.
candidate_new_tokens = candidate_input_ids[:, -candidate_length:]
n_matches = ((~(candidate_new_tokens == selected_tokens[:, :-1])).cumsum(dim=-1) < 1).sum()
# Ensure we don't generate beyond max_len or an EOS token
if last_assistant_token_is_eos and n_matches == candidate_length:
n_matches -= 1
n_matches = min(n_matches, max_len - cur_len - 1)

# 5. Update variables according to the number of matching assistant tokens. Remember: the token generated
# 4. Update variables according to the number of matching assistant tokens. Remember: the token generated
# by the model after the last candidate match is also valid, as it is generated from a correct sequence.
# Because of this last token, assisted generation search reduces to a normal greedy search/sample if there
# is no match.

# 5.1. Ensure we don't generate beyond max_len or an EOS token
if last_assistant_token_is_eos and n_matches == candidate_length:
n_matches -= 1
n_matches = min(n_matches, max_len - cur_len - 1)

# 5.2. Get the valid continuation, after the matching tokens
# 4.1. Get the valid continuation, after the matching tokens
valid_tokens = selected_tokens[:, : n_matches + 1]
input_ids = torch.cat((input_ids, valid_tokens), dim=-1)
if streamer is not None:
streamer.put(valid_tokens.cpu())
new_cur_len = input_ids.shape[-1]

# 5.3. Discard past key values relative to unused assistant tokens
# 4.2. Discard past key values relative to unused assistant tokens
new_cache_size = new_cur_len - 1
outputs.past_key_values = _crop_past_key_values(self, outputs.past_key_values, new_cache_size)

# 6. Update the candidate generation strategy if needed
# 5. Update the candidate generation strategy if needed
candidate_generator.update_candidate_strategy(input_ids, new_logits, n_matches)

if synced_gpus and this_peer_finished:
Expand Down
48 changes: 34 additions & 14 deletions tests/generation/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3128,21 +3128,26 @@ def test_model_kwarg_assisted_decoding_decoder_only(self):
self.assertListEqual(outputs_assisted.tolist(), outputs_tti.tolist())

def test_model_kwarg_assisted_decoding_encoder_decoder(self):
"""
Tests that the following scenario is compatible with assisted generation:
1. encoder-decoder main model
2. encoder-decoder assistant model
3. both have a custom input
(e.g. Whisper)
"""

# PT-only test: TF doesn't support assisted decoding yet.
# Bart subclass with a kwarg that distorts the output
class FakeBart(BartForConditionalGeneration):
def forward(self, input_ids, foo=False, **kwargs):
outs = super().forward(input_ids, **kwargs)

if foo:
outs["logits"][:, :, :] = 0.0

return outs

def prepare_inputs_for_generation(self, *args, foo=False, encoder_outputs=None, **kwargs):
kwargs["encoder_outputs"] = encoder_outputs
inputs = super().prepare_inputs_for_generation(*args, **kwargs)

inputs["foo"] = foo
return inputs

Expand All @@ -3160,10 +3165,7 @@ def prepare_inputs_for_generation(self, *args, foo=False, encoder_outputs=None,
self.assertEqual(outputs_normal.shape, (1, 20))

# Should be different with foo
outputs_foo = model.generate(
input_ids,
foo=True,
)
outputs_foo = model.generate(input_ids, foo=True)
with self.assertRaises(AssertionError):
self.assertListEqual(outputs_foo.tolist(), outputs_normal.tolist())

Expand Down Expand Up @@ -3192,25 +3194,43 @@ def prepare_inputs_for_generation(self, *args, foo=False, encoder_outputs=None,
self.assertListEqual(outputs_assisted.tolist(), outputs_foo.tolist())

def test_assisted_decoding_encoder_decoder_shared_encoder(self):
"""
Tests that the following scenario is compatible with assisted generation:
1. encoder-decoder main model
2. decoder-only assistant model
3. both have a custom input
(e.g. DistilWhisper)
"""

# PT-only test: TF doesn't support assisted decoding yet.
# Bart subclass with a kwarg called foo that distorts the output
class FakeBart(BartForConditionalGeneration):
class FakeBartSeq2Seq(BartForConditionalGeneration):
def forward(self, input_ids, foo=False, **kwargs):
outs = super().forward(input_ids, **kwargs)

if foo:
outs["logits"][:, :, :] = 0.0

return outs

def prepare_inputs_for_generation(self, *args, foo=False, encoder_outputs=None, **kwargs):
kwargs["encoder_outputs"] = encoder_outputs
inputs = super().prepare_inputs_for_generation(*args, **kwargs)
inputs["foo"] = foo
return inputs

class FakeBartCausalLM(BartForCausalLM):
def forward(self, input_ids, attention_mask, past_key_values, foo=False, **kwargs):
outs = super().forward(input_ids, attention_mask, past_key_values=past_key_values, **kwargs)
if foo:
outs["logits"][:, :, :] = 0.0
return outs

def prepare_inputs_for_generation(self, *args, foo=False, encoder_outputs=None, **kwargs):
kwargs["encoder_outputs"] = encoder_outputs
inputs = super().prepare_inputs_for_generation(*args, **kwargs)
inputs["foo"] = foo
return inputs

model = FakeBart.from_pretrained("hf-internal-testing/tiny-random-BartForConditionalGeneration").to(
model = FakeBartSeq2Seq.from_pretrained("hf-internal-testing/tiny-random-BartForConditionalGeneration").to(
torch_device
)
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-BartForConditionalGeneration")
Expand All @@ -3229,9 +3249,9 @@ def prepare_inputs_for_generation(self, *args, foo=False, encoder_outputs=None,
self.assertListEqual(outputs_foo.tolist(), outputs_normal.tolist())

# Assistant model
assistant = BartForCausalLM.from_pretrained("hf-internal-testing/tiny-random-BartForConditionalGeneration").to(
torch_device
)
assistant = FakeBartCausalLM.from_pretrained(
"hf-internal-testing/tiny-random-BartForConditionalGeneration"
).to(torch_device)

# If assisted generation passes model_kwargs correctly, should be same as previous
outputs_assisted = model.generate(
Expand Down

0 comments on commit 9b51da1

Please sign in to comment.