Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Core] Support Lora lineage and base model metadata management #6315

Merged
merged 7 commits into from
Sep 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 64 additions & 0 deletions docs/source/models/lora.rst
Original file line number Diff line number Diff line change
Expand Up @@ -159,3 +159,67 @@ Example request to unload a LoRA adapter:
-d '{
"lora_name": "sql_adapter"
}'


New format for `--lora-modules`
-------------------------------

In the previous version, users would provide LoRA modules via the following format, either as a key-value pair or in JSON format. For example:

.. code-block:: bash

--lora-modules sql-lora=$HOME/.cache/huggingface/hub/models--yard1--llama-2-7b-sql-lora-test/snapshots/0dfa347e8877a4d4ed19ee56c140fa518470028c/

This would only include the `name` and `path` for each LoRA module, but did not provide a way to specify a `base_model_name`.
Now, you can specify a base_model_name alongside the name and path using JSON format. For example:

.. code-block:: bash

--lora-modules '{"name": "sql-lora", "path": "/path/to/lora", "base_model_name": "meta-llama/Llama-2-7b"}'

To provide the backward compatibility support, you can still use the old key-value format (name=path), but the `base_model_name` will remain unspecified in that case.


Lora model lineage in model card
--------------------------------

The new format of `--lora-modules` is mainly to support the display of parent model information in the model card. Here's an explanation of how your current response supports this:

- The `parent` field of LoRA model `sql-lora` now links to its base model `meta-llama/Llama-2-7b-hf`. This correctly reflects the hierarchical relationship between the base model and the LoRA adapter.
- The `root` field points to the artifact location of the lora adapter.

.. code-block:: bash

$ curl http://localhost:8000/v1/models

{
"object": "list",
"data": [
{
"id": "meta-llama/Llama-2-7b-hf",
"object": "model",
"created": 1715644056,
"owned_by": "vllm",
"root": "~/.cache/huggingface/hub/models--meta-llama--Llama-2-7b-hf/snapshots/01c7f73d771dfac7d292323805ebc428287df4f9/",
"parent": null,
"permission": [
{
.....
}
]
},
{
"id": "sql-lora",
"object": "model",
"created": 1715644056,
"owned_by": "vllm",
"root": "~/.cache/huggingface/hub/models--yard1--llama-2-7b-sql-lora-test/snapshots/0dfa347e8877a4d4ed19ee56c140fa518470028c/",
"parent": meta-llama/Llama-2-7b-hf,
"permission": [
{
....
}
]
}
]
}
91 changes: 91 additions & 0 deletions tests/entrypoints/openai/test_cli_args.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
import json
import unittest

from vllm.entrypoints.openai.cli_args import make_arg_parser
from vllm.entrypoints.openai.serving_engine import LoRAModulePath
from vllm.utils import FlexibleArgumentParser

LORA_MODULE = {
"name": "module2",
"path": "/path/to/module2",
"base_model_name": "llama"
}


class TestLoraParserAction(unittest.TestCase):

def setUp(self):
# Setting up argparse parser for tests
parser = FlexibleArgumentParser(
description="vLLM's remote OpenAI server.")
self.parser = make_arg_parser(parser)

def test_valid_key_value_format(self):
# Test old format: name=path
args = self.parser.parse_args([
'--lora-modules',
'module1=/path/to/module1',
])
expected = [LoRAModulePath(name='module1', path='/path/to/module1')]
self.assertEqual(args.lora_modules, expected)

def test_valid_json_format(self):
# Test valid JSON format input
args = self.parser.parse_args([
'--lora-modules',
json.dumps(LORA_MODULE),
])
expected = [
LoRAModulePath(name='module2',
path='/path/to/module2',
base_model_name='llama')
]
self.assertEqual(args.lora_modules, expected)

def test_invalid_json_format(self):
# Test invalid JSON format input, missing closing brace
with self.assertRaises(SystemExit):
self.parser.parse_args([
'--lora-modules',
'{"name": "module3", "path": "/path/to/module3"'
])

def test_invalid_type_error(self):
# Test type error when values are not JSON or key=value
with self.assertRaises(SystemExit):
self.parser.parse_args([
'--lora-modules',
'invalid_format' # This is not JSON or key=value format
])

def test_invalid_json_field(self):
# Test valid JSON format but missing required fields
with self.assertRaises(SystemExit):
self.parser.parse_args([
'--lora-modules',
'{"name": "module4"}' # Missing required 'path' field
])

def test_empty_values(self):
# Test when no LoRA modules are provided
args = self.parser.parse_args(['--lora-modules', ''])
self.assertEqual(args.lora_modules, [])

def test_multiple_valid_inputs(self):
# Test multiple valid inputs (both old and JSON format)
args = self.parser.parse_args([
'--lora-modules',
'module1=/path/to/module1',
json.dumps(LORA_MODULE),
])
expected = [
LoRAModulePath(name='module1', path='/path/to/module1'),
LoRAModulePath(name='module2',
path='/path/to/module2',
base_model_name='llama')
]
self.assertEqual(args.lora_modules, expected)


if __name__ == '__main__':
unittest.main()
83 changes: 83 additions & 0 deletions tests/entrypoints/openai/test_lora_lineage.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
import json

import openai # use the official client for correctness check
import pytest
import pytest_asyncio
# downloading lora to test lora requests
from huggingface_hub import snapshot_download

from ...utils import RemoteOpenAIServer

# any model with a chat template should work here
MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta"
# technically this needs Mistral-7B-v0.1 as base, but we're not testing
# generation quality here
LORA_NAME = "typeof/zephyr-7b-beta-lora"


@pytest.fixture(scope="module")
def zephyr_lora_files():
return snapshot_download(repo_id=LORA_NAME)


@pytest.fixture(scope="module")
def server_with_lora_modules_json(zephyr_lora_files):
# Define the json format LoRA module configurations
lora_module_1 = {
"name": "zephyr-lora",
"path": zephyr_lora_files,
"base_model_name": MODEL_NAME
}

lora_module_2 = {
"name": "zephyr-lora2",
"path": zephyr_lora_files,
"base_model_name": MODEL_NAME
}

args = [
# use half precision for speed and memory savings in CI environment
"--dtype",
"bfloat16",
"--max-model-len",
"8192",
"--enforce-eager",
# lora config below
"--enable-lora",
"--lora-modules",
json.dumps(lora_module_1),
json.dumps(lora_module_2),
"--max-lora-rank",
"64",
"--max-cpu-loras",
"2",
"--max-num-seqs",
"64",
]

with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
yield remote_server


@pytest_asyncio.fixture
async def client_for_lora_lineage(server_with_lora_modules_json):
async with server_with_lora_modules_json.get_async_client(
) as async_client:
yield async_client


@pytest.mark.asyncio
async def test_check_lora_lineage(client_for_lora_lineage: openai.AsyncOpenAI,
zephyr_lora_files):
models = await client_for_lora_lineage.models.list()
models = models.data
served_model = models[0]
lora_models = models[1:]
assert served_model.id == MODEL_NAME
assert served_model.root == MODEL_NAME
assert served_model.parent is None
assert all(lora_model.root == zephyr_lora_files
for lora_model in lora_models)
assert all(lora_model.parent == MODEL_NAME for lora_model in lora_models)
assert lora_models[0].id == "zephyr-lora"
assert lora_models[1].id == "zephyr-lora2"
6 changes: 4 additions & 2 deletions tests/entrypoints/openai/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,12 +51,14 @@ async def client(server):


@pytest.mark.asyncio
async def test_check_models(client: openai.AsyncOpenAI):
async def test_check_models(client: openai.AsyncOpenAI, zephyr_lora_files):
models = await client.models.list()
models = models.data
served_model = models[0]
lora_models = models[1:]
assert served_model.id == MODEL_NAME
assert all(model.root == MODEL_NAME for model in models)
assert served_model.root == MODEL_NAME
assert all(lora_model.root == zephyr_lora_files
for lora_model in lora_models)
assert lora_models[0].id == "zephyr-lora"
assert lora_models[1].id == "zephyr-lora2"
6 changes: 4 additions & 2 deletions tests/entrypoints/openai/test_serving_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,12 @@
from vllm.engine.multiprocessing.client import MQLLMEngineClient
from vllm.entrypoints.openai.protocol import ChatCompletionRequest
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
from vllm.entrypoints.openai.serving_engine import BaseModelPath
from vllm.transformers_utils.tokenizer import get_tokenizer

MODEL_NAME = "openai-community/gpt2"
CHAT_TEMPLATE = "Dummy chat template for testing {}"
BASE_MODEL_PATHS = [BaseModelPath(name=MODEL_NAME, model_path=MODEL_NAME)]


@dataclass
Expand All @@ -37,7 +39,7 @@ async def _async_serving_chat_init():

serving_completion = OpenAIServingChat(engine,
model_config,
served_model_names=[MODEL_NAME],
BASE_MODEL_PATHS,
response_role="assistant",
chat_template=CHAT_TEMPLATE,
lora_modules=None,
Expand All @@ -58,7 +60,7 @@ def test_serving_chat_should_set_correct_max_tokens():

serving_chat = OpenAIServingChat(mock_engine,
MockModelConfig(),
served_model_names=[MODEL_NAME],
BASE_MODEL_PATHS,
response_role="assistant",
chat_template=CHAT_TEMPLATE,
lora_modules=None,
Expand Down
5 changes: 3 additions & 2 deletions tests/entrypoints/openai/test_serving_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,10 @@
from vllm.entrypoints.openai.protocol import (ErrorResponse,
LoadLoraAdapterRequest,
UnloadLoraAdapterRequest)
from vllm.entrypoints.openai.serving_engine import OpenAIServing
from vllm.entrypoints.openai.serving_engine import BaseModelPath, OpenAIServing

MODEL_NAME = "meta-llama/Llama-2-7b"
BASE_MODEL_PATHS = [BaseModelPath(name=MODEL_NAME, model_path=MODEL_NAME)]
LORA_LOADING_SUCCESS_MESSAGE = (
"Success: LoRA adapter '{lora_name}' added successfully.")
LORA_UNLOADING_SUCCESS_MESSAGE = (
Expand All @@ -25,7 +26,7 @@ async def _async_serving_engine_init():

serving_engine = OpenAIServing(mock_engine_client,
mock_model_config,
served_model_names=[MODEL_NAME],
BASE_MODEL_PATHS,
lora_modules=None,
prompt_adapters=None,
request_logger=None)
Expand Down
14 changes: 10 additions & 4 deletions vllm/entrypoints/openai/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
from vllm.entrypoints.openai.serving_engine import BaseModelPath
from vllm.entrypoints.openai.serving_tokenization import (
OpenAIServingTokenization)
from vllm.logger import init_logger
Expand Down Expand Up @@ -476,13 +477,18 @@ def init_app_state(
else:
request_logger = RequestLogger(max_log_len=args.max_log_len)

base_model_paths = [
BaseModelPath(name=name, model_path=args.model)
for name in served_model_names
]

state.engine_client = engine_client
state.log_stats = not args.disable_log_stats

state.openai_serving_chat = OpenAIServingChat(
engine_client,
model_config,
served_model_names,
base_model_paths,
args.response_role,
lora_modules=args.lora_modules,
prompt_adapters=args.prompt_adapters,
Expand All @@ -494,7 +500,7 @@ def init_app_state(
state.openai_serving_completion = OpenAIServingCompletion(
engine_client,
model_config,
served_model_names,
base_model_paths,
lora_modules=args.lora_modules,
prompt_adapters=args.prompt_adapters,
request_logger=request_logger,
Expand All @@ -503,13 +509,13 @@ def init_app_state(
state.openai_serving_embedding = OpenAIServingEmbedding(
engine_client,
model_config,
served_model_names,
base_model_paths,
request_logger=request_logger,
)
state.openai_serving_tokenization = OpenAIServingTokenization(
engine_client,
model_config,
served_model_names,
base_model_paths,
lora_modules=args.lora_modules,
request_logger=request_logger,
chat_template=args.chat_template,
Expand Down
Loading
Loading