diff --git a/exllamav2/generator/streaming.py b/exllamav2/generator/streaming.py index b5433bc7..c7fa9e0f 100644 --- a/exllamav2/generator/streaming.py +++ b/exllamav2/generator/streaming.py @@ -35,8 +35,8 @@ class ExLlamaV2StreamingGenerator(ExLlamaV2BaseGenerator): no_probs: torch.Tensor = None no_logits: torch.Tensor = None - first_token = False - heal_next_token = False + heal_prefix_token = None + heal_old_tail_len = None draft_model: ExLlamaV2 or None = None draft_cache: ExLlamaV2Cache or None = None @@ -53,10 +53,12 @@ class ExLlamaV2StreamingGenerator(ExLlamaV2BaseGenerator): return_probabilities_k: int = 1 # Number of probabilities to return per token return_logits: bool = False # Return raw logits prior to softmax, per token - active_loras = [] + active_loras = None position_offsets = None input_mask = None + queued_logits = None + def __init__(self, model, cache, tokenizer, draft_model = None, draft_cache = None, num_speculative_tokens = 5): super().__init__(model, cache, tokenizer) @@ -119,7 +121,23 @@ def begin_stream(self, input_ids: torch.Tensor, gen_settings: ExLlamaV2Sampler.S self.settings = gen_settings self._gen_begin_reuse(input_ids, gen_settings) - self.heal_next_token = (token_healing and self.sequence_ids.shape[-1] >= 2) + self.queued_logits = [] + + # Initialize token healing + if token_healing and self.sequence_ids.shape[-1] >= max(2, self.tail_decode_tokens + 1): + + # Pop the last token, remembering tail len for first stream decode + + self.heal_old_tail_len = len(self.tokenizer.decode(self.sequence_ids[:, -(self.tail_decode_tokens + 1):])[0]) + self.heal_prefix_token = self.sequence_ids[:, -1:] + self.sequence_ids = self.sequence_ids[:, :-1] + self.cache.current_seq_len -= 1 + + # Start filters + + self.settings.begin_filters(self.tokenizer.get_id_to_piece_list()[self.heal_prefix_token]) + else: + self.settings.begin_filters() def stream(self) -> Union[Tuple[str, bool, torch.Tensor], @@ -145,53 +163,16 @@ def stream(self) -> Union[Tuple[str, bool, torch.Tensor], def _stream(self) -> (str, bool, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor): - # Token healing - - if self.heal_next_token: - - # Pop the last token - - old_tail = self.tokenizer.decode(self.sequence_ids[:, -self.tail_decode_tokens:])[0] - last_token = self.sequence_ids[:, -1:] - self.sequence_ids = self.sequence_ids[:, :-1] - self.cache.current_seq_len -= 1 - - # Start filters - - if self.first_token: - - self.settings.begin_filters(self.tokenizer.get_id_to_piece_list()[last_token]) - self.first_token = False - - # Regenerate the last token again, with prefix - - healed_token, _, _, eos, logits = self._gen_single_token(self.settings, prefix_token = last_token) - new_tail = self.tokenizer.decode(self.sequence_ids[:, -self.tail_decode_tokens:])[0] - self.held_text += new_tail[len(old_tail):] - - self.heal_next_token = False - - # In case we only needed the healed token - - if eos: return self.held_text, True, self.no_tokens, self.no_probs, self.no_ptokens, self.no_logits - - # Start filters when not healing - + if self.heal_old_tail_len is not None: + old_tail_len = self.heal_old_tail_len + self.heal_old_tail_len = None else: - - if self.first_token: - - self.settings.begin_filters() - self.first_token = False - - - # Decode the current tail end of the sequence - - old_tail = self.tokenizer.decode(self.sequence_ids[:1, -self.tail_decode_tokens:])[0] + old_tail_len = len(self.tokenizer.decode(self.sequence_ids[:1, -self.tail_decode_tokens:])[0]) # Generate a single token and append to the sequence - next_token, next_ptokens, next_prob, eos, next_logits = self._gen_single_token(self.settings) + next_token, next_ptokens, next_prob, eos, next_logits = self._gen_single_token(self.settings, prefix_token = self.heal_prefix_token) + self.heal_prefix_token = None # End immediately if it was a stop token @@ -201,7 +182,7 @@ def _stream(self) -> (str, bool, torch.Tensor, torch.Tensor, torch.Tensor, torch # Decode the tail end of the sequence with the added token to get (actual) characters added new_tail = self.tokenizer.decode(self.sequence_ids[:1, -(self.tail_decode_tokens + 1):])[0] - new_text = new_tail[len(old_tail):] + new_text = new_tail[old_tail_len:] next_token, new_text = self._catch_utf8(next_token, new_text) @@ -321,8 +302,6 @@ def _gen_begin(self, in_tokens, gen_settings): self.future_logits = None self.future_tokens = None - self.first_token = True - def _gen_begin_reuse(self, in_tokens, gen_settings): @@ -367,11 +346,21 @@ def _gen_feed_tokens(self, in_tokens, gen_settings): self.future_tokens = None + def append_logits(self, logits): + + assert self.draft_model is None + assert logits.shape[0] == self.sequence_ids.shape[0] + + self.queued_logits.append(logits) + def _gen_single_token(self, gen_settings, prefix_token = None): if self.draft_model is None: - logits = self.model.forward(self.sequence_ids[:, -1:], self.cache, loras = self.active_loras, input_mask = self.input_mask, position_offsets = self.position_offsets).float().cpu() + if self.queued_logits: + logits = self.queued_logits.pop() + else: + logits = self.model.forward(self.sequence_ids[:, -1:], self.cache, loras = self.active_loras, input_mask = self.input_mask, position_offsets = self.position_offsets).float().cpu() token, ptokens, prob, eos = ExLlamaV2Sampler.sample(logits, gen_settings, self.sequence_ids[:1, :], random.random(), self.tokenizer, prefix_token, self.return_probabilities_k) else: