Skip to content

Commit

Permalink
Merge pull request #4849 from oobabooga/dev
Browse files Browse the repository at this point in the history
Merge dev branch
  • Loading branch information
oobabooga authored Dec 8, 2023
2 parents 2694ef4 + 00aedf9 commit 884871c
Show file tree
Hide file tree
Showing 14 changed files with 176 additions and 29 deletions.
17 changes: 15 additions & 2 deletions download-model.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,10 +127,23 @@ def get_download_links_from_huggingface(self, model, branch, text_only=False, sp
if classifications[i] in ['pytorch', 'pt']:
links.pop(i)

# For GGUF, try to download only the Q4_K_M if no specific file is specified.
# If not present, exclude all GGUFs, as that's likely a repository with both
# GGUF and fp16 files.
if has_gguf and specific_file is None:
has_q4km = False
for i in range(len(classifications) - 1, -1, -1):
if 'q4_k_m' not in links[i].lower():
links.pop(i)
if 'q4_k_m' in links[i].lower():
has_q4km = True

if has_q4km:
for i in range(len(classifications) - 1, -1, -1):
if 'q4_k_m' not in links[i].lower():
links.pop(i)
else:
for i in range(len(classifications) - 1, -1, -1):
if links[i].lower().endswith('.gguf'):
links.pop(i)

is_llamacpp = has_gguf and specific_file is not None
return links, sha256, is_lora, is_llamacpp
Expand Down
2 changes: 1 addition & 1 deletion extensions/openai/completions.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ def chat_completions_common(body: dict, is_legacy: bool = False, stream=False) -

max_tokens = generate_params['max_new_tokens']
if max_tokens in [None, 0]:
generate_params['max_new_tokens'] = 200
generate_params['max_new_tokens'] = 512
generate_params['auto_max_new_tokens'] = True

requested_model = generate_params.pop('model')
Expand Down
2 changes: 1 addition & 1 deletion extensions/openai/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ class GenerationOptions(BaseModel):
min_p: float = 0
top_k: int = 0
repetition_penalty: float = 1
repetition_penalty_range: int = 0
repetition_penalty_range: int = 1024
typical_p: float = 1
tfs: float = 1
top_a: float = 0
Expand Down
19 changes: 18 additions & 1 deletion modules/exllama.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,10 +165,19 @@ def generate_with_streaming(self, prompt, state):
if has_leading_space:
decoded_text = ' ' + decoded_text

yield decoded_text
# Check the partial unicode character
if chr(0xfffd) in decoded_text:
is_last = i == max_new_tokens - 1
is_stopping = token.item() == self.generator.tokenizer.eos_token_id or shared.stop_everything
# If we are not at the end of the generation, we skip this token
if not (is_last or is_stopping):
continue

if token.item() == self.generator.tokenizer.eos_token_id or shared.stop_everything:
break

yield decoded_text

# Case 2: CFG
# Copied from https://github.com/turboderp/exllama/blob/master/example_cfg.py
else:
Expand Down Expand Up @@ -205,6 +214,14 @@ def generate_with_streaming(self, prompt, state):
if has_leading_space:
decoded_text = ' ' + decoded_text

# Check the partial unicode character
if chr(0xfffd) in decoded_text:
is_last = i == max_new_tokens - 1
is_stopping = token.item() == self.tokenizer.eos_token_id or shared.stop_everything
# If we are not at the end of the generation, we skip this token
if not (is_last or is_stopping):
continue

yield decoded_text
if token.item() == self.tokenizer.eos_token_id or shared.stop_everything:
break
Expand Down
10 changes: 9 additions & 1 deletion modules/exllamav2.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,11 +138,19 @@ def generate_with_streaming(self, prompt, state):
if has_leading_space:
decoded_text = ' ' + decoded_text

yield decoded_text
# Check the partial unicode character
if chr(0xfffd) in decoded_text:
is_last = i == max_new_tokens - 1
is_stopping = token.item() == self.tokenizer.eos_token_id or shared.stop_everything
# If we are not at the end of the generation, we skip this token
if not (is_last or is_stopping):
continue

if token.item() == self.tokenizer.eos_token_id or shared.stop_everything:
break

yield decoded_text

def generate(self, prompt, state):
output = ''
for output in self.generate_with_streaming(prompt, state):
Expand Down
42 changes: 42 additions & 0 deletions modules/loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,11 @@
'no_mmap',
'mlock'
],
'QuIP#': [
'trust_remote_code',
'no_use_fast',
'no_flash_attn',
]
})

loaders_samplers = {
Expand Down Expand Up @@ -453,6 +458,43 @@
'skip_special_tokens',
'auto_max_new_tokens',
},
'QuIP#': {
'temperature',
'temperature_last',
'top_p',
'min_p',
'top_k',
'typical_p',
'epsilon_cutoff',
'eta_cutoff',
'tfs',
'top_a',
'repetition_penalty',
'presence_penalty',
'frequency_penalty',
'repetition_penalty_range',
'encoder_repetition_penalty',
'no_repeat_ngram_size',
'min_length',
'seed',
'do_sample',
'penalty_alpha',
'num_beams',
'length_penalty',
'early_stopping',
'mirostat_mode',
'mirostat_tau',
'mirostat_eta',
'grammar_file_row',
'grammar_string',
'guidance_scale',
'negative_prompt',
'ban_eos_token',
'custom_token_bans',
'add_bos_token',
'skip_special_tokens',
'auto_max_new_tokens',
},
}

loaders_model_types = {
Expand Down
34 changes: 34 additions & 0 deletions modules/models.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import gc
import logging
import os
import re
import time
Expand All @@ -23,6 +24,7 @@
from modules import RoPE, llama_attn_hijack, sampler_hijack
from modules.logging_colors import logger
from modules.models_settings import get_model_metadata
from modules.relative_imports import RelativeImport

transformers.logging.set_verbosity_error()

Expand Down Expand Up @@ -69,6 +71,7 @@ def load_model(model_name, loader=None):
'ExLlamav2_HF': ExLlamav2_HF_loader,
'ctransformers': ctransformers_loader,
'AutoAWQ': AutoAWQ_loader,
'QuIP#': QuipSharp_loader,
}

metadata = get_model_metadata(model_name)
Expand Down Expand Up @@ -321,6 +324,37 @@ def AutoAWQ_loader(model_name):
return model


def QuipSharp_loader(model_name):
try:
with RelativeImport("repositories/quip-sharp"):
from lib.utils.unsafe_import import model_from_hf_path
except:
logger.error(
"\nQuIP# has not been found. It must be installed manually for now.\n"
"For instructions on how to do that, please consult:\n"
"https://github.com/oobabooga/text-generation-webui/pull/4803\n"
)
return None, None

# This fixes duplicate logging messages after the import above.
handlers = logging.getLogger().handlers
if len(handlers) > 1:
logging.getLogger().removeHandler(handlers[1])

model_dir = Path(f'{shared.args.model_dir}/{model_name}')
if not all((model_dir / file).exists() for file in ['tokenizer_config.json', 'special_tokens_map.json', 'tokenizer.model']):
logger.error(f"Could not load the model because the tokenizer files could not be found in the model folder. Please download the following files from the original (unquantized) model into {model_dir}: special_tokens_map.json, tokenizer.json, tokenizer.model, tokenizer_config.json.")
return None, None

model, model_str = model_from_hf_path(
model_dir,
use_cuda_graph=False,
use_flash_attn=not shared.args.no_flash_attn
)

return model


def GPTQ_loader(model_name):

# Monkey patch
Expand Down
25 changes: 17 additions & 8 deletions modules/models_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,14 +33,24 @@ def get_model_metadata(model):
for k in settings[pat]:
model_settings[k] = settings[pat][k]


path = Path(f'{shared.args.model_dir}/{model}/config.json')
if path.exists():
hf_metadata = json.loads(open(path, 'r').read())
else:
hf_metadata = None

if 'loader' not in model_settings:
loader = infer_loader(model, model_settings)
if 'wbits' in model_settings and type(model_settings['wbits']) is int and model_settings['wbits'] > 0:
loader = 'AutoGPTQ'
if hf_metadata is not None and 'quip_params' in hf_metadata:
model_settings['loader'] = 'QuIP#'
else:
loader = infer_loader(model, model_settings)
if 'wbits' in model_settings and type(model_settings['wbits']) is int and model_settings['wbits'] > 0:
loader = 'AutoGPTQ'

model_settings['loader'] = loader
model_settings['loader'] = loader

# Read GGUF metadata
# GGUF metadata
if model_settings['loader'] in ['llama.cpp', 'llamacpp_HF', 'ctransformers']:
path = Path(f'{shared.args.model_dir}/{model}')
if path.is_file():
Expand All @@ -57,9 +67,8 @@ def get_model_metadata(model):
model_settings['rope_freq_base'] = metadata['llama.rope.freq_base']

else:
# Read transformers metadata
path = Path(f'{shared.args.model_dir}/{model}/config.json')
if path.exists():
# Transformers metadata
if hf_metadata is not None:
metadata = json.loads(open(path, 'r').read())
if 'max_position_embeddings' in metadata:
model_settings['truncation_length'] = metadata['max_position_embeddings']
Expand Down
2 changes: 1 addition & 1 deletion modules/presets.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def default_preset():
'repetition_penalty': 1,
'presence_penalty': 0,
'frequency_penalty': 0,
'repetition_penalty_range': 0,
'repetition_penalty_range': 1024,
'typical_p': 1,
'tfs': 1,
'top_a': 0,
Expand Down
4 changes: 3 additions & 1 deletion modules/shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
'prompt-default': 'QA',
'prompt-notebook': 'QA',
'preset': 'simple-1',
'max_new_tokens': 200,
'max_new_tokens': 512,
'max_new_tokens_min': 1,
'max_new_tokens_max': 4096,
'negative_prompt': '',
Expand Down Expand Up @@ -241,6 +241,8 @@ def fix_loader_name(name):
return 'ctransformers'
elif name in ['autoawq', 'awq', 'auto-awq']:
return 'AutoAWQ'
elif name in ['quip#', 'quip-sharp', 'quipsharp', 'quip_sharp']:
return 'QuIP#'


def add_extension(name, last=False):
Expand Down
24 changes: 13 additions & 11 deletions modules/text_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,14 +264,10 @@ def apply_stopping_strings(reply, all_stop_strings):


def get_reply_from_output_ids(output_ids, state, starting_from=0):
if shared.is_seq2seq:
reply = decode(output_ids, state['skip_special_tokens'])
else:
reply = decode(output_ids[starting_from:], state['skip_special_tokens'])
# Prevent LlamaTokenizer from skipping a space
if type(shared.tokenizer) in [transformers.LlamaTokenizer, transformers.LlamaTokenizerFast] and len(output_ids) > 0:
if shared.tokenizer.convert_ids_to_tokens(int(output_ids[starting_from])).startswith('▁'):
reply = ' ' + reply
reply = decode(output_ids[starting_from:], state['skip_special_tokens'])
if type(shared.tokenizer) in [transformers.LlamaTokenizer, transformers.LlamaTokenizerFast] and len(output_ids) > starting_from:
if shared.tokenizer.convert_ids_to_tokens(int(output_ids[starting_from])).startswith('▁'):
reply = ' ' + reply

return reply

Expand Down Expand Up @@ -343,7 +339,8 @@ def generate_reply_HF(question, original_question, seed, state, stopping_strings
if cuda:
output = output.cuda()

yield get_reply_from_output_ids(output, state, starting_from=len(input_ids[0]))
starting_from = 0 if shared.is_seq2seq else len(input_ids[0])
yield get_reply_from_output_ids(output, state, starting_from=starting_from)

# Stream the reply 1 token at a time.
# This is based on the trick of using 'stopping_criteria' to create an iterator.
Expand All @@ -360,12 +357,17 @@ def generate_with_streaming(**kwargs):

with generate_with_streaming(**generate_params) as generator:
cumulative_reply = ''
starting_from = len(input_ids[0])
starting_from = 0 if shared.is_seq2seq else len(input_ids[0])
for output in generator:
if output[-1] in eos_token_ids:
break

cumulative_reply += get_reply_from_output_ids(output, state, starting_from=starting_from)
new_content = get_reply_from_output_ids(output, state, starting_from=starting_from)
# check the partial unicode character
if chr(0xfffd) in new_content:
continue

cumulative_reply += new_content
starting_from = len(output)
yield cumulative_reply

Expand Down
10 changes: 9 additions & 1 deletion one_click.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import os
import platform
import re
import signal
import site
import subprocess
import sys
Expand All @@ -27,6 +28,13 @@
flags = f"{' '.join([flag for flag in sys.argv[1:] if flag != '--update'])} {CMD_FLAGS}"


def signal_handler(sig, frame):
sys.exit(0)


signal.signal(signal.SIGINT, signal_handler)


def is_linux():
return sys.platform.startswith("linux")

Expand Down Expand Up @@ -210,7 +218,7 @@ def install_webui():
elif is_linux() and (choice == "C" or choice == "N"):
install_pytorch = "python -m pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu"
elif choice == "D":
install_pytorch = "python -m pip install torch==2.0.1a0 torchvision==0.15.2a0 intel_extension_for_pytorch==2.0.110+xpu -f https://developer.intel.com/ipex-whl-stable-xpu"
install_pytorch = "python -m pip install torch==2.0.1a0 torchvision==0.15.2a0 intel_extension_for_pytorch==2.0.110+xpu --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/"

# Install Git and then Pytorch
run_cmd(f"{install_git} && {install_pytorch} && python -m pip install py-cpuinfo==9.0.0", assert_success=True, environment=True)
Expand Down
12 changes: 12 additions & 0 deletions server.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

import json
import os
import signal
import sys
import time
from functools import partial
Expand Down Expand Up @@ -55,6 +56,17 @@
from modules.utils import gradio


def signal_handler(sig, frame):
logger.info("Received Ctrl+C. Shutting down Text generation web UI gracefully.")
if 'interface' in shared.gradio:
shared.gradio['interface'].close()

sys.exit(0)


signal.signal(signal.SIGINT, signal_handler)


def create_interface():

title = 'Text generation web UI'
Expand Down
Loading

0 comments on commit 884871c

Please sign in to comment.