diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index d23f7f9245d7e8..b3bc4cd8d875cc 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -4624,40 +4624,57 @@ def assisted_decoding( for i in range(candidate_length + 1): new_logits[:, i, :] = logits_warper(candidate_input_ids[:, : cur_len + i], new_logits[:, i, :]) - # 3. Obtain the next tokens from the original model logits. - if do_sample: - probs = new_logits.softmax(dim=-1) - selected_tokens = torch.multinomial(probs[0, :, :], num_samples=1).squeeze(1)[None, :] + # 3. Select the accepted tokens. There are two possible cases: + # Case 1: `do_sample=True` and we have logits for the candidates (originally from speculative decoding) + # 👉 Apply algorithm 1 from the speculative decoding paper (https://arxiv.org/pdf/2211.17192.pdf). + max_matches = max_len - cur_len - 1 + if do_sample and candidate_logits is not None: + next_sampled_tokens, n_matches = _speculative_sampling( + candidate_input_ids, + candidate_logits, + candidate_length, + new_logits, + last_assistant_token_is_eos, + max_matches, + ) + # The selected tokens include the matches plus the next sampled tokens + selected_tokens = torch.cat((candidate_input_ids[:, :n_matches], next_sampled_tokens), dim=-1) + + # Case 2: all other cases (originally from assisted generation) 👉 Compare the tokens selected from the + # original model logits with the candidate tokens. We can keep the candidate tokens until the first + # mismatch, or until the max length is reached. else: - selected_tokens = new_logits.argmax(dim=-1) + if do_sample: + probs = new_logits.softmax(dim=-1) + selected_tokens = torch.multinomial(probs[0, :, :], num_samples=1).squeeze(1)[None, :] + else: + selected_tokens = new_logits.argmax(dim=-1) - # 4. Compare the argmax from the original model logits with the assistant forecasted tokens. We can keep - # the assistant forecasted tokens until the first mismatch, or until the max length is reached. - candidate_new_tokens = candidate_input_ids[:, -candidate_length:] - n_matches = ((~(candidate_new_tokens == selected_tokens[:, :-1])).cumsum(dim=-1) < 1).sum() + candidate_new_tokens = candidate_input_ids[:, -candidate_length:] + n_matches = ((~(candidate_new_tokens == selected_tokens[:, :-1])).cumsum(dim=-1) < 1).sum() - # 5. Update variables according to the number of matching assistant tokens. Remember: the token generated + # Ensure we don't generate beyond max_len or an EOS token + if last_assistant_token_is_eos and n_matches == candidate_length: + n_matches -= 1 + n_matches = min(n_matches, max_matches) + + # 4. Update variables according to the number of matching assistant tokens. Remember: the token generated # by the model after the last candidate match is also valid, as it is generated from a correct sequence. # Because of this last token, assisted generation search reduces to a normal greedy search/sample if there # is no match. - # 5.1. Ensure we don't generate beyond max_len or an EOS token - if last_assistant_token_is_eos and n_matches == candidate_length: - n_matches -= 1 - n_matches = min(n_matches, max_len - cur_len - 1) - - # 5.2. Get the valid continuation, after the matching tokens + # 4.1. Get the valid continuation, after the matching tokens valid_tokens = selected_tokens[:, : n_matches + 1] input_ids = torch.cat((input_ids, valid_tokens), dim=-1) if streamer is not None: streamer.put(valid_tokens.cpu()) new_cur_len = input_ids.shape[-1] - # 5.3. Discard past key values relative to unused assistant tokens + # 4.2. Discard past key values relative to unused assistant tokens new_cache_size = new_cur_len - 1 outputs.past_key_values = _crop_past_key_values(self, outputs.past_key_values, new_cache_size) - # 6. Update the candidate generation strategy if needed + # 5. Update the candidate generation strategy if needed candidate_generator.update_candidate_strategy(input_ids, new_logits, n_matches) if synced_gpus and this_peer_finished: @@ -4755,6 +4772,61 @@ def assisted_decoding( return input_ids +def _speculative_sampling( + candidate_input_ids, + candidate_logits, + candidate_length, + new_logits, + last_assistant_token_is_eos, + max_matches, +): + """ + Applies sampling as in the speculative decoding paper (https://arxiv.org/pdf/2211.17192.pdf, algorithm 1). Returns + the next selected token, as well as the number of candidate matches. + + NOTE: Unless otherwise stated, the variable names match those in the paper. + """ + # Gets the probabilities from the logits. q_i and p_i denote the assistant and model probabilities of the tokens + # selected by the assistant, respectively. + q = candidate_logits.softmax(dim=-1) + q_i = q[ + :, + torch.range(0, candidate_length - 1, dtype=torch.int), + candidate_input_ids[:, -candidate_length:], + ].squeeze(0, 1) + p = new_logits.softmax(dim=-1) + p_i = p[ + :, + torch.range(0, candidate_length - 1, dtype=torch.int), + candidate_input_ids[:, -candidate_length:], + ].squeeze(0, 1) + probability_ratio = p_i / q_i + + # When probability_ratio > 1 (i.e. q_i(x) < p_i(x), or "assistant probability of the candidate token is smaller + # than the model probability for the same token"), keep the token. Otherwise reject with p = 1 - probability_ratio + # (= keep with p = probability_ratio). Keep all the tokens until the first rejection + r_i = torch.rand_like(probability_ratio) + is_accepted = r_i <= probability_ratio + n_matches = (~is_accepted.cumsum(dim=-1) < 1).sum() # this is `n` in algorithm 1 + + # Ensure we don't generate beyond max_len or an EOS token (not in algorithm 1, but needed for correct behavior) + if last_assistant_token_is_eos and n_matches == candidate_length: + n_matches -= 1 + n_matches = min(n_matches, max_matches) + + # Next token selection: if there is a rejection, adjust the distribution from the main model before sampling. + gamma = candidate_logits.shape[1] + p_n_plus_1 = p[:, n_matches, :] + if n_matches < gamma: + q_n_plus_1 = q[:, n_matches, :] + p_prime = torch.clamp((p_n_plus_1 - q_n_plus_1), min=0).softmax(dim=-1) + else: + p_prime = p_n_plus_1 + t = torch.multinomial(p_prime, num_samples=1).squeeze(1)[None, :] + + return t, n_matches + + def _split_model_outputs(outputs, new_outputs, cur_len, added_len, is_decoder_attention=False): """ Given the (decoder/cross attentions)/(decoder hidden states) for multiple generated tokens, splits it into a tuple