Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize HF text generation #4814

Merged
merged 1 commit into from
Dec 5, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 23 additions & 18 deletions modules/text_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,9 +93,10 @@ def _generate_reply(question, state, stopping_strings=None, is_chat=False, escap
last_update = time.time()
yield reply

# Limit updates to 24 per second to not stress low latency networks
# Limit updates to 24 or 5 per second to avoid lag
else:
if cur_time - last_update > 0.041666666666666664:
min_update_interval = 0.2 if (shared.args.listen or shared.args.share) else 0.0417
if cur_time - last_update > min_update_interval:
last_update = cur_time
yield reply

Expand Down Expand Up @@ -218,20 +219,6 @@ def fix_galactica(s):
return s


def get_reply_from_output_ids(output_ids, input_ids, original_question, state, is_chat=False):
if shared.is_seq2seq:
reply = decode(output_ids, state['skip_special_tokens'])
else:
new_tokens = len(output_ids) - len(input_ids[0])
reply = decode(output_ids[-new_tokens:], state['skip_special_tokens'])
# Prevent LlamaTokenizer from skipping a space
if type(shared.tokenizer) in [transformers.LlamaTokenizer, transformers.LlamaTokenizerFast] and len(output_ids) > 0:
if shared.tokenizer.convert_ids_to_tokens(int(output_ids[-new_tokens])).startswith('▁'):
reply = ' ' + reply

return reply


def set_manual_seed(seed):
seed = int(seed)
if seed == -1:
Expand All @@ -242,6 +229,7 @@ def set_manual_seed(seed):
torch.cuda.manual_seed_all(seed)
elif is_torch_xpu_available():
torch.xpu.manual_seed_all(seed)

return seed


Expand Down Expand Up @@ -274,6 +262,19 @@ def apply_stopping_strings(reply, all_stop_strings):
return reply, stop_found


def get_reply_from_output_ids(output_ids, state, starting_from=0):
if shared.is_seq2seq:
reply = decode(output_ids, state['skip_special_tokens'])
else:
reply = decode(output_ids[starting_from:], state['skip_special_tokens'])
# Prevent LlamaTokenizer from skipping a space
if type(shared.tokenizer) in [transformers.LlamaTokenizer, transformers.LlamaTokenizerFast] and len(output_ids) > 0:
if shared.tokenizer.convert_ids_to_tokens(int(output_ids[starting_from])).startswith('▁'):
reply = ' ' + reply

return reply


def generate_reply_HF(question, original_question, seed, state, stopping_strings=None, is_chat=False):
generate_params = {}
for k in ['max_new_tokens', 'do_sample', 'temperature', 'temperature_last', 'top_p', 'min_p', 'typical_p', 'repetition_penalty', 'presence_penalty', 'frequency_penalty', 'repetition_penalty_range', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'tfs', 'top_a', 'mirostat_mode', 'mirostat_tau', 'mirostat_eta', 'guidance_scale']:
Expand Down Expand Up @@ -341,7 +342,7 @@ def generate_reply_HF(question, original_question, seed, state, stopping_strings
if cuda:
output = output.cuda()

yield get_reply_from_output_ids(output, input_ids, original_question, state, is_chat=is_chat)
yield get_reply_from_output_ids(output, state, starting_from=len(input_ids[0]))

# Stream the reply 1 token at a time.
# This is based on the trick of using 'stopping_criteria' to create an iterator.
Expand All @@ -357,11 +358,15 @@ def generate_with_streaming(**kwargs):
return Iteratorize(generate_with_callback, [], kwargs, callback=None)

with generate_with_streaming(**generate_params) as generator:
cumulative_reply = ''
starting_from = len(input_ids[0])
for output in generator:
if output[-1] in eos_token_ids:
break

yield get_reply_from_output_ids(output, input_ids, original_question, state, is_chat=is_chat)
cumulative_reply += get_reply_from_output_ids(output, state, starting_from=starting_from)
starting_from = len(output)
yield cumulative_reply

except Exception:
traceback.print_exc()
Expand Down