Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Generate: speculative decoding #27979

Merged
merged 8 commits into from
Dec 19, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 90 additions & 18 deletions src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4624,40 +4624,57 @@ def assisted_decoding(
for i in range(candidate_length + 1):
new_logits[:, i, :] = logits_warper(candidate_input_ids[:, : cur_len + i], new_logits[:, i, :])

# 3. Obtain the next tokens from the original model logits.
if do_sample:
probs = new_logits.softmax(dim=-1)
selected_tokens = torch.multinomial(probs[0, :, :], num_samples=1).squeeze(1)[None, :]
# 3. Select the accepted tokens. There are two possible cases:
# Case 1: `do_sample=True` and we have logits for the candidates (originally from speculative decoding)
# 👉 Apply algorithm 1 from the speculative decoding paper (https://arxiv.org/pdf/2211.17192.pdf).
max_matches = max_len - cur_len - 1
if do_sample and candidate_logits is not None:
next_sampled_tokens, n_matches = _speculative_sampling(
candidate_input_ids,
candidate_logits,
candidate_length,
new_logits,
last_assistant_token_is_eos,
max_matches,
)
# The selected tokens include the matches plus the next sampled tokens
selected_tokens = torch.cat((candidate_input_ids[:, :n_matches], next_sampled_tokens), dim=-1)

# Case 2: all other cases (originally from assisted generation) 👉 Compare the tokens selected from the
# original model logits with the candidate tokens. We can keep the candidate tokens until the first
# mismatch, or until the max length is reached.
else:
selected_tokens = new_logits.argmax(dim=-1)
if do_sample:
probs = new_logits.softmax(dim=-1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this case still relevant? Not sure it's a good idea to have two "assisted decoding" do_sample=True cases in our generate. Should we maybe just deprecate this case?

selected_tokens = torch.multinomial(probs[0, :, :], num_samples=1).squeeze(1)[None, :]
else:
selected_tokens = new_logits.argmax(dim=-1)
Comment on lines +4647 to +4651
Copy link
Contributor

@patrickvonplaten patrickvonplaten Dec 15, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if do_sample:
probs = new_logits.softmax(dim=-1)
selected_tokens = torch.multinomial(probs[0, :, :], num_samples=1).squeeze(1)[None, :]
else:
selected_tokens = new_logits.argmax(dim=-1)
if do_sample:
probs = new_logits.softmax(dim=-1)
selected_tokens = torch.multinomial(probs[0, :, :], num_samples=1).squeeze(1)[None, :]
else:
selected_tokens = new_logits.argmax(dim=-1)

It's probably time to soon factor this out into something like:

selected_tokens = Categorical(new_logits / temperature).sample()

everywhere in generate

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes! Then equivalent sampling/non-sampling methods (e.g. greedy decoding/samplinh) could be merged into a single function, facilitating maintenance. I'm going to leave it to a follow-up PR, though, to keep this PR exclusively about speculative decoding.


# 4. Compare the argmax from the original model logits with the assistant forecasted tokens. We can keep
# the assistant forecasted tokens until the first mismatch, or until the max length is reached.
candidate_new_tokens = candidate_input_ids[:, -candidate_length:]
n_matches = ((~(candidate_new_tokens == selected_tokens[:, :-1])).cumsum(dim=-1) < 1).sum()
candidate_new_tokens = candidate_input_ids[:, -candidate_length:]
n_matches = ((~(candidate_new_tokens == selected_tokens[:, :-1])).cumsum(dim=-1) < 1).sum()
gante marked this conversation as resolved.
Show resolved Hide resolved

# 5. Update variables according to the number of matching assistant tokens. Remember: the token generated
# Ensure we don't generate beyond max_len or an EOS token
if last_assistant_token_is_eos and n_matches == candidate_length:
n_matches -= 1
n_matches = min(n_matches, max_matches)

# 4. Update variables according to the number of matching assistant tokens. Remember: the token generated
# by the model after the last candidate match is also valid, as it is generated from a correct sequence.
# Because of this last token, assisted generation search reduces to a normal greedy search/sample if there
# is no match.

# 5.1. Ensure we don't generate beyond max_len or an EOS token
if last_assistant_token_is_eos and n_matches == candidate_length:
n_matches -= 1
n_matches = min(n_matches, max_len - cur_len - 1)

# 5.2. Get the valid continuation, after the matching tokens
# 4.1. Get the valid continuation, after the matching tokens
valid_tokens = selected_tokens[:, : n_matches + 1]
input_ids = torch.cat((input_ids, valid_tokens), dim=-1)
if streamer is not None:
streamer.put(valid_tokens.cpu())
new_cur_len = input_ids.shape[-1]

# 5.3. Discard past key values relative to unused assistant tokens
# 4.2. Discard past key values relative to unused assistant tokens
new_cache_size = new_cur_len - 1
outputs.past_key_values = _crop_past_key_values(self, outputs.past_key_values, new_cache_size)

# 6. Update the candidate generation strategy if needed
# 5. Update the candidate generation strategy if needed
candidate_generator.update_candidate_strategy(input_ids, new_logits, n_matches)

if synced_gpus and this_peer_finished:
Expand Down Expand Up @@ -4755,6 +4772,61 @@ def assisted_decoding(
return input_ids


def _speculative_sampling(
candidate_input_ids,
candidate_logits,
candidate_length,
new_logits,
last_assistant_token_is_eos,
max_matches,
):
"""
Applies sampling as in the speculative decoding paper (https://arxiv.org/pdf/2211.17192.pdf, algorithm 1). Returns
the next selected token, as well as the number of candidate matches.

NOTE: Unless otherwise stated, the variable names match those in the paper.
"""
# Gets the probabilities from the logits. q_i and p_i denote the assistant and model probabilities of the tokens
# selected by the assistant, respectively.
q = candidate_logits.softmax(dim=-1)
q_i = q[
:,
torch.range(0, candidate_length - 1, dtype=torch.int),
candidate_input_ids[:, -candidate_length:],
].squeeze(0, 1)
p = new_logits.softmax(dim=-1)
p_i = p[
:,
torch.range(0, candidate_length - 1, dtype=torch.int),
candidate_input_ids[:, -candidate_length:],
].squeeze(0, 1)
probability_ratio = p_i / q_i

# When probability_ratio > 1 (i.e. q_i(x) < p_i(x), or "assistant probability of the candidate token is smaller
# than the model probability for the same token"), keep the token. Otherwise reject with p = 1 - probability_ratio
# (= keep with p = probability_ratio). Keep all the tokens until the first rejection
r_i = torch.rand_like(probability_ratio)
is_accepted = r_i <= probability_ratio
n_matches = (~is_accepted.cumsum(dim=-1) < 1).sum() # this is `n` in algorithm 1

# Ensure we don't generate beyond max_len or an EOS token (not in algorithm 1, but needed for correct behavior)
if last_assistant_token_is_eos and n_matches == candidate_length:
n_matches -= 1
n_matches = min(n_matches, max_matches)

# Next token selection: if there is a rejection, adjust the distribution from the main model before sampling.
gamma = candidate_logits.shape[1]
p_n_plus_1 = p[:, n_matches, :]
if n_matches < gamma:
q_n_plus_1 = q[:, n_matches, :]
p_prime = torch.clamp((p_n_plus_1 - q_n_plus_1), min=0).softmax(dim=-1)
else:
p_prime = p_n_plus_1
t = torch.multinomial(p_prime, num_samples=1).squeeze(1)[None, :]

return t, n_matches


def _split_model_outputs(outputs, new_outputs, cur_len, added_len, is_decoder_attention=False):
"""
Given the (decoder/cross attentions)/(decoder hidden states) for multiple generated tokens, splits it into a tuple
Expand Down
Loading