diff --git a/README.md b/README.md index 22ceeeb8ea..fb22d09945 100644 --- a/README.md +++ b/README.md @@ -10,33 +10,29 @@ Its goal is to become the [AUTOMATIC1111/stable-diffusion-webui](https://github. ## Features -* Multiple backends for text generation in a single UI and API, including [Transformers](https://github.com/huggingface/transformers), [llama.cpp](https://github.com/ggerganov/llama.cpp) (through [llama-cpp-python](https://github.com/abetlen/llama-cpp-python)), [ExLlamaV2](https://github.com/turboderp/exllamav2), [AutoGPTQ](https://github.com/PanQiWei/AutoGPTQ), and [TensorRT-LLM](https://github.com/NVIDIA/TensorRT-LLM). [AutoAWQ](https://github.com/casper-hansen/AutoAWQ), [HQQ](https://github.com/mobiusml/hqq), and [AQLM](https://github.com/Vahe1994/AQLM) are also supported through the Transformers loader. -* OpenAI-compatible API server with Chat and Completions endpoints – see the [examples](https://github.com/oobabooga/text-generation-webui/wiki/12-%E2%80%90-OpenAI-API#examples). -* Automatic prompt formatting for each model using the Jinja2 template in its metadata. -* Three chat modes: `instruct`, `chat-instruct`, and `chat`, allowing for both instruction-following and casual conversations with characters. `chat-instruct` mode automatically applies the model's template to the chat prompt, ensuring high-quality outputs without manual setup. -* "Past chats" menu to quickly switch between conversations and start new ones. -* Free-form generation in the Default/Notebook tabs without being limited to chat turns. Send formatted chat conversations from the Chat tab to these tabs. -* Multiple sampling parameters and generation options for sophisticated text generation control. -* Easy switching between different models through the UI without restarting, using the "Model" tab. -* Simple LoRA fine-tuning tool to customize models with your data. -* All in one folder. The requirements are installed in a self-contained `installer_files` folder that doesn't interfere with the system's environment. -* Extensions support, including numerous built-in and user-contributed extensions. See [the wiki](https://github.com/oobabooga/text-generation-webui/wiki/07-%E2%80%90-Extensions) and [the extensions directory](https://github.com/oobabooga/text-generation-webui-extensions) for details. +- Supports multiple text generation backends in one UI/API, including [Transformers](https://github.com/huggingface/transformers), [llama.cpp](https://github.com/ggerganov/llama.cpp), and [ExLlamaV2](https://github.com/turboderp/exllamav2). [TensorRT-LLM](https://github.com/NVIDIA/TensorRT-LLM), [AutoGPTQ](https://github.com/PanQiWei/AutoGPTQ), [AutoAWQ](https://github.com/casper-hansen/AutoAWQ), [HQQ](https://github.com/mobiusml/hqq), and [AQLM](https://github.com/Vahe1994/AQLM) are also supported but you need to install them manually. +- OpenAI-compatible API with Chat and Completions endpoints – see [examples](https://github.com/oobabooga/text-generation-webui/wiki/12-%E2%80%90-OpenAI-API#examples). +- Automatic prompt formatting using Jinja2 templates. +- Three chat modes: `instruct`, `chat-instruct`, and `chat`, with automatic prompt templates in `chat-instruct`. +- "Past chats" menu to quickly switch between conversations. +- Free-form text generation in the Default/Notebook tabs without being limited to chat turns. You can send formatted conversations from the Chat tab to these. +- Multiple sampling parameters and generation options for sophisticated text generation control. +- Switch between different models easily in the UI without restarting. +- Simple LoRA fine-tuning tool. +- Requirements installed in a self-contained `installer_files` directory that doesn't interfere with the system environment. +- Extension support, with numerous built-in and user-contributed extensions available. See the [wiki](https://github.com/oobabooga/text-generation-webui/wiki/07-%E2%80%90-Extensions) and [extensions directory](https://github.com/oobabooga/text-generation-webui-extensions) for details. ## How to install -1) Clone or [download](https://github.com/oobabooga/text-generation-webui/archive/refs/heads/main.zip) the repository. -2) Run the `start_linux.sh`, `start_windows.bat`, `start_macos.sh`, or `start_wsl.bat` script depending on your OS. +1) Clone or [download the repository](https://github.com/oobabooga/text-generation-webui/archive/refs/heads/main.zip). +2) Run the script that matches your OS: `start_linux.sh`, `start_windows.bat`, `start_macos.sh`, or `start_wsl.bat`. 3) Select your GPU vendor when asked. 4) Once the installation ends, browse to `http://localhost:7860`. 5) Have fun! -To restart the web UI in the future, run the `start_` script again. +To restart the web UI later, just run the same `start_` script. If you need to reinstall, delete the `installer_files` folder created during setup and run the script again. -This script creates an `installer_files` folder where it sets up the project's requirements. If you need to reinstall the requirements, just delete that folder and start the web UI again. - -The script accepts command-line flags, such as `./start_linux.sh --help`. Alternatively, you can edit the `CMD_FLAGS.txt` file with a text editor and add your flags there, such as `--api` in case you need to use the API. - -To get updates in the future, run `update_wizard_linux.sh`, `update_wizard_windows.bat`, `update_wizard_macos.sh`, or `update_wizard_wsl.bat`. +You can use command-line flags, like `./start_linux.sh --help`, or add them to `CMD_FLAGS.txt` (such as `--api` to enable API use). To update the project, run `update_wizard_linux.sh`, `update_wizard_windows.bat`, `update_wizard_macos.sh`, or `update_wizard_wsl.bat`.
@@ -80,12 +76,12 @@ conda activate textgen | System | GPU | Command | |--------|---------|---------| -| Linux/WSL | NVIDIA | `pip3 install torch==2.2.2 torchvision==0.17.2 torchaudio==2.2.2 --index-url https://download.pytorch.org/whl/cu121` | -| Linux/WSL | CPU only | `pip3 install torch==2.2.2 torchvision==0.17.2 torchaudio==2.2.2 --index-url https://download.pytorch.org/whl/cpu` | -| Linux | AMD | `pip3 install torch==2.2.2 torchvision==0.17.2 torchaudio==2.2.2 --index-url https://download.pytorch.org/whl/rocm5.6` | -| MacOS + MPS | Any | `pip3 install torch==2.2.2 torchvision==0.17.2 torchaudio==2.2.2` | -| Windows | NVIDIA | `pip3 install torch==2.2.2 torchvision==0.17.2 torchaudio==2.2.2 --index-url https://download.pytorch.org/whl/cu121` | -| Windows | CPU only | `pip3 install torch==2.2.2 torchvision==0.17.2 torchaudio==2.2.2` | +| Linux/WSL | NVIDIA | `pip3 install torch==2.4.1 torchvision==0.19.1 torchaudio==2.4.1 --index-url https://download.pytorch.org/whl/cu121` | +| Linux/WSL | CPU only | `pip3 install torch==2.4.1 torchvision==0.19.1 torchaudio==2.4.1 --index-url https://download.pytorch.org/whl/cpu` | +| Linux | AMD | `pip3 install torch==2.4.1 torchvision==0.19.1 torchaudio==2.4.1 --index-url https://download.pytorch.org/whl/rocm6.1` | +| MacOS + MPS | Any | `pip3 install torch==2.4.1 torchvision==0.19.1 torchaudio==2.4.1` | +| Windows | NVIDIA | `pip3 install torch==2.4.1 torchvision==0.19.1 torchaudio==2.4.1 --index-url https://download.pytorch.org/whl/cu121` | +| Windows | CPU only | `pip3 install torch==2.4.1 torchvision==0.19.1 torchaudio==2.4.1` | The up-to-date commands can be found here: https://pytorch.org/get-started/locally/. @@ -150,7 +146,7 @@ Then browse to 1) For Kepler GPUs and older, you will need to install CUDA 11.8 instead of 12: ``` -pip3 install torch==2.2.2 torchvision==0.17.2 torchaudio==2.2.2 --index-url https://download.pytorch.org/whl/cu118 +pip3 install torch==2.4.1 torchvision==0.19.1 torchaudio==2.4.1 --index-url https://download.pytorch.org/whl/cu118 conda install -y -c "nvidia/label/cuda-11.8.0" cuda-runtime ``` diff --git a/docs/12 - OpenAI API.md b/docs/12 - OpenAI API.md index b00a1f3447..9b4f89bf17 100644 --- a/docs/12 - OpenAI API.md +++ b/docs/12 - OpenAI API.md @@ -19,7 +19,7 @@ Add `--api` to your command-line flags. ### Examples -For the documentation with all the parameters and their types, consult `http://127.0.0.1:5000/docs` or the [typing.py](https://github.com/oobabooga/text-generation-webui/blob/main/extensions/openai/typing.py) file. +For the documentation with all the endpoints, parameters and their types, consult `http://127.0.0.1:5000/docs` or the [typing.py](https://github.com/oobabooga/text-generation-webui/blob/main/extensions/openai/typing.py) file. The official examples in the [OpenAI documentation](https://platform.openai.com/docs/api-reference) should also work, and the same parameters apply (although the API here has more optional parameters). @@ -114,6 +114,30 @@ curl -k http://127.0.0.1:5000/v1/internal/logits \ }' ``` +#### List models + +```shell +curl -k http://127.0.0.1:5000/v1/internal/model/list \ + -H "Content-Type: application/json" +``` + +#### Load model + +```shell +curl -k http://127.0.0.1:5000/v1/internal/model/load \ + -H "Content-Type: application/json" \ + -d '{ + "model_name": "model_name", + "args": { + "load_in_4bit": true, + "n_gpu_layers": 12 + }, + "settings": { + "instruction_template": "Alpaca" + } + }' +``` + #### Python chat example ```python diff --git a/extensions/Training_PRO/script.py b/extensions/Training_PRO/script.py index 8f29646232..5365154c35 100644 --- a/extensions/Training_PRO/script.py +++ b/extensions/Training_PRO/script.py @@ -241,7 +241,7 @@ def ui(): stride_length = gr.Slider(label='Stride', minimum=1, maximum=2048, value=512, step=1, info='Used to make the evaluation faster at the cost of accuracy. 1 = slowest but most accurate. 512 is a common value.') with gr.Column(): - max_length = gr.Slider(label='max_length', minimum=0, maximum=shared.settings['truncation_length_max'], value=0, step=1, info='The context for each evaluation. If set to 0, the maximum context length for the model will be used.') + max_length = gr.Number(label='max_length', precision=0, step=256, value=0, info='The context for each evaluation. If set to 0, the maximum context length for the model will be used.') with gr.Row(): start_current_evaluation = gr.Button("Evaluate loaded model") diff --git a/extensions/openai/completions.py b/extensions/openai/completions.py index 646dee2d91..6bd8f409a0 100644 --- a/extensions/openai/completions.py +++ b/extensions/openai/completions.py @@ -154,8 +154,9 @@ def convert_history(history): elif item['type'] == 'text' and isinstance(item['text'], str): content = item['text'] - if image_url and content: + if image_url: new_history.append({"image_url": image_url, "role": "user"}) + if content: new_history.append({"content": content, "role": "user"}) else: new_history.append(entry) @@ -234,7 +235,7 @@ def chat_completions_common(body: dict, is_legacy: bool = False, stream=False, p raise InvalidRequestError(message="messages: missing content", param='messages') # Chat Completions - object_type = 'chat.completions' if not stream else 'chat.completions.chunk' + object_type = 'chat.completion' if not stream else 'chat.completion.chunk' created_time = int(time.time()) cmpl_id = "chatcmpl-%d" % (int(time.time() * 1000000000)) resp_list = 'data' if is_legacy else 'choices' diff --git a/extensions/openai/typing.py b/extensions/openai/typing.py index 4015f6a1ce..f63c1f3911 100644 --- a/extensions/openai/typing.py +++ b/extensions/openai/typing.py @@ -37,6 +37,8 @@ class GenerationOptions(BaseModel): dry_base: float = 1.75 dry_allowed_length: int = 2 dry_sequence_breakers: str = '"\\n", ":", "\\"", "*"' + xtc_threshold: float = 0.1 + xtc_probability: float = 0 truncation_length: int = 0 max_tokens_second: int = 0 prompt_lookup_num_tokens: int = 0 diff --git a/js/main.js b/js/main.js index 899bd8f02d..71489f14e3 100644 --- a/js/main.js +++ b/js/main.js @@ -600,4 +600,15 @@ headerBar.addEventListener("click", (e) => { } }); +//------------------------------------------------ +// Add a confirmation dialog when leaving the page +// Useful to avoid data loss +//------------------------------------------------ +window.addEventListener('beforeunload', function (event) { + // Cancel the event + event.preventDefault(); + // Chrome requires returnValue to be set + event.returnValue = ''; +}); + moveToChatTab(); diff --git a/modules/LoRA.py b/modules/LoRA.py index 117022cfc8..4fd144ba00 100644 --- a/modules/LoRA.py +++ b/modules/LoRA.py @@ -1,7 +1,6 @@ from pathlib import Path import torch -from peft import PeftModel from transformers import is_torch_xpu_available import modules.shared as shared @@ -85,6 +84,9 @@ def add_lora_autogptq(lora_names): def add_lora_transformers(lora_names): + + from peft import PeftModel + prior_set = set(shared.lora_names) added_set = set(lora_names) - prior_set removed_set = prior_set - set(lora_names) diff --git a/modules/chat.py b/modules/chat.py index 00c4ffa956..b81cfea6ee 100644 --- a/modules/chat.py +++ b/modules/chat.py @@ -1059,7 +1059,12 @@ def handle_start_new_chat_click(state): convert_to_markdown.cache_clear() - return [history, html, gr.update(choices=histories, value=histories[0][1])] + if len(histories) > 0: + past_chats_update = gr.update(choices=histories, value=histories[0][1]) + else: + past_chats_update = gr.update(choices=histories) + + return [history, html, past_chats_update] def handle_delete_chat_confirm_click(state): @@ -1110,10 +1115,15 @@ def handle_upload_chat_history(load_chat_history, state): convert_to_markdown.cache_clear() + if len(histories) > 0: + past_chats_update = gr.update(choices=histories, value=histories[0][1]) + else: + past_chats_update = gr.update(choices=histories) + return [ history, html, - gr.update(choices=histories, value=histories[0][1]) + past_chats_update ] @@ -1132,6 +1142,11 @@ def handle_character_menu_change(state): convert_to_markdown.cache_clear() + if len(histories) > 0: + past_chats_update = gr.update(choices=histories, value=histories[0][1]) + else: + past_chats_update = gr.update(choices=histories) + return [ history, html, @@ -1140,7 +1155,7 @@ def handle_character_menu_change(state): picture, greeting, context, - gr.update(choices=histories, value=histories[0][1]), + past_chats_update, ] @@ -1151,12 +1166,17 @@ def handle_mode_change(state): convert_to_markdown.cache_clear() + if len(histories) > 0: + past_chats_update = gr.update(choices=histories, value=histories[0][1]) + else: + past_chats_update = gr.update(choices=histories) + return [ history, html, gr.update(visible=state['mode'] != 'instruct'), gr.update(visible=state['mode'] == 'chat-instruct'), - gr.update(choices=histories, value=histories[0][1]) + past_chats_update ] diff --git a/modules/exllamav2.py b/modules/exllamav2.py index a770e34257..0498c4882e 100644 --- a/modules/exllamav2.py +++ b/modules/exllamav2.py @@ -7,6 +7,7 @@ ExLlamaV2Cache, ExLlamaV2Cache_8bit, ExLlamaV2Cache_Q4, + ExLlamaV2Cache_TP, ExLlamaV2Config, ExLlamaV2Tokenizer ) @@ -18,14 +19,6 @@ try: import flash_attn -except ModuleNotFoundError: - logger.warning( - 'You are running ExLlamaV2 without flash-attention. This will cause the VRAM usage ' - 'to be a lot higher than it could be.\n' - 'Try installing flash-attention following the instructions here: ' - 'https://github.com/Dao-AILab/flash-attention#installation-and-features' - ) - pass except Exception: logger.warning('Failed to load flash-attention due to the following error:\n') traceback.print_exc() @@ -54,21 +47,30 @@ def from_pretrained(self, path_to_model): model = ExLlamaV2(config) - if not shared.args.autosplit: - split = None - if shared.args.gpu_split: - split = [float(alloc) for alloc in shared.args.gpu_split.split(",")] + split = None + if shared.args.gpu_split: + split = [float(alloc) for alloc in shared.args.gpu_split.split(",")] + if shared.args.enable_tp: + model.load_tp(split) + elif not shared.args.autosplit: model.load(split) + # Determine the correct cache type if shared.args.cache_8bit: - cache = ExLlamaV2Cache_8bit(model, lazy=shared.args.autosplit) + cache_type = ExLlamaV2Cache_8bit elif shared.args.cache_4bit: - cache = ExLlamaV2Cache_Q4(model, lazy=shared.args.autosplit) + cache_type = ExLlamaV2Cache_Q4 else: - cache = ExLlamaV2Cache(model, lazy=shared.args.autosplit) + cache_type = ExLlamaV2Cache - if shared.args.autosplit: + # Use TP if specified + if shared.args.enable_tp: + cache = ExLlamaV2Cache_TP(model, base=cache_type) + else: + cache = cache_type(model, lazy=shared.args.autosplit) + + if shared.args.autosplit and not shared.args.enable_tp: model.load_autosplit(cache) tokenizer = ExLlamaV2Tokenizer(config) diff --git a/modules/exllamav2_hf.py b/modules/exllamav2_hf.py index 53143d9a92..320a8d2467 100644 --- a/modules/exllamav2_hf.py +++ b/modules/exllamav2_hf.py @@ -9,6 +9,7 @@ ExLlamaV2Cache, ExLlamaV2Cache_8bit, ExLlamaV2Cache_Q4, + ExLlamaV2Cache_TP, ExLlamaV2Config ) from torch.nn import CrossEntropyLoss @@ -20,14 +21,6 @@ try: import flash_attn -except ModuleNotFoundError: - logger.warning( - 'You are running ExLlamaV2 without flash-attention. This will cause the VRAM usage ' - 'to be a lot higher than it could be.\n' - 'Try installing flash-attention following the instructions here: ' - 'https://github.com/Dao-AILab/flash-attention#installation-and-features' - ) - pass except Exception: logger.warning('Failed to load flash-attention due to the following error:\n') traceback.print_exc() @@ -42,21 +35,30 @@ def __init__(self, config: ExLlamaV2Config): self.ex_model = ExLlamaV2(config) - if not shared.args.autosplit: - split = None - if shared.args.gpu_split: - split = [float(alloc) for alloc in shared.args.gpu_split.split(",")] + split = None + if shared.args.gpu_split: + split = [float(alloc) for alloc in shared.args.gpu_split.split(",")] + if shared.args.enable_tp: + self.ex_model.load_tp(split) + elif not shared.args.autosplit: self.ex_model.load(split) + # Determine the correct cache type if shared.args.cache_8bit: - self.ex_cache = ExLlamaV2Cache_8bit(self.ex_model, lazy=shared.args.autosplit) + cache_type = ExLlamaV2Cache_8bit elif shared.args.cache_4bit: - self.ex_cache = ExLlamaV2Cache_Q4(self.ex_model, lazy=shared.args.autosplit) + cache_type = ExLlamaV2Cache_Q4 else: - self.ex_cache = ExLlamaV2Cache(self.ex_model, lazy=shared.args.autosplit) + cache_type = ExLlamaV2Cache - if shared.args.autosplit: + # Use TP if specified + if shared.args.enable_tp: + self.ex_cache = ExLlamaV2Cache_TP(self.ex_model, base=cache_type) + else: + self.ex_cache = cache_type(self.ex_model, lazy=shared.args.autosplit) + + if shared.args.autosplit and not shared.args.enable_tp: self.ex_model.load_autosplit(self.ex_cache) self.past_seq = None diff --git a/modules/llama_cpp_python_hijack.py b/modules/llama_cpp_python_hijack.py index 64280dc9a0..2a9c10da2e 100644 --- a/modules/llama_cpp_python_hijack.py +++ b/modules/llama_cpp_python_hijack.py @@ -2,12 +2,12 @@ import platform from typing import Sequence +import numpy as np from tqdm import tqdm from modules import shared from modules.cache_utils import process_llamacpp_cache - imported_module = None @@ -57,11 +57,9 @@ def eval_with_progress(self, tokens: Sequence[int]): with tqdm to show prompt processing progress. """ - assert self._ctx.ctx is not None - assert self._batch.batch is not None self._ctx.kv_cache_seq_rm(-1, self.n_tokens, -1) - if len(tokens) > 1: + if len(tokens) > self.n_batch: progress_bar = tqdm(range(0, len(tokens), self.n_batch), desc="Prompt evaluation", leave=False) else: progress_bar = range(0, len(tokens), self.n_batch) @@ -80,13 +78,20 @@ def eval_with_progress(self, tokens: Sequence[int]): if self.context_params.logits_all: rows = n_tokens cols = self._n_vocab - logits = self._ctx.get_logits()[: rows * cols] - self.scores[n_past : n_past + n_tokens, :].reshape(-1)[: :] = logits + logits = np.ctypeslib.as_array( + self._ctx.get_logits(), shape=(rows * cols,) + ) + self.scores[n_past : n_past + n_tokens, :].reshape(-1)[::] = logits + self.last_updated_index = n_past + n_tokens - 1 else: rows = 1 cols = self._n_vocab - logits = self._ctx.get_logits()[: rows * cols] - self.scores[n_past + n_tokens - 1, :].reshape(-1)[: :] = logits + logits = np.ctypeslib.as_array( + self._ctx.get_logits(), shape=(rows * cols,) + ) + last_token_index = min(n_past + n_tokens - 1, self.scores.shape[0] - 1) + self.scores[last_token_index, :] = logits.reshape(-1) + self.last_updated_index = last_token_index # Update n_tokens self.n_tokens += n_tokens diff --git a/modules/llamacpp_hf.py b/modules/llamacpp_hf.py index 327e3a7b43..6611a7c1a8 100644 --- a/modules/llamacpp_hf.py +++ b/modules/llamacpp_hf.py @@ -127,7 +127,7 @@ def __call__(self, *args, **kwargs): self.model.reset() self.model.eval(seq) - logits = torch.tensor(self.model.scores[self.model.n_tokens - 1, :]).view(1, 1, -1).to(input_ids.device) + logits = torch.tensor(self.model.scores[self.model.last_updated_index, :]).view(1, 1, -1).to(input_ids.device) else: self.model.reset() self.model.eval(seq) @@ -205,5 +205,6 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P Llama = llama_cpp_lib().Llama model = Llama(**params) + model.last_updated_index = -1 return LlamacppHF(model, model_file) diff --git a/modules/loaders.py b/modules/loaders.py index 549de5fb02..16a3106e94 100644 --- a/modules/loaders.py +++ b/modules/loaders.py @@ -168,6 +168,8 @@ def transformers_samplers(): 'dry_base', 'dry_allowed_length', 'dry_sequence_breakers', + 'xtc_threshold', + 'xtc_probability', 'seed', 'do_sample', 'penalty_alpha', @@ -242,6 +244,8 @@ def transformers_samplers(): 'dry_base', 'dry_allowed_length', 'dry_sequence_breakers', + 'xtc_threshold', + 'xtc_probability', 'seed', 'do_sample', 'mirostat_mode', @@ -304,6 +308,8 @@ def transformers_samplers(): 'dry_base', 'dry_allowed_length', 'dry_sequence_breakers', + 'xtc_threshold', + 'xtc_probability', 'seed', 'do_sample', 'mirostat_mode', diff --git a/modules/models.py b/modules/models.py index b0e2346e18..3e8ae3f227 100644 --- a/modules/models.py +++ b/modules/models.py @@ -70,11 +70,11 @@ def load_model(model_name, loader=None): shared.model_name = model_name load_func_map = { 'Transformers': huggingface_loader, - 'AutoGPTQ': AutoGPTQ_loader, 'llama.cpp': llamacpp_loader, 'llamacpp_HF': llamacpp_HF_loader, 'ExLlamav2': ExLlamav2_loader, 'ExLlamav2_HF': ExLlamav2_HF_loader, + 'AutoGPTQ': AutoGPTQ_loader, 'HQQ': HQQ_loader, 'TensorRT-LLM': TensorRT_LLM_loader, } @@ -302,12 +302,6 @@ def llamacpp_HF_loader(model_name): return model -def AutoGPTQ_loader(model_name): - import modules.AutoGPTQ_loader - - return modules.AutoGPTQ_loader.load_quantized(model_name) - - def ExLlamav2_loader(model_name): from modules.exllamav2 import Exllamav2Model @@ -321,9 +315,21 @@ def ExLlamav2_HF_loader(model_name): return Exllamav2HF.from_pretrained(model_name) +def AutoGPTQ_loader(model_name): + try: + import modules.AutoGPTQ_loader + except ModuleNotFoundError: + raise ModuleNotFoundError("Failed to import 'autogptq'. Please install it manually following the instructions in the AutoGPTQ GitHub repository.") + + return modules.AutoGPTQ_loader.load_quantized(model_name) + + def HQQ_loader(model_name): - from hqq.core.quantize import HQQBackend, HQQLinear - from hqq.models.hf.base import AutoHQQHFModel + try: + from hqq.core.quantize import HQQBackend, HQQLinear + from hqq.models.hf.base import AutoHQQHFModel + except ModuleNotFoundError: + raise ModuleNotFoundError("Failed to import 'hqq'. Please install it manually following the instructions in the HQQ GitHub repository.") logger.info(f"Loading HQQ model with backend: \"{shared.args.hqq_backend}\"") @@ -334,7 +340,10 @@ def HQQ_loader(model_name): def TensorRT_LLM_loader(model_name): - from modules.tensorrt_llm import TensorRTLLMModel + try: + from modules.tensorrt_llm import TensorRTLLMModel + except ModuleNotFoundError: + raise ModuleNotFoundError("Failed to import 'tensorrt_llm'. Please install it manually following the instructions in the TensorRT-LLM GitHub repository.") model = TensorRTLLMModel.from_pretrained(model_name) return model diff --git a/modules/presets.py b/modules/presets.py index b00e829eb1..c8118fb3b7 100644 --- a/modules/presets.py +++ b/modules/presets.py @@ -44,7 +44,9 @@ def default_preset(): 'dry_base': 1.75, 'dry_allowed_length': 2, 'dry_sequence_breakers': '"\\n", ":", "\\"", "*"', - 'sampler_priority': 'temperature\ndynamic_temperature\nquadratic_sampling\ntop_k\ntop_p\ntypical_p\nepsilon_cutoff\neta_cutoff\ntfs\ntop_a\nmin_p\nmirostat' + 'xtc_threshold': 0.1, + 'xtc_probability': 0, + 'sampler_priority': 'repetition_penalty\npresence_penalty\nfrequency_penalty\ndry\ntemperature\ndynamic_temperature\nquadratic_sampling\ntop_k\ntop_p\ntypical_p\nepsilon_cutoff\neta_cutoff\ntfs\ntop_a\nmin_p\nmirostat\nxtc\nencoder_repetition_penalty\nno_repeat_ngram' } diff --git a/modules/sampler_hijack.py b/modules/sampler_hijack.py index 9fb661aecc..87f0b25e8e 100644 --- a/modules/sampler_hijack.py +++ b/modules/sampler_hijack.py @@ -1,6 +1,7 @@ import json import math import pprint +import random import torch import transformers @@ -191,6 +192,53 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to return scores +# Exclude Top Choices (XTC) +class XTCLogitsWarper(LogitsWarper): + def __init__(self, threshold: float, probability: float, filter_value: float = -float("Inf")): + self.threshold = threshold + self.probability = probability + self.filter_value = filter_value + self.special_token_ids = [ + shared.tokenizer.encode("\n")[-1], + ] + + if shared.tokenizer.eos_token_id is not None: + self.special_token_ids.append(shared.tokenizer.eos_token_id) + + def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: + # `random` returns values in the half-open range [0, 1), so setting `probability` + # to 0 means the sampler never takes action, while setting it to 1 means the sampler + # always takes action. + # + # Note that while XTC is most intuitively described as "if multiple tokens meet + # the threshold, then with probability...", reversing the two conditions is logically + # equivalent, and improves performance because processing can immediately be stopped + # if the random check fails. + if random.random() >= self.probability: + return scores + + sorted_logits, sorted_indices = torch.sort(scores, descending=True) + probs = sorted_logits.softmax(dim=-1) + + sorted_indices_to_remove = torch.full_like(probs, False, dtype=torch.bool) + + # This operation sets exactly those indices to `True` for which the next index has + # probability above the threshold. Since `probs` is sorted, those are the indices + # of all tokens that meet the threshold, *except* the least probable one. + sorted_indices_to_remove[..., :-1] = probs[..., 1:] >= self.threshold + + # Convert sorted_indices_to_remove to the original indices + indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) + + # If newline or EOS tokens would be removed, return the original scores + if indices_to_remove[:, self.special_token_ids].any(): + return scores + + # Otherwise, remove tokens with the mask + scores = scores.masked_fill(indices_to_remove, self.filter_value) + return scores + + class DRYLogitsProcessor(LogitsProcessor): def __init__(self, multiplier: float, base: float, allowed_length: int, sequence_breakers: set[int], _range: int): self.multiplier = multiplier @@ -323,62 +371,141 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to class RepetitionPenaltyLogitsProcessorWithRange(LogitsProcessor): - ''' - Copied from the transformers library - ''' - - def __init__(self, penalty: float, presence_penalty: float, frequency_penalty: float, _range: int): + def __init__(self, penalty: float, _range: int): if not (penalty > 0): raise ValueError(f"`penalty` has to be strictly positive, but is {penalty}") - self.penalty = penalty - self.presence_penalty = presence_penalty - self.frequency_penalty = frequency_penalty self._range = _range + def apply_repetition_penalty(self, input_ids_row, scores_row): + unique_ids = torch.unique(input_ids_row) + score = torch.gather(scores_row, 0, unique_ids) + + # Apply multiplicative repetition penalty + score = torch.where(score < 0, score * self.penalty, score / self.penalty) + scores_row.scatter_(0, unique_ids, score) + return scores_row + def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: input_ids = input_ids[:, -self._range:] + for input_ids_row, scores_row in zip(input_ids, scores): + scores_row = self.apply_repetition_penalty(input_ids_row, scores_row) + + return scores + - # We loop here because torch.unique() needs to process each row separately in the - # case that batch_size > 1. +class PresencePenaltyLogitsProcessor(LogitsProcessor): + def __init__(self, presence_penalty: float, _range: int): + self.presence_penalty = presence_penalty + self._range = _range + + def apply_presence_penalty(self, input_ids_row, scores_row): + unique_ids, counts = torch.unique(input_ids_row, return_counts=True) + + # Apply presence penalty + raw_presence_penalty = (counts > 0).to(scores_row.dtype) + presence_penalty = raw_presence_penalty * self.presence_penalty + scores_row.scatter_add_(0, unique_ids, -presence_penalty) + return scores_row + + def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: + input_ids = input_ids[:, -self._range:] for input_ids_row, scores_row in zip(input_ids, scores): - unique_ids, counts = torch.unique(input_ids_row, return_counts=True) - score = torch.gather(scores_row, 0, unique_ids) + scores_row = self.apply_presence_penalty(input_ids_row, scores_row) + return scores + - # multiplicative repetition penalty - # if score < 0 then repetition penalty has to be multiplied to reduce the previous token probability - score = torch.where(score < 0, score * self.penalty, score / self.penalty) - scores_row.scatter_(0, unique_ids, score) +class FrequencyPenaltyLogitsProcessor(LogitsProcessor): + def __init__(self, frequency_penalty: float, _range: int): + self.frequency_penalty = frequency_penalty + self._range = _range - # presence_penalty and frequency_penalty - raw_presence_penalty = (counts > 0).to(scores.dtype) - raw_frequency_penalty = counts.to(scores.dtype) - additive_penalty = raw_presence_penalty * self.presence_penalty + raw_frequency_penalty * self.frequency_penalty - scores_row.scatter_add_(0, unique_ids, -additive_penalty) + def apply_frequency_penalty(self, input_ids_row, scores_row): + unique_ids, counts = torch.unique(input_ids_row, return_counts=True) + # Apply frequency penalty + raw_frequency_penalty = counts.to(scores_row.dtype) + frequency_penalty = raw_frequency_penalty * self.frequency_penalty + scores_row.scatter_add_(0, unique_ids, -frequency_penalty) + return scores_row + + def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: + input_ids = input_ids[:, -self._range:] + for input_ids_row, scores_row in zip(input_ids, scores): + scores_row = self.apply_frequency_penalty(input_ids_row, scores_row) return scores -def get_logits_warper_patch(self, generation_config, **kwargs): +def get_logits_processor_patch(self, **kwargs): + generation_config = kwargs['generation_config'] # Parameter sanitization if isinstance(generation_config.temperature, int): generation_config.temperature = float(generation_config.temperature) # Must be float # Get the original warpers - warpers = self._get_logits_warper_old(generation_config, **kwargs) + warpers = self._get_logits_processor_old(**kwargs) - # Replace temperature with our modified class. - # Currently, it behaves identically to the original. - for i in range(len(warpers)): + for i in range(len(warpers) - 1, -1, -1): + # Replace temperature with our modified class. if warpers[i].__class__.__name__ == 'TemperatureLogitsWarper': warpers[i] = TemperatureLogitsWarperCustom( generation_config.temperature, ) + # Stuff we don't need + elif warpers[i].__class__.__name__ in ['SuppressTokensLogitsProcessor', 'RepetitionPenaltyLogitsProcessor']: + del warpers[i] + # Add custom warpers warpers_to_add = LogitsProcessorList() min_tokens_to_keep = 2 if generation_config.num_beams > 1 else 1 + + if generation_config.repetition_penalty is not None and generation_config.repetition_penalty != 1.0: + warpers_to_add.append( + RepetitionPenaltyLogitsProcessorWithRange( + penalty=generation_config.repetition_penalty, + _range=generation_config.repetition_penalty_range + ) + ) + + if generation_config.presence_penalty is not None and generation_config.presence_penalty != 0.0: + warpers_to_add.append( + PresencePenaltyLogitsProcessor( + presence_penalty=generation_config.presence_penalty, + _range=generation_config.repetition_penalty_range + ) + ) + + if generation_config.frequency_penalty is not None and generation_config.frequency_penalty != 0.0: + warpers_to_add.append( + FrequencyPenaltyLogitsProcessor( + frequency_penalty=generation_config.frequency_penalty, + _range=generation_config.repetition_penalty_range + ) + ) + + if generation_config.dry_multiplier is not None and generation_config.dry_multiplier > 0.0: + dry_sequence_breakers = generation_config.dry_sequence_breakers + + # Support both JSON array notation and comma-separated strings. + if not dry_sequence_breakers.startswith("["): + dry_sequence_breakers = "[" + dry_sequence_breakers + "]" + + sequence_breaker_strings = json.loads(dry_sequence_breakers) + # Prefix with 'a' to get the correct encoding of the token at the end of a text. + sequence_breakers = {shared.tokenizer.encode(f'a{s}')[-1] for s in sequence_breaker_strings} + + warpers.append( + DRYLogitsProcessor( + multiplier=generation_config.dry_multiplier, + base=generation_config.dry_base, + allowed_length=generation_config.dry_allowed_length, + sequence_breakers=sequence_breakers, + _range=generation_config.repetition_penalty_range, + ) + ) + if generation_config.tfs is not None and 0.0 <= generation_config.tfs < 1.0: warpers_to_add.append( TailFreeLogitsWarper( @@ -395,6 +522,14 @@ def get_logits_warper_patch(self, generation_config, **kwargs): ) ) + if generation_config.xtc_probability is not None and generation_config.xtc_probability > 0: + warpers_to_add.append( + XTCLogitsWarper( + threshold=generation_config.xtc_threshold, + probability=generation_config.xtc_probability, + ) + ) + if generation_config.dynamic_temperature: warpers_to_add.append( DynamicTemperatureLogitsWarper( @@ -454,17 +589,23 @@ def get_logits_warper_patch(self, generation_config, **kwargs): 'TopALogitsWarper': 'top_a', 'TopKLogitsWarper': 'top_k', 'TopPLogitsWarper': 'top_p', - 'TypicalLogitsWarper': 'typical_p' + 'TypicalLogitsWarper': 'typical_p', + 'XTCLogitsWarper': 'xtc', + 'RepetitionPenaltyLogitsProcessorWithRange': 'repetition_penalty', + 'PresencePenaltyLogitsProcessor': 'presence_penalty', + 'FrequencyPenaltyLogitsProcessor': 'frequency_penalty', + 'DRYLogitsProcessor': 'dry', + 'EncoderRepetitionPenaltyLogitsProcessor': 'encoder_repetition_penalty', + 'NoRepeatNGramLogitsProcessor': 'no_repeat_ngram', } def custom_sort_key(obj): class_name = obj.__class__.__name__ - # Return a large value if class name is not mapped or if the mapped nickname is not in priority + # Return -1 if class_name is not mapped if class_name not in class_name_to_nickname or class_name_to_nickname[class_name] not in sampler_priority: - return float('inf') + return -1 - # Return the index of the nickname in the priority list for sorting return sampler_priority.index(class_name_to_nickname[class_name]) # Sort the list using the custom key function @@ -482,49 +623,6 @@ def custom_sort_key(obj): return warpers -def get_logits_processor_patch(self, **kwargs): - generation_config = kwargs['generation_config'] - - do_rep_pen_hijack = (generation_config.repetition_penalty > 1) or (generation_config.presence_penalty != 0) or (generation_config.frequency_penalty != 0) - if do_rep_pen_hijack: - generation_config.repetition_penalty = 1.1 # Set to value > 1 to ensure RepetitionPenaltyLogitsProcessor is created - - result = self._get_logits_processor_old(**kwargs) - - if do_rep_pen_hijack: - for i in range(len(result)): - if result[i].__class__.__name__ == 'RepetitionPenaltyLogitsProcessor': - result[i] = RepetitionPenaltyLogitsProcessorWithRange( - generation_config.repetition_penalty, - generation_config.presence_penalty, - generation_config.frequency_penalty, - generation_config.repetition_penalty_range - ) - - if generation_config.dry_multiplier is not None and generation_config.dry_multiplier > 0.0: - dry_sequence_breakers = generation_config.dry_sequence_breakers - - # Support both JSON array notation and comma-separated strings. - if not dry_sequence_breakers.startswith("["): - dry_sequence_breakers = "[" + dry_sequence_breakers + "]" - - sequence_breaker_strings = json.loads(dry_sequence_breakers) - # Prefix with 'a' to get the correct encoding of the token at the end of a text. - sequence_breakers = {shared.tokenizer.encode(f'a{s}')[-1] for s in sequence_breaker_strings} - - result.append( - DRYLogitsProcessor( - multiplier=generation_config.dry_multiplier, - base=generation_config.dry_base, - allowed_length=generation_config.dry_allowed_length, - sequence_breakers=sequence_breakers, - _range=generation_config.repetition_penalty_range, - ) - ) - - return result - - def generation_config_init_patch(self, **kwargs): self.__init___old(**kwargs) self.min_p = kwargs.pop("min_p", 0.0) @@ -546,14 +644,13 @@ def generation_config_init_patch(self, **kwargs): self.dry_base = kwargs.pop("dry_base", 1.75) self.dry_allowed_length = kwargs.pop("dry_allowed_length", 2) self.dry_sequence_breakers = kwargs.pop("dry_sequence_breakers", '"\\n", ":", "\\"", "*"') + self.xtc_threshold = kwargs.pop("xtc_threshold", 0.1) + self.xtc_probability = kwargs.pop("xtc_probability", 0) self.temperature_last = kwargs.pop("temperature_last", False) - self.sampler_priority = kwargs.pop("sampler_priority", ['temperature', 'dynamic_temperature', 'quadratic_sampling', 'top_k', 'top_p', 'typical_p', 'epsilon_cutoff', 'eta_cutoff', 'tfs', 'top_a', 'min_p', 'mirostat']) + self.sampler_priority = kwargs.pop("sampler_priority", ['repetition_penalty', 'presence_penalty', 'frequency_penalty', 'dry', 'temperature', 'dynamic_temperature', 'quadratic_sampling', 'top_k', 'top_p', 'typical_p', 'epsilon_cutoff', 'eta_cutoff', 'tfs', 'top_a', 'min_p', 'mirostat', 'xtc', 'encoder_repetition_penalty', 'no_repeat_ngram']) def hijack_samplers(): - transformers.GenerationMixin._get_logits_warper_old = transformers.GenerationMixin._get_logits_warper - transformers.GenerationMixin._get_logits_warper = get_logits_warper_patch - transformers.GenerationMixin._get_logits_processor_old = transformers.GenerationMixin._get_logits_processor transformers.GenerationMixin._get_logits_processor = get_logits_processor_patch diff --git a/modules/shared.py b/modules/shared.py index 43533a1480..894ed6fe56 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -146,6 +146,7 @@ group.add_argument('--cache_8bit', action='store_true', help='Use 8-bit cache to save VRAM.') group.add_argument('--cache_4bit', action='store_true', help='Use Q4 cache to save VRAM.') group.add_argument('--num_experts_per_token', type=int, default=2, help='Number of experts to use for generation. Applies to MoE models like Mixtral.') +group.add_argument('--enable_tp', action='store_true', help='Enable Tensor Parallelism (TP) in ExLlamaV2.') # AutoGPTQ group = parser.add_argument_group('AutoGPTQ') diff --git a/modules/text_generation.py b/modules/text_generation.py index 75e5ef36ae..86245098ab 100644 --- a/modules/text_generation.py +++ b/modules/text_generation.py @@ -274,7 +274,12 @@ def get_reply_from_output_ids(output_ids, state=None, starting_from=0): if (hasattr(shared.tokenizer, 'convert_ids_to_tokens') and len(output_ids) > starting_from) and not reply.startswith(' '): first_token = shared.tokenizer.convert_ids_to_tokens(int(output_ids[starting_from])) if isinstance(first_token, (bytes,)): - first_token = first_token.decode('utf8') + # try to decode the bytes to a string + # if it fails, which means it's not a string in this turn, just ignore it + try: + first_token = first_token.decode('utf8') + except UnicodeDecodeError: + first_token = '' if first_token.startswith('▁'): reply = ' ' + reply @@ -284,7 +289,7 @@ def get_reply_from_output_ids(output_ids, state=None, starting_from=0): def generate_reply_HF(question, original_question, seed, state, stopping_strings=None, is_chat=False): generate_params = {} - for k in ['max_new_tokens', 'temperature', 'temperature_last', 'dynamic_temperature', 'dynatemp_low', 'dynatemp_high', 'dynatemp_exponent', 'smoothing_factor', 'smoothing_curve', 'top_p', 'min_p', 'top_k', 'repetition_penalty', 'presence_penalty', 'frequency_penalty', 'repetition_penalty_range', 'typical_p', 'tfs', 'top_a', 'guidance_scale', 'penalty_alpha', 'mirostat_mode', 'mirostat_tau', 'mirostat_eta', 'do_sample', 'encoder_repetition_penalty', 'no_repeat_ngram_size', 'dry_multiplier', 'dry_base', 'dry_allowed_length', 'dry_sequence_breakers']: + for k in ['max_new_tokens', 'temperature', 'temperature_last', 'dynamic_temperature', 'dynatemp_low', 'dynatemp_high', 'dynatemp_exponent', 'smoothing_factor', 'smoothing_curve', 'top_p', 'min_p', 'top_k', 'repetition_penalty', 'presence_penalty', 'frequency_penalty', 'repetition_penalty_range', 'typical_p', 'tfs', 'top_a', 'guidance_scale', 'penalty_alpha', 'mirostat_mode', 'mirostat_tau', 'mirostat_eta', 'do_sample', 'encoder_repetition_penalty', 'no_repeat_ngram_size', 'dry_multiplier', 'dry_base', 'dry_allowed_length', 'dry_sequence_breakers', 'xtc_threshold', 'xtc_probability']: if k in state: generate_params[k] = state[k] diff --git a/modules/training.py b/modules/training.py index b003fc8c28..11c4b8c5d7 100644 --- a/modules/training.py +++ b/modules/training.py @@ -18,14 +18,6 @@ import torch import transformers from datasets import Dataset, load_dataset -from peft import ( - LoraConfig, - get_peft_model, - prepare_model_for_kbit_training, - set_peft_model_state_dict -) -from peft.utils.other import \ - TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING as model_to_lora_modules from transformers import is_torch_xpu_available from transformers.models.auto.modeling_auto import ( MODEL_FOR_CAUSAL_LM_MAPPING_NAMES @@ -292,6 +284,16 @@ def calc_trainable_parameters(model): def do_train(lora_name: str, always_override: bool, q_proj_en: bool, v_proj_en: bool, k_proj_en: bool, o_proj_en: bool, gate_proj_en: bool, down_proj_en: bool, up_proj_en: bool, save_steps: int, micro_batch_size: int, batch_size: int, epochs: int, learning_rate: str, lr_scheduler_type: str, lora_rank: int, lora_alpha: int, lora_dropout: float, cutoff_len: int, dataset: str, eval_dataset: str, format: str, eval_steps: int, raw_text_file: str, overlap_len: int, newline_favor_len: int, higher_rank_limit: bool, warmup_steps: int, optimizer: str, hard_cut_string: str, train_only_after: str, stop_at_loss: float, add_eos_token: bool, min_chars: int, report_to: str): + from peft import ( + LoraConfig, + get_peft_model, + prepare_model_for_kbit_training, + set_peft_model_state_dict + ) + from peft.utils.other import \ + TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING as \ + model_to_lora_modules + global WANT_INTERRUPT WANT_INTERRUPT = False diff --git a/modules/ui.py b/modules/ui.py index 47f92cf0f9..76b1c009c4 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -158,6 +158,8 @@ def list_interface_input_elements(): 'dry_base', 'dry_allowed_length', 'dry_sequence_breakers', + 'xtc_threshold', + 'xtc_probability', 'do_sample', 'penalty_alpha', 'mirostat_mode', diff --git a/modules/ui_parameters.py b/modules/ui_parameters.py index eff62c20e5..a2665e0d21 100644 --- a/modules/ui_parameters.py +++ b/modules/ui_parameters.py @@ -45,6 +45,10 @@ def create_ui(default_preset): shared.gradio['dry_base'] = gr.Slider(1, 4, value=generate_params['dry_base'], step=0.01, label='dry_base', info='Controls how fast the penalty grows with increasing sequence length.') shared.gradio['dry_sequence_breakers'] = gr.Textbox(value=generate_params['dry_sequence_breakers'], label='dry_sequence_breakers', info='Tokens across which sequence matching is not continued. Specified as a comma-separated list of quoted strings.') + with gr.Blocks(): + shared.gradio['xtc_threshold'] = gr.Slider(0, 0.5, value=generate_params['xtc_threshold'], step=0.01, label='xtc_threshold', info='If 2 or more tokens have probability above this threshold, consider removing all but the last one.') + shared.gradio['xtc_probability'] = gr.Slider(0, 1, value=generate_params['xtc_probability'], step=0.01, label='xtc_probability', info='Probability that the removal will actually happen. 0 disables the sampler. 1 makes it always happen.') + gr.Markdown("[Learn more](https://github.com/oobabooga/text-generation-webui/wiki/03-%E2%80%90-Parameters-Tab)") with gr.Column(): diff --git a/one_click.py b/one_click.py index 0a0412ba07..9bad25d1be 100644 --- a/one_click.py +++ b/one_click.py @@ -16,9 +16,9 @@ # Define the required PyTorch version -TORCH_VERSION = "2.2.2" -TORCHVISION_VERSION = "0.17.2" -TORCHAUDIO_VERSION = "2.2.2" +TORCH_VERSION = "2.4.1" +TORCHVISION_VERSION = "0.19.1" +TORCHAUDIO_VERSION = "2.4.1" # Environment script_dir = os.getcwd() @@ -117,7 +117,7 @@ def update_pytorch(): elif is_cuda: install_pytorch += "--index-url https://download.pytorch.org/whl/cu121" elif is_rocm: - install_pytorch += "--index-url https://download.pytorch.org/whl/rocm5.6" + install_pytorch += "--index-url https://download.pytorch.org/whl/rocm6.1" elif is_cpu: install_pytorch += "--index-url https://download.pytorch.org/whl/cpu" elif is_intel: @@ -189,8 +189,11 @@ def run_cmd(cmd, assert_success=False, environment=False, capture_output=False, conda_sh_path = os.path.join(script_dir, "installer_files", "conda", "etc", "profile.d", "conda.sh") cmd = f'. "{conda_sh_path}" && conda activate "{conda_env_path}" && {cmd}' + # Set executable to None for Windows, /bin/bash for everything else + executable = None if is_windows() else '/bin/bash' + # Run shell commands - result = subprocess.run(cmd, shell=True, capture_output=capture_output, env=env) + result = subprocess.run(cmd, shell=True, capture_output=capture_output, env=env, executable=executable) # Assert the command ran successfully if assert_success and result.returncode != 0: @@ -239,7 +242,7 @@ def install_webui(): "What is your GPU?", { 'A': 'NVIDIA', - 'B': 'AMD (Linux/MacOS only. Requires ROCm SDK 5.6 on Linux)', + 'B': 'AMD (Linux/MacOS only. Requires ROCm SDK 6.1 on Linux)', 'C': 'Apple M Series', 'D': 'Intel Arc (IPEX)', 'N': 'None (I want to run models in CPU mode)' @@ -294,7 +297,7 @@ def install_webui(): else: install_pytorch += "--index-url https://download.pytorch.org/whl/cu121" elif selected_gpu == "AMD": - install_pytorch += "--index-url https://download.pytorch.org/whl/rocm5.6" + install_pytorch += "--index-url https://download.pytorch.org/whl/rocm6.1" elif selected_gpu in ["APPLE", "NONE"]: install_pytorch += "--index-url https://download.pytorch.org/whl/cpu" elif selected_gpu == "INTEL": diff --git a/requirements.txt b/requirements.txt index 51558a845a..5623f1f7ac 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,13 +1,10 @@ accelerate==0.33.* -aqlm[gpu,cpu]==1.1.6; platform_system == "Linux" -auto-gptq==0.7.1 -bitsandbytes==0.43.* +bitsandbytes==0.44.* colorama datasets einops fastapi==0.112.4 gradio==4.26.* -hqq==0.1.7.post3 jinja2==3.1.4 lm_eval==0.3.0 markdown @@ -26,7 +23,7 @@ safetensors==0.4.* scipy sentencepiece tensorboard -transformers==4.44.* +transformers==4.45.* tqdm wandb @@ -37,38 +34,30 @@ sse-starlette==1.6.5 tiktoken # llama-cpp-python (CPU only, AVX2) -https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/cpu/llama_cpp_python-0.2.89+cpuavx2-cp311-cp311-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.11" -https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/cpu/llama_cpp_python-0.2.89+cpuavx2-cp310-cp310-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.10" -https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/cpu/llama_cpp_python-0.2.89+cpuavx2-cp311-cp311-win_amd64.whl; platform_system == "Windows" and python_version == "3.11" -https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/cpu/llama_cpp_python-0.2.89+cpuavx2-cp310-cp310-win_amd64.whl; platform_system == "Windows" and python_version == "3.10" +https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/cpu/llama_cpp_python-0.3.1+cpuavx2-cp311-cp311-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.11" +https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/cpu/llama_cpp_python-0.3.1+cpuavx2-cp310-cp310-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.10" +https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/cpu/llama_cpp_python-0.3.1+cpuavx2-cp311-cp311-win_amd64.whl; platform_system == "Windows" and python_version == "3.11" +https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/cpu/llama_cpp_python-0.3.1+cpuavx2-cp310-cp310-win_amd64.whl; platform_system == "Windows" and python_version == "3.10" # llama-cpp-python (CUDA, no tensor cores) -https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/textgen-webui/llama_cpp_python_cuda-0.2.89+cu121-cp311-cp311-win_amd64.whl; platform_system == "Windows" and python_version == "3.11" -https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/textgen-webui/llama_cpp_python_cuda-0.2.89+cu121-cp310-cp310-win_amd64.whl; platform_system == "Windows" and python_version == "3.10" -https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/textgen-webui/llama_cpp_python_cuda-0.2.89+cu121-cp311-cp311-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.11" -https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/textgen-webui/llama_cpp_python_cuda-0.2.89+cu121-cp310-cp310-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.10" +https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/textgen-webui/llama_cpp_python_cuda-0.3.1+cu121-cp311-cp311-win_amd64.whl; platform_system == "Windows" and python_version == "3.11" +https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/textgen-webui/llama_cpp_python_cuda-0.3.1+cu121-cp310-cp310-win_amd64.whl; platform_system == "Windows" and python_version == "3.10" +https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/textgen-webui/llama_cpp_python_cuda-0.3.1+cu121-cp311-cp311-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.11" +https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/textgen-webui/llama_cpp_python_cuda-0.3.1+cu121-cp310-cp310-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.10" # llama-cpp-python (CUDA, tensor cores) -https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/textgen-webui/llama_cpp_python_cuda_tensorcores-0.2.89+cu121-cp311-cp311-win_amd64.whl; platform_system == "Windows" and python_version == "3.11" -https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/textgen-webui/llama_cpp_python_cuda_tensorcores-0.2.89+cu121-cp310-cp310-win_amd64.whl; platform_system == "Windows" and python_version == "3.10" -https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/textgen-webui/llama_cpp_python_cuda_tensorcores-0.2.89+cu121-cp311-cp311-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.11" -https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/textgen-webui/llama_cpp_python_cuda_tensorcores-0.2.89+cu121-cp310-cp310-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.10" +https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/textgen-webui/llama_cpp_python_cuda_tensorcores-0.3.1+cu121-cp311-cp311-win_amd64.whl; platform_system == "Windows" and python_version == "3.11" +https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/textgen-webui/llama_cpp_python_cuda_tensorcores-0.3.1+cu121-cp310-cp310-win_amd64.whl; platform_system == "Windows" and python_version == "3.10" +https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/textgen-webui/llama_cpp_python_cuda_tensorcores-0.3.1+cu121-cp311-cp311-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.11" +https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/textgen-webui/llama_cpp_python_cuda_tensorcores-0.3.1+cu121-cp310-cp310-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.10" # CUDA wheels -https://github.com/oobabooga/exllamav2/releases/download/v0.1.8/exllamav2-0.1.8+cu121.torch2.2.2-cp311-cp311-win_amd64.whl; platform_system == "Windows" and python_version == "3.11" -https://github.com/oobabooga/exllamav2/releases/download/v0.1.8/exllamav2-0.1.8+cu121.torch2.2.2-cp310-cp310-win_amd64.whl; platform_system == "Windows" and python_version == "3.10" -https://github.com/oobabooga/exllamav2/releases/download/v0.1.8/exllamav2-0.1.8+cu121.torch2.2.2-cp311-cp311-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.11" -https://github.com/oobabooga/exllamav2/releases/download/v0.1.8/exllamav2-0.1.8+cu121.torch2.2.2-cp310-cp310-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.10" -https://github.com/oobabooga/exllamav2/releases/download/v0.1.8/exllamav2-0.1.8-py3-none-any.whl; platform_system == "Linux" and platform_machine != "x86_64" -https://github.com/oobabooga/flash-attention/releases/download/v2.6.1/flash_attn-2.6.1+cu122torch2.2.2cxx11abiFALSE-cp311-cp311-win_amd64.whl; platform_system == "Windows" and python_version == "3.11" -https://github.com/oobabooga/flash-attention/releases/download/v2.6.1/flash_attn-2.6.1+cu122torch2.2.2cxx11abiFALSE-cp310-cp310-win_amd64.whl; platform_system == "Windows" and python_version == "3.10" -https://github.com/Dao-AILab/flash-attention/releases/download/v2.6.1/flash_attn-2.6.1+cu123torch2.2cxx11abiFALSE-cp311-cp311-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.11" -https://github.com/Dao-AILab/flash-attention/releases/download/v2.6.1/flash_attn-2.6.1+cu123torch2.2cxx11abiFALSE-cp310-cp310-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.10" -https://github.com/oobabooga/AutoAWQ/releases/download/0.2.6/autoawq-0.2.6-cp311-cp311-win_amd64.whl; platform_system == "Windows" and python_version == "3.11" -https://github.com/oobabooga/AutoAWQ/releases/download/0.2.6/autoawq-0.2.6-cp310-cp310-win_amd64.whl; platform_system == "Windows" and python_version == "3.10" -https://github.com/oobabooga/AutoAWQ/releases/download/0.2.6/autoawq-0.2.6-cp311-cp311-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.11" -https://github.com/oobabooga/AutoAWQ/releases/download/0.2.6/autoawq-0.2.6-cp310-cp310-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.10" -https://github.com/oobabooga/AutoAWQ_kernels/releases/download/0.0.7/autoawq_kernels-0.0.7-cp311-cp311-win_amd64.whl; platform_system == "Windows" and python_version == "3.11" -https://github.com/oobabooga/AutoAWQ_kernels/releases/download/0.0.7/autoawq_kernels-0.0.7-cp310-cp310-win_amd64.whl; platform_system == "Windows" and python_version == "3.10" -https://github.com/oobabooga/AutoAWQ_kernels/releases/download/0.0.7/autoawq_kernels-0.0.7-cp311-cp311-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.11" -https://github.com/oobabooga/AutoAWQ_kernels/releases/download/0.0.7/autoawq_kernels-0.0.7-cp310-cp310-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.10" +https://github.com/oobabooga/exllamav2/releases/download/v0.2.3/exllamav2-0.2.3+cu121.torch2.4.1-cp311-cp311-win_amd64.whl; platform_system == "Windows" and python_version == "3.11" +https://github.com/oobabooga/exllamav2/releases/download/v0.2.3/exllamav2-0.2.3+cu121.torch2.4.1-cp310-cp310-win_amd64.whl; platform_system == "Windows" and python_version == "3.10" +https://github.com/oobabooga/exllamav2/releases/download/v0.2.3/exllamav2-0.2.3+cu121.torch2.4.1-cp311-cp311-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.11" +https://github.com/oobabooga/exllamav2/releases/download/v0.2.3/exllamav2-0.2.3+cu121.torch2.4.1-cp310-cp310-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.10" +https://github.com/oobabooga/exllamav2/releases/download/v0.2.3/exllamav2-0.2.3-py3-none-any.whl; platform_system == "Linux" and platform_machine != "x86_64" +https://github.com/oobabooga/flash-attention/releases/download/v2.6.3/flash_attn-2.6.3+cu122torch2.4.1cxx11abiFALSE-cp311-cp311-win_amd64.whl; platform_system == "Windows" and python_version == "3.11" +https://github.com/oobabooga/flash-attention/releases/download/v2.6.3/flash_attn-2.6.3+cu122torch2.4.1cxx11abiFALSE-cp310-cp310-win_amd64.whl; platform_system == "Windows" and python_version == "3.10" +https://github.com/Dao-AILab/flash-attention/releases/download/v2.6.3/flash_attn-2.6.3+cu123torch2.4cxx11abiFALSE-cp311-cp311-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.11" +https://github.com/Dao-AILab/flash-attention/releases/download/v2.6.3/flash_attn-2.6.3+cu123torch2.4cxx11abiFALSE-cp310-cp310-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.10" diff --git a/requirements_amd.txt b/requirements_amd.txt index e9dfb1283e..d203c157c2 100644 --- a/requirements_amd.txt +++ b/requirements_amd.txt @@ -4,7 +4,6 @@ datasets einops fastapi==0.112.4 gradio==4.26.* -hqq==0.1.7.post3 jinja2==3.1.4 lm_eval==0.3.0 markdown @@ -23,7 +22,7 @@ safetensors==0.4.* scipy sentencepiece tensorboard -transformers==4.44.* +transformers==4.45.* tqdm wandb @@ -34,18 +33,14 @@ sse-starlette==1.6.5 tiktoken # llama-cpp-python (CPU only, AVX2) -https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/cpu/llama_cpp_python-0.2.89+cpuavx2-cp311-cp311-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.11" -https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/cpu/llama_cpp_python-0.2.89+cpuavx2-cp310-cp310-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.10" -https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/cpu/llama_cpp_python-0.2.89+cpuavx2-cp311-cp311-win_amd64.whl; platform_system == "Windows" and python_version == "3.11" -https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/cpu/llama_cpp_python-0.2.89+cpuavx2-cp310-cp310-win_amd64.whl; platform_system == "Windows" and python_version == "3.10" +https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/cpu/llama_cpp_python-0.3.1+cpuavx2-cp311-cp311-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.11" +https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/cpu/llama_cpp_python-0.3.1+cpuavx2-cp310-cp310-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.10" +https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/cpu/llama_cpp_python-0.3.1+cpuavx2-cp311-cp311-win_amd64.whl; platform_system == "Windows" and python_version == "3.11" +https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/cpu/llama_cpp_python-0.3.1+cpuavx2-cp310-cp310-win_amd64.whl; platform_system == "Windows" and python_version == "3.10" # AMD wheels -https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/rocm/llama_cpp_python_cuda-0.2.89+rocm5.6.1-cp311-cp311-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.11" -https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/rocm/llama_cpp_python_cuda-0.2.89+rocm5.6.1-cp310-cp310-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.10" -https://github.com/oobabooga/exllamav2/releases/download/v0.1.8/exllamav2-0.1.8+rocm5.6.torch2.2.2-cp311-cp311-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.11" -https://github.com/oobabooga/exllamav2/releases/download/v0.1.8/exllamav2-0.1.8+rocm5.6.torch2.2.2-cp310-cp310-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.10" -https://github.com/oobabooga/exllamav2/releases/download/v0.1.8/exllamav2-0.1.8-py3-none-any.whl; platform_system != "Darwin" and platform_machine != "x86_64" -https://github.com/oobabooga/AutoAWQ/releases/download/0.2.6/autoawq-0.2.6-cp311-cp311-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.11" -https://github.com/oobabooga/AutoAWQ/releases/download/0.2.6/autoawq-0.2.6-cp310-cp310-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.10" -https://github.com/oobabooga/AutoAWQ_kernels/releases/download/0.0.7/autoawq_kernels-0.0.7+rocm561-cp311-cp311-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.11" -https://github.com/oobabooga/AutoAWQ_kernels/releases/download/0.0.7/autoawq_kernels-0.0.7+rocm561-cp310-cp310-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.10" +https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/rocm/llama_cpp_python_cuda-0.3.1+rocm6.1.2-cp311-cp311-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.11" +https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/rocm/llama_cpp_python_cuda-0.3.1+rocm6.1.2-cp310-cp310-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.10" +https://github.com/oobabooga/exllamav2/releases/download/v0.2.3/exllamav2-0.2.3+rocm6.1.torch2.4.1-cp311-cp311-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.11" +https://github.com/oobabooga/exllamav2/releases/download/v0.2.3/exllamav2-0.2.3+rocm6.1.torch2.4.1-cp310-cp310-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.10" +https://github.com/oobabooga/exllamav2/releases/download/v0.2.3/exllamav2-0.2.3-py3-none-any.whl; platform_system != "Darwin" and platform_machine != "x86_64" diff --git a/requirements_amd_noavx2.txt b/requirements_amd_noavx2.txt index 69a350ad54..3a7a950d54 100644 --- a/requirements_amd_noavx2.txt +++ b/requirements_amd_noavx2.txt @@ -4,7 +4,6 @@ datasets einops fastapi==0.112.4 gradio==4.26.* -hqq==0.1.7.post3 jinja2==3.1.4 lm_eval==0.3.0 markdown @@ -23,7 +22,7 @@ safetensors==0.4.* scipy sentencepiece tensorboard -transformers==4.44.* +transformers==4.45.* tqdm wandb @@ -34,16 +33,12 @@ sse-starlette==1.6.5 tiktoken # llama-cpp-python (CPU only, no AVX2) -https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/cpu/llama_cpp_python-0.2.89+cpuavx-cp311-cp311-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.11" -https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/cpu/llama_cpp_python-0.2.89+cpuavx-cp310-cp310-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.10" -https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/cpu/llama_cpp_python-0.2.89+cpuavx-cp311-cp311-win_amd64.whl; platform_system == "Windows" and python_version == "3.11" -https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/cpu/llama_cpp_python-0.2.89+cpuavx-cp310-cp310-win_amd64.whl; platform_system == "Windows" and python_version == "3.10" +https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/cpu/llama_cpp_python-0.3.1+cpuavx-cp311-cp311-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.11" +https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/cpu/llama_cpp_python-0.3.1+cpuavx-cp310-cp310-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.10" +https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/cpu/llama_cpp_python-0.3.1+cpuavx-cp311-cp311-win_amd64.whl; platform_system == "Windows" and python_version == "3.11" +https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/cpu/llama_cpp_python-0.3.1+cpuavx-cp310-cp310-win_amd64.whl; platform_system == "Windows" and python_version == "3.10" # AMD wheels -https://github.com/oobabooga/exllamav2/releases/download/v0.1.8/exllamav2-0.1.8+rocm5.6.torch2.2.2-cp311-cp311-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.11" -https://github.com/oobabooga/exllamav2/releases/download/v0.1.8/exllamav2-0.1.8+rocm5.6.torch2.2.2-cp310-cp310-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.10" -https://github.com/oobabooga/exllamav2/releases/download/v0.1.8/exllamav2-0.1.8-py3-none-any.whl; platform_system != "Darwin" and platform_machine != "x86_64" -https://github.com/oobabooga/AutoAWQ/releases/download/0.2.6/autoawq-0.2.6-cp311-cp311-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.11" -https://github.com/oobabooga/AutoAWQ/releases/download/0.2.6/autoawq-0.2.6-cp310-cp310-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.10" -https://github.com/oobabooga/AutoAWQ_kernels/releases/download/0.0.7/autoawq_kernels-0.0.7+rocm561-cp311-cp311-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.11" -https://github.com/oobabooga/AutoAWQ_kernels/releases/download/0.0.7/autoawq_kernels-0.0.7+rocm561-cp310-cp310-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.10" +https://github.com/oobabooga/exllamav2/releases/download/v0.2.3/exllamav2-0.2.3+rocm6.1.torch2.4.1-cp311-cp311-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.11" +https://github.com/oobabooga/exllamav2/releases/download/v0.2.3/exllamav2-0.2.3+rocm6.1.torch2.4.1-cp310-cp310-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.10" +https://github.com/oobabooga/exllamav2/releases/download/v0.2.3/exllamav2-0.2.3-py3-none-any.whl; platform_system != "Darwin" and platform_machine != "x86_64" diff --git a/requirements_apple_intel.txt b/requirements_apple_intel.txt index 0326916eee..61c5367ffb 100644 --- a/requirements_apple_intel.txt +++ b/requirements_apple_intel.txt @@ -4,7 +4,6 @@ datasets einops fastapi==0.112.4 gradio==4.26.* -hqq==0.1.7.post3 jinja2==3.1.4 lm_eval==0.3.0 markdown @@ -23,7 +22,7 @@ safetensors==0.4.* scipy sentencepiece tensorboard -transformers==4.44.* +transformers==4.45.* tqdm wandb @@ -34,8 +33,8 @@ sse-starlette==1.6.5 tiktoken # Mac wheels -https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/metal/llama_cpp_python-0.2.89-cp311-cp311-macosx_12_0_x86_64.whl; platform_system == "Darwin" and platform_release >= "21.0.0" and platform_release < "22.0.0" and python_version == "3.11" -https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/metal/llama_cpp_python-0.2.89-cp310-cp310-macosx_12_0_x86_64.whl; platform_system == "Darwin" and platform_release >= "21.0.0" and platform_release < "22.0.0" and python_version == "3.10" -https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/metal/llama_cpp_python-0.2.89-cp311-cp311-macosx_14_0_x86_64.whl; platform_system == "Darwin" and platform_release >= "23.0.0" and platform_release < "24.0.0" and python_version == "3.11" -https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/metal/llama_cpp_python-0.2.89-cp310-cp310-macosx_14_0_x86_64.whl; platform_system == "Darwin" and platform_release >= "23.0.0" and platform_release < "24.0.0" and python_version == "3.10" -https://github.com/oobabooga/exllamav2/releases/download/v0.1.8/exllamav2-0.1.8-py3-none-any.whl +https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/metal/llama_cpp_python-0.3.1-cp311-cp311-macosx_12_0_x86_64.whl; platform_system == "Darwin" and platform_release >= "21.0.0" and platform_release < "22.0.0" and python_version == "3.11" +https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/metal/llama_cpp_python-0.3.1-cp310-cp310-macosx_12_0_x86_64.whl; platform_system == "Darwin" and platform_release >= "21.0.0" and platform_release < "22.0.0" and python_version == "3.10" +https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/metal/llama_cpp_python-0.3.1-cp311-cp311-macosx_14_0_x86_64.whl; platform_system == "Darwin" and platform_release >= "23.0.0" and platform_release < "24.0.0" and python_version == "3.11" +https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/metal/llama_cpp_python-0.3.1-cp310-cp310-macosx_14_0_x86_64.whl; platform_system == "Darwin" and platform_release >= "23.0.0" and platform_release < "24.0.0" and python_version == "3.10" +https://github.com/oobabooga/exllamav2/releases/download/v0.2.3/exllamav2-0.2.3-py3-none-any.whl diff --git a/requirements_apple_silicon.txt b/requirements_apple_silicon.txt index 70ee3d27a1..494c127afe 100644 --- a/requirements_apple_silicon.txt +++ b/requirements_apple_silicon.txt @@ -4,7 +4,6 @@ datasets einops fastapi==0.112.4 gradio==4.26.* -hqq==0.1.7.post3 jinja2==3.1.4 lm_eval==0.3.0 markdown @@ -23,7 +22,7 @@ safetensors==0.4.* scipy sentencepiece tensorboard -transformers==4.44.* +transformers==4.45.* tqdm wandb @@ -34,10 +33,10 @@ sse-starlette==1.6.5 tiktoken # Mac wheels -https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/metal/llama_cpp_python-0.2.89-cp311-cp311-macosx_12_0_arm64.whl; platform_system == "Darwin" and platform_release >= "21.0.0" and platform_release < "22.0.0" and python_version == "3.11" -https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/metal/llama_cpp_python-0.2.89-cp310-cp310-macosx_12_0_arm64.whl; platform_system == "Darwin" and platform_release >= "21.0.0" and platform_release < "22.0.0" and python_version == "3.10" -https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/metal/llama_cpp_python-0.2.89-cp311-cp311-macosx_13_0_arm64.whl; platform_system == "Darwin" and platform_release >= "22.0.0" and platform_release < "23.0.0" and python_version == "3.11" -https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/metal/llama_cpp_python-0.2.89-cp310-cp310-macosx_13_0_arm64.whl; platform_system == "Darwin" and platform_release >= "22.0.0" and platform_release < "23.0.0" and python_version == "3.10" -https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/metal/llama_cpp_python-0.2.89-cp311-cp311-macosx_14_0_arm64.whl; platform_system == "Darwin" and platform_release >= "23.0.0" and platform_release < "24.0.0" and python_version == "3.11" -https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/metal/llama_cpp_python-0.2.89-cp310-cp310-macosx_14_0_arm64.whl; platform_system == "Darwin" and platform_release >= "23.0.0" and platform_release < "24.0.0" and python_version == "3.10" -https://github.com/oobabooga/exllamav2/releases/download/v0.1.8/exllamav2-0.1.8-py3-none-any.whl +https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/metal/llama_cpp_python-0.3.1-cp311-cp311-macosx_12_0_arm64.whl; platform_system == "Darwin" and platform_release >= "21.0.0" and platform_release < "22.0.0" and python_version == "3.11" +https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/metal/llama_cpp_python-0.3.1-cp310-cp310-macosx_12_0_arm64.whl; platform_system == "Darwin" and platform_release >= "21.0.0" and platform_release < "22.0.0" and python_version == "3.10" +https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/metal/llama_cpp_python-0.3.1-cp311-cp311-macosx_13_0_arm64.whl; platform_system == "Darwin" and platform_release >= "22.0.0" and platform_release < "23.0.0" and python_version == "3.11" +https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/metal/llama_cpp_python-0.3.1-cp310-cp310-macosx_13_0_arm64.whl; platform_system == "Darwin" and platform_release >= "22.0.0" and platform_release < "23.0.0" and python_version == "3.10" +https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/metal/llama_cpp_python-0.3.1-cp311-cp311-macosx_14_0_arm64.whl; platform_system == "Darwin" and platform_release >= "23.0.0" and platform_release < "24.0.0" and python_version == "3.11" +https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/metal/llama_cpp_python-0.3.1-cp310-cp310-macosx_14_0_arm64.whl; platform_system == "Darwin" and platform_release >= "23.0.0" and platform_release < "24.0.0" and python_version == "3.10" +https://github.com/oobabooga/exllamav2/releases/download/v0.2.3/exllamav2-0.2.3-py3-none-any.whl diff --git a/requirements_cpu_only.txt b/requirements_cpu_only.txt index fe2d796ccd..5838ac47c2 100644 --- a/requirements_cpu_only.txt +++ b/requirements_cpu_only.txt @@ -4,7 +4,6 @@ datasets einops fastapi==0.112.4 gradio==4.26.* -hqq==0.1.7.post3 jinja2==3.1.4 lm_eval==0.3.0 markdown @@ -23,7 +22,7 @@ safetensors==0.4.* scipy sentencepiece tensorboard -transformers==4.44.* +transformers==4.45.* tqdm wandb @@ -34,7 +33,7 @@ sse-starlette==1.6.5 tiktoken # llama-cpp-python (CPU only, AVX2) -https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/cpu/llama_cpp_python-0.2.89+cpuavx2-cp311-cp311-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.11" -https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/cpu/llama_cpp_python-0.2.89+cpuavx2-cp310-cp310-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.10" -https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/cpu/llama_cpp_python-0.2.89+cpuavx2-cp311-cp311-win_amd64.whl; platform_system == "Windows" and python_version == "3.11" -https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/cpu/llama_cpp_python-0.2.89+cpuavx2-cp310-cp310-win_amd64.whl; platform_system == "Windows" and python_version == "3.10" +https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/cpu/llama_cpp_python-0.3.1+cpuavx2-cp311-cp311-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.11" +https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/cpu/llama_cpp_python-0.3.1+cpuavx2-cp310-cp310-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.10" +https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/cpu/llama_cpp_python-0.3.1+cpuavx2-cp311-cp311-win_amd64.whl; platform_system == "Windows" and python_version == "3.11" +https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/cpu/llama_cpp_python-0.3.1+cpuavx2-cp310-cp310-win_amd64.whl; platform_system == "Windows" and python_version == "3.10" diff --git a/requirements_cpu_only_noavx2.txt b/requirements_cpu_only_noavx2.txt index 982a12a162..97d4eb9063 100644 --- a/requirements_cpu_only_noavx2.txt +++ b/requirements_cpu_only_noavx2.txt @@ -4,7 +4,6 @@ datasets einops fastapi==0.112.4 gradio==4.26.* -hqq==0.1.7.post3 jinja2==3.1.4 lm_eval==0.3.0 markdown @@ -23,7 +22,7 @@ safetensors==0.4.* scipy sentencepiece tensorboard -transformers==4.44.* +transformers==4.45.* tqdm wandb @@ -34,7 +33,7 @@ sse-starlette==1.6.5 tiktoken # llama-cpp-python (CPU only, no AVX2) -https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/cpu/llama_cpp_python-0.2.89+cpuavx-cp311-cp311-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.11" -https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/cpu/llama_cpp_python-0.2.89+cpuavx-cp310-cp310-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.10" -https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/cpu/llama_cpp_python-0.2.89+cpuavx-cp311-cp311-win_amd64.whl; platform_system == "Windows" and python_version == "3.11" -https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/cpu/llama_cpp_python-0.2.89+cpuavx-cp310-cp310-win_amd64.whl; platform_system == "Windows" and python_version == "3.10" +https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/cpu/llama_cpp_python-0.3.1+cpuavx-cp311-cp311-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.11" +https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/cpu/llama_cpp_python-0.3.1+cpuavx-cp310-cp310-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.10" +https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/cpu/llama_cpp_python-0.3.1+cpuavx-cp311-cp311-win_amd64.whl; platform_system == "Windows" and python_version == "3.11" +https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/cpu/llama_cpp_python-0.3.1+cpuavx-cp310-cp310-win_amd64.whl; platform_system == "Windows" and python_version == "3.10" diff --git a/requirements_noavx2.txt b/requirements_noavx2.txt index fc253f5c27..d34a598d65 100644 --- a/requirements_noavx2.txt +++ b/requirements_noavx2.txt @@ -1,13 +1,10 @@ accelerate==0.33.* -aqlm[gpu,cpu]==1.1.6; platform_system == "Linux" -auto-gptq==0.7.1 -bitsandbytes==0.43.* +bitsandbytes==0.44.* colorama datasets einops fastapi==0.112.4 gradio==4.26.* -hqq==0.1.7.post3 jinja2==3.1.4 lm_eval==0.3.0 markdown @@ -26,7 +23,7 @@ safetensors==0.4.* scipy sentencepiece tensorboard -transformers==4.44.* +transformers==4.45.* tqdm wandb @@ -37,38 +34,30 @@ sse-starlette==1.6.5 tiktoken # llama-cpp-python (CPU only, no AVX2) -https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/cpu/llama_cpp_python-0.2.89+cpuavx-cp311-cp311-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.11" -https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/cpu/llama_cpp_python-0.2.89+cpuavx-cp310-cp310-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.10" -https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/cpu/llama_cpp_python-0.2.89+cpuavx-cp311-cp311-win_amd64.whl; platform_system == "Windows" and python_version == "3.11" -https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/cpu/llama_cpp_python-0.2.89+cpuavx-cp310-cp310-win_amd64.whl; platform_system == "Windows" and python_version == "3.10" +https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/cpu/llama_cpp_python-0.3.1+cpuavx-cp311-cp311-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.11" +https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/cpu/llama_cpp_python-0.3.1+cpuavx-cp310-cp310-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.10" +https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/cpu/llama_cpp_python-0.3.1+cpuavx-cp311-cp311-win_amd64.whl; platform_system == "Windows" and python_version == "3.11" +https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/cpu/llama_cpp_python-0.3.1+cpuavx-cp310-cp310-win_amd64.whl; platform_system == "Windows" and python_version == "3.10" # llama-cpp-python (CUDA, no tensor cores) -https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/textgen-webui/llama_cpp_python_cuda-0.2.89+cu121avx-cp311-cp311-win_amd64.whl; platform_system == "Windows" and python_version == "3.11" -https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/textgen-webui/llama_cpp_python_cuda-0.2.89+cu121avx-cp310-cp310-win_amd64.whl; platform_system == "Windows" and python_version == "3.10" -https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/textgen-webui/llama_cpp_python_cuda-0.2.89+cu121avx-cp311-cp311-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.11" -https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/textgen-webui/llama_cpp_python_cuda-0.2.89+cu121avx-cp310-cp310-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.10" +https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/textgen-webui/llama_cpp_python_cuda-0.3.1+cu121avx-cp311-cp311-win_amd64.whl; platform_system == "Windows" and python_version == "3.11" +https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/textgen-webui/llama_cpp_python_cuda-0.3.1+cu121avx-cp310-cp310-win_amd64.whl; platform_system == "Windows" and python_version == "3.10" +https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/textgen-webui/llama_cpp_python_cuda-0.3.1+cu121avx-cp311-cp311-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.11" +https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/textgen-webui/llama_cpp_python_cuda-0.3.1+cu121avx-cp310-cp310-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.10" # llama-cpp-python (CUDA, tensor cores) -https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/textgen-webui/llama_cpp_python_cuda_tensorcores-0.2.89+cu121avx-cp311-cp311-win_amd64.whl; platform_system == "Windows" and python_version == "3.11" -https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/textgen-webui/llama_cpp_python_cuda_tensorcores-0.2.89+cu121avx-cp310-cp310-win_amd64.whl; platform_system == "Windows" and python_version == "3.10" -https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/textgen-webui/llama_cpp_python_cuda_tensorcores-0.2.89+cu121avx-cp311-cp311-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.11" -https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/textgen-webui/llama_cpp_python_cuda_tensorcores-0.2.89+cu121avx-cp310-cp310-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.10" +https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/textgen-webui/llama_cpp_python_cuda_tensorcores-0.3.1+cu121avx-cp311-cp311-win_amd64.whl; platform_system == "Windows" and python_version == "3.11" +https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/textgen-webui/llama_cpp_python_cuda_tensorcores-0.3.1+cu121avx-cp310-cp310-win_amd64.whl; platform_system == "Windows" and python_version == "3.10" +https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/textgen-webui/llama_cpp_python_cuda_tensorcores-0.3.1+cu121avx-cp311-cp311-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.11" +https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/textgen-webui/llama_cpp_python_cuda_tensorcores-0.3.1+cu121avx-cp310-cp310-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.10" # CUDA wheels -https://github.com/oobabooga/exllamav2/releases/download/v0.1.8/exllamav2-0.1.8+cu121.torch2.2.2-cp311-cp311-win_amd64.whl; platform_system == "Windows" and python_version == "3.11" -https://github.com/oobabooga/exllamav2/releases/download/v0.1.8/exllamav2-0.1.8+cu121.torch2.2.2-cp310-cp310-win_amd64.whl; platform_system == "Windows" and python_version == "3.10" -https://github.com/oobabooga/exllamav2/releases/download/v0.1.8/exllamav2-0.1.8+cu121.torch2.2.2-cp311-cp311-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.11" -https://github.com/oobabooga/exllamav2/releases/download/v0.1.8/exllamav2-0.1.8+cu121.torch2.2.2-cp310-cp310-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.10" -https://github.com/oobabooga/exllamav2/releases/download/v0.1.8/exllamav2-0.1.8-py3-none-any.whl; platform_system == "Linux" and platform_machine != "x86_64" -https://github.com/oobabooga/flash-attention/releases/download/v2.6.1/flash_attn-2.6.1+cu122torch2.2.2cxx11abiFALSE-cp311-cp311-win_amd64.whl; platform_system == "Windows" and python_version == "3.11" -https://github.com/oobabooga/flash-attention/releases/download/v2.6.1/flash_attn-2.6.1+cu122torch2.2.2cxx11abiFALSE-cp310-cp310-win_amd64.whl; platform_system == "Windows" and python_version == "3.10" -https://github.com/Dao-AILab/flash-attention/releases/download/v2.6.1/flash_attn-2.6.1+cu123torch2.2cxx11abiFALSE-cp311-cp311-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.11" -https://github.com/Dao-AILab/flash-attention/releases/download/v2.6.1/flash_attn-2.6.1+cu123torch2.2cxx11abiFALSE-cp310-cp310-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.10" -https://github.com/oobabooga/AutoAWQ/releases/download/0.2.6/autoawq-0.2.6-cp311-cp311-win_amd64.whl; platform_system == "Windows" and python_version == "3.11" -https://github.com/oobabooga/AutoAWQ/releases/download/0.2.6/autoawq-0.2.6-cp310-cp310-win_amd64.whl; platform_system == "Windows" and python_version == "3.10" -https://github.com/oobabooga/AutoAWQ/releases/download/0.2.6/autoawq-0.2.6-cp311-cp311-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.11" -https://github.com/oobabooga/AutoAWQ/releases/download/0.2.6/autoawq-0.2.6-cp310-cp310-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.10" -https://github.com/oobabooga/AutoAWQ_kernels/releases/download/0.0.7/autoawq_kernels-0.0.7-cp311-cp311-win_amd64.whl; platform_system == "Windows" and python_version == "3.11" -https://github.com/oobabooga/AutoAWQ_kernels/releases/download/0.0.7/autoawq_kernels-0.0.7-cp310-cp310-win_amd64.whl; platform_system == "Windows" and python_version == "3.10" -https://github.com/oobabooga/AutoAWQ_kernels/releases/download/0.0.7/autoawq_kernels-0.0.7-cp311-cp311-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.11" -https://github.com/oobabooga/AutoAWQ_kernels/releases/download/0.0.7/autoawq_kernels-0.0.7-cp310-cp310-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.10" +https://github.com/oobabooga/exllamav2/releases/download/v0.2.3/exllamav2-0.2.3+cu121.torch2.4.1-cp311-cp311-win_amd64.whl; platform_system == "Windows" and python_version == "3.11" +https://github.com/oobabooga/exllamav2/releases/download/v0.2.3/exllamav2-0.2.3+cu121.torch2.4.1-cp310-cp310-win_amd64.whl; platform_system == "Windows" and python_version == "3.10" +https://github.com/oobabooga/exllamav2/releases/download/v0.2.3/exllamav2-0.2.3+cu121.torch2.4.1-cp311-cp311-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.11" +https://github.com/oobabooga/exllamav2/releases/download/v0.2.3/exllamav2-0.2.3+cu121.torch2.4.1-cp310-cp310-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.10" +https://github.com/oobabooga/exllamav2/releases/download/v0.2.3/exllamav2-0.2.3-py3-none-any.whl; platform_system == "Linux" and platform_machine != "x86_64" +https://github.com/oobabooga/flash-attention/releases/download/v2.6.3/flash_attn-2.6.3+cu122torch2.4.1cxx11abiFALSE-cp311-cp311-win_amd64.whl; platform_system == "Windows" and python_version == "3.11" +https://github.com/oobabooga/flash-attention/releases/download/v2.6.3/flash_attn-2.6.3+cu122torch2.4.1cxx11abiFALSE-cp310-cp310-win_amd64.whl; platform_system == "Windows" and python_version == "3.10" +https://github.com/Dao-AILab/flash-attention/releases/download/v2.6.3/flash_attn-2.6.3+cu123torch2.4cxx11abiFALSE-cp311-cp311-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.11" +https://github.com/Dao-AILab/flash-attention/releases/download/v2.6.3/flash_attn-2.6.3+cu123torch2.4cxx11abiFALSE-cp310-cp310-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.10" diff --git a/requirements_nowheels.txt b/requirements_nowheels.txt index ec18464d0f..5c840f56de 100644 --- a/requirements_nowheels.txt +++ b/requirements_nowheels.txt @@ -4,7 +4,6 @@ datasets einops fastapi==0.112.4 gradio==4.26.* -hqq==0.1.7.post3 jinja2==3.1.4 lm_eval==0.3.0 markdown @@ -23,7 +22,7 @@ safetensors==0.4.* scipy sentencepiece tensorboard -transformers==4.44.* +transformers==4.45.* tqdm wandb