Skip to content

Commit

Permalink
Fix candidate generation more than max_length
Browse files Browse the repository at this point in the history
  • Loading branch information
ofirzaf committed Jan 16, 2024
1 parent e5e0625 commit e979262
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 6 deletions.
8 changes: 5 additions & 3 deletions src/transformers/generation/candidate_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,10 +173,12 @@ def get_candidates(self, input_ids: torch.LongTensor) -> Tuple[torch.LongTensor,

# 1. If it is not the first round of candidate generation, prepare the inputs based on the input_ids length
# (which implicitly contains the number of accepted candidates from the previous round)
new_cur_len = input_ids.shape[-1]
max_new_tokens = min(int(self.num_assistant_tokens), self.generation_config.max_length - new_cur_len - 1)
if max_new_tokens == 0:
return input_ids, None
has_past_key_values = self.assistant_kwargs.get("past_key_values", None) is not None
if has_past_key_values:
new_cur_len = input_ids.shape[-1]

new_cache_size = new_cur_len - 1
self.assistant_kwargs["past_key_values"] = _crop_past_key_values(
self.assistant_model, self.assistant_kwargs["past_key_values"], new_cache_size - 1
Expand All @@ -190,7 +192,7 @@ def get_candidates(self, input_ids: torch.LongTensor) -> Tuple[torch.LongTensor,
# 2. Forecast next N tokens using the assistant model.
assistant_generation_kwargs = {
self.input_ids_key: input_ids,
"max_new_tokens": int(self.num_assistant_tokens),
"max_new_tokens": max_new_tokens,
"generation_config": self.generation_config,
"logits_processor": self.logits_processor,
}
Expand Down
8 changes: 5 additions & 3 deletions src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4658,9 +4658,11 @@ def assisted_decoding(
else:
selected_tokens = new_logits.argmax(dim=-1)

candidate_new_tokens = candidate_input_ids[:, -candidate_length:]
n_matches = ((~(candidate_new_tokens == selected_tokens[:, :-1])).cumsum(dim=-1) < 1).sum()

if candidate_length > 0:
candidate_new_tokens = candidate_input_ids[:, -candidate_length:]
n_matches = ((~(candidate_new_tokens == selected_tokens[:, :-1])).cumsum(dim=-1) < 1).sum()
else:
n_matches = 0
# Ensure we don't generate beyond max_len or an EOS token
if last_assistant_token_is_eos and n_matches == candidate_length:
n_matches -= 1
Expand Down

0 comments on commit e979262

Please sign in to comment.