Skip to content

Commit

Permalink
Add temperature_last parameter (oobabooga#4472)
Browse files Browse the repository at this point in the history
  • Loading branch information
oobabooga authored Nov 4, 2023
1 parent 1ab8700 commit aa5d671
Show file tree
Hide file tree
Showing 7 changed files with 38 additions and 7 deletions.
1 change: 1 addition & 0 deletions extensions/api/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ def build_parameters(body, chat=False):
'max_tokens_second': int(body.get('max_tokens_second', 0)),
'do_sample': bool(body.get('do_sample', True)),
'temperature': float(body.get('temperature', 0.5)),
'temperature_last': bool(body.get('temperature_last', False)),
'top_p': float(body.get('top_p', 1)),
'min_p': float(body.get('min_p', 0)),
'typical_p': float(body.get('typical_p', body.get('typical', 1))),
Expand Down
7 changes: 7 additions & 0 deletions modules/loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,7 @@
loaders_samplers = {
'Transformers': {
'temperature',
'temperature_last',
'top_p',
'min_p',
'top_k',
Expand Down Expand Up @@ -184,6 +185,7 @@
},
'ExLlama_HF': {
'temperature',
'temperature_last',
'top_p',
'min_p',
'top_k',
Expand Down Expand Up @@ -245,6 +247,7 @@
},
'ExLlamav2_HF': {
'temperature',
'temperature_last',
'top_p',
'min_p',
'top_k',
Expand Down Expand Up @@ -277,6 +280,7 @@
},
'AutoGPTQ': {
'temperature',
'temperature_last',
'top_p',
'min_p',
'top_k',
Expand Down Expand Up @@ -313,6 +317,7 @@
},
'GPTQ-for-LLaMa': {
'temperature',
'temperature_last',
'top_p',
'min_p',
'top_k',
Expand Down Expand Up @@ -365,6 +370,7 @@
},
'llamacpp_HF': {
'temperature',
'temperature_last',
'top_p',
'min_p',
'top_k',
Expand Down Expand Up @@ -404,6 +410,7 @@
},
'AutoAWQ': {
'temperature',
'temperature_last',
'top_p',
'min_p',
'top_k',
Expand Down
1 change: 1 addition & 0 deletions modules/presets.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ def default_preset():
return {
'do_sample': True,
'temperature': 1,
'temperature_last': False,
'top_p': 1,
'min_p': 0,
'top_k': 0,
Expand Down
32 changes: 26 additions & 6 deletions modules/sampler_hijack.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

global_scores = None


class MinPLogitsWarper(LogitsWarper):
def __init__(self, min_p: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
if min_p < 0 or min_p > 1.0:
Expand Down Expand Up @@ -41,6 +42,7 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to
scores = scores.masked_fill(indices_to_remove, self.filter_value)
return scores


class TailFreeLogitsWarper(LogitsWarper):
def __init__(self, tfs: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
tfs = float(tfs)
Expand Down Expand Up @@ -214,19 +216,36 @@ def get_logits_warper_patch(self, generation_config):
if not isinstance(warper, TemperatureLogitsWarper):
warpers.remove(warper)
else:
if generation_config.tfs is not None and 0.0 <= generation_config.tfs <= 1.0:
if generation_config.tfs is not None and 0.0 <= generation_config.tfs < 1.0:
warpers_to_add.append(TailFreeLogitsWarper(tfs=generation_config.tfs, min_tokens_to_keep=min_tokens_to_keep))
if generation_config.top_a is not None and 0.0 <= generation_config.top_a <= 1.0:
if generation_config.top_a is not None and 0.0 < generation_config.top_a <= 1.0:
warpers_to_add.append(TopALogitsWarper(top_a=generation_config.top_a, min_tokens_to_keep=min_tokens_to_keep))
if generation_config.min_p is not None and 0.0 <= generation_config.min_p <= 1.0:
if generation_config.min_p is not None and 0.0 < generation_config.min_p <= 1.0:
warpers_to_add.append(MinPLogitsWarper(min_p=generation_config.min_p, min_tokens_to_keep=min_tokens_to_keep))

if warpers and isinstance(warpers[-1], LogitNormalization):
warpers = warpers[:-1] + warpers_to_add + [warpers[-1]]
if len(warpers) > 0 and isinstance(warpers[-1], LogitNormalization):
normalize = warpers.pop(-1)
else:
warpers += warpers_to_add
normalize = None

warpers += warpers_to_add
if generation_config.temperature_last:
temperature_idx = None
for i in range(len(warpers)):
if warpers[i].__class__.__name__ == 'TemperatureLogitsWarper':
temperature_idx = i
break

if temperature_idx is not None:
warpers = warpers[:temperature_idx] + warpers[temperature_idx + 1:] + [warpers[temperature_idx]]
warpers = LogitsProcessorList(warpers)

if normalize is not None:
warpers.append(normalize)

warpers.append(SpyLogitsWarper())
# for i in range(len(warpers)):
# print(warpers[i].__class__.__name__)
return warpers


Expand Down Expand Up @@ -261,6 +280,7 @@ def generation_config_init_patch(self, **kwargs):
self.repetition_penalty_range = kwargs.pop("repetition_penalty_range", 0)
self.presence_penalty = kwargs.pop("presence_penalty", 0)
self.frequency_penalty = kwargs.pop("frequency_penalty", 0)
self.temperature_last = kwargs.pop("temperature_last", False)


def hijack_samplers():
Expand Down
2 changes: 1 addition & 1 deletion modules/text_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,7 @@ def apply_stopping_strings(reply, all_stop_strings):

def generate_reply_HF(question, original_question, seed, state, stopping_strings=None, is_chat=False):
generate_params = {}
for k in ['max_new_tokens', 'do_sample', 'temperature', 'top_p', 'min_p', 'typical_p', 'repetition_penalty', 'presence_penalty', 'frequency_penalty', 'repetition_penalty_range', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'tfs', 'top_a', 'mirostat_mode', 'mirostat_tau', 'mirostat_eta', 'guidance_scale']:
for k in ['max_new_tokens', 'do_sample', 'temperature', 'temperature_last', 'top_p', 'min_p', 'typical_p', 'repetition_penalty', 'presence_penalty', 'frequency_penalty', 'repetition_penalty_range', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'tfs', 'top_a', 'mirostat_mode', 'mirostat_tau', 'mirostat_eta', 'guidance_scale']:
generate_params[k] = state[k]

if state['negative_prompt'] != '':
Expand Down
1 change: 1 addition & 0 deletions modules/ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ def list_interface_input_elements():
'max_tokens_second',
'seed',
'temperature',
'temperature_last',
'top_p',
'min_p',
'top_k',
Expand Down
1 change: 1 addition & 0 deletions modules/ui_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ def create_ui(default_preset):
shared.gradio['mirostat_mode'] = gr.Slider(0, 2, step=1, value=generate_params['mirostat_mode'], label='mirostat_mode', info='mode=1 is for llama.cpp only.')
shared.gradio['mirostat_tau'] = gr.Slider(0, 10, step=0.01, value=generate_params['mirostat_tau'], label='mirostat_tau')
shared.gradio['mirostat_eta'] = gr.Slider(0, 1, step=0.01, value=generate_params['mirostat_eta'], label='mirostat_eta')
shared.gradio['temperature_last'] = gr.Checkbox(value=generate_params['temperature_last'], label='temperature_last', info='Makes temperature the last sampler instead of the first.')
shared.gradio['do_sample'] = gr.Checkbox(value=generate_params['do_sample'], label='do_sample')
shared.gradio['seed'] = gr.Number(value=shared.settings['seed'], label='Seed (-1 for random)')
with gr.Accordion('Other parameters', open=False):
Expand Down

0 comments on commit aa5d671

Please sign in to comment.