Это демо приложение на базе Gradio, демонстрирующее возможности горячей замены PEFT-адаптеров, а именно LoRA, над одной и той же LLM прямо в Runtime. Выполнено в рамках тестового задания на позицию ML Engineer.
Создать окружение Python 3.10.10 с помощью Conda или Pyenv:
conda create -n myenv python=3.10.10 && conda activate myenv
Установить необходимые пакеты:
pip install -r requirements.txt
И запустить:
python -m app
Собрать образ:
docker build -t llm-lora-hotswap .
И запустить приложение:
docker run --gpus all --name hotswap-app --net host --rm -it llm-lora-hotswap
После запуска приложение должно быть доступно на 7860 порту localhost (порт Gradio по умолчанию).
В данном проекте я набросал черновую структуру классов, которые могли бы быть прототипом для решения в продакшене. Поскольку у каждого адаптера могут быть свои нюансы токенизации, пре/пост-обработки сообщений, они инкапсулируются в классах-наследниках LLMAdapterBase
.
Для самих ответов LLM я реализовал поддержку стриминга токенов, чтобы можно было отдать первый токен как можно быстрее на целевой интерфейс (в данном случае — в UI чата).
- Llama2 7B GPTQ — LLM, квантизованная с помощью метода GPTQ до 4bit. Выбрал её вместо квантизации через
bitsandbytes
, поскольку по тестам GPTQ даёт выше качество итоговой модели. На моей локальной машине установлена RTX 3070 на 8Gb VRAM, поэтому нужна была хотя бы 4bit версия - Saiga2 LoRA — адаптер поверх Llama 2, дообученный на инструктивно-диалоговом датасете Сайга
- Llama 2 LoRA OpenAssistant Guanaco (блогпост) — адаптер, дообученный на очищенной части датасета OpenAssistant (OASTT)
Цель данного демо — показать с помощью простого кода реализацию горячей замены и предоставить интерактивный интерфейс для демонстрации работы. Если бы потребовалось реализовывать подобный функционал в виде REST API, я бы посмотрел такие решения как OpenLLM. Согласно вот этому обзору фреймворков для инференса и текущей документации, OpenLLM — единственный, который поддерживает адаптеры и их подмену в Runtime.
Однако OpenLLM сам по себе, и в особенности с адаптерами, будет давать низкий RPS и высокий Latency. Дело в том, что при сёрвинге в продакшене кучи адаптеров страдает батчинг запросов — следует задуматься над более эффективной утилизацией GPU и формированием батчей. Я нашёл пару многообещающих решений для этой проблемы:
Другой открытый вопрос — допустимо ли использовать адаптер, обученный поверх модели в половинной точности, с моделью квантизованной до 4bit. Допускаю, что может присутствовать деградация в качестве вывода такого адаптера. В продакшен разработке следовало бы проверить качество этой связки на downstream задачах.
Кроме того, при обучении своего адаптера под такой юзкейс я бы сразу смотрел в сторону quantization-aware методов:
- LoftQ: LoRA-Fine-Tuning-Aware Quantization for Large Language Models
- QA-LoRA: Quantization-Aware Low-Rank Adaptation of Large Language Models
- LQ-LoRA: Low-rank plus Quantized Matrix Decomposition for Efficient Language Model Finetuning
Основную массу времени я потратил на поиск базовой модели и подходящих адаптеров. Согласно заданию, я нацелился на использование Saiga2 от Ильи Гусева, однако она в свою очередь является файнтюном над базовой Llama2, а не Chat/Instruct версией — что, как оказалось, редкость, если хочется использовать модель в чатботе. Большинство LoRA-адаптеров для чата файнтюнятся именно от Chat/Instruct-модели.
Также оказалась не совсем прозрачной логика методов add_adapter()
, set_adapter()
, load_adapter()
из библиотеки PEFT. Так, к примеру, добавление адаптера с помощью конфига и метода add_adapter()
не инициализирует сами веса адаптера и, судя по всему, нацелено именно на юзкейс файнтюнинга модели.
Для инференса же необходимо вызывать именно load_adapter()
с указанием идентификатора модели-адаптера с хаба (или локальной папки). Чтобы разобраться с этим, пришлось посмотреть код соответвующих методов, поскольку документация на момент написания очень расплывчатая.
Было забавно наблюдать, как подключенный адаптер Сайги не работает и базовая модель при виде русских символов в промпте выдаёт код, причём на C/C++ и под платформу Windows...
Кроме того, пришлось дописать некоторую логику по пост-процессингу вывода модели OASTT и попотеть над подбором гиперпараметров для генерации. Так модель очевидно плохо уловила специфику чата и старается продолжать реплики за человека. Поэтому я отлавливаю токены, соответствующие началу реплики ### Human:
и останавливаю генерацию на них.