Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ExtractAdapters: Extract lora adapters and use them as model inputs or external initializers #1064

Merged
merged 10 commits into from
Apr 10, 2024
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions docs/source/api/passes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,12 @@ InsertBeamSearch
--------------------
.. autoconfigclass:: olive.passes.InsertBeamSearch

.. _extract_adapters:

ExtractAdapters
----------------
.. autoconfigclass:: olive.passes.ExtractAdapters

.. _lora:

LoRA
Expand Down
31 changes: 31 additions & 0 deletions docs/source/features/passes/onnx.md
Original file line number Diff line number Diff line change
Expand Up @@ -462,3 +462,34 @@ b. Making the entire input shape fixed
Note: The `input_dim` and `dim_value` should have the same length, and the `input_name` and `input_shape` should have the same length. Also the `input_dim & dim_value` and `input_name & input_shape` should be exclusive to each other, user cannot specify both of them at the same time.

More details about the pass and its config parameters can be found [here](https://onnxruntime.ai/docs/tutorials/mobile/helpers/make-dynamic-shape-fixed.html).

## Extract Adapters

LoRA, QLoRA and related techniques allow us to fine-tune a pre-trained model by adding a small number of trainable matrices called adapters. The same base model can be used for multiple tasks by adding different adapters for each task. To support using multiple adapters with the same optimized onnx model, the `ExtractAdapters` pass extracts the adapters weights from the model and saves them to a separate file. The model graph is then modified in one of the following ways:
- Adapters weights are set as external tensors pointing to a non-existent file. The onnx model is thus invalid by itself as it cannot be loaded. In order to create an inference session using this model, the adapter weights must be added to a sessions options object using `add_initializer` or `add_external_initializers`.
- Adapter weights are converted into model inputs. The onnx model is valid. During inference, the adapter weights must be provided as part of the inputs. We call them constant inputs here since these weights don't change between runs when using the one set of adapters.

### Example Configuration

a. As external initializers
```json
{
"type": "ExtractAdapters",
"config": {
"make_inputs": false
}
}
```

b. As constant inputs with packed weights
```json
{
"type": "ExtractAdapters",
"config": {
"make_inputs": true,
"pack_inputs": true
}
}
```

Please refer to [ExtractAdapters](extract_adapters) for more details about the pass and its config parameters.
1 change: 1 addition & 0 deletions docs/source/overview/options.md
Original file line number Diff line number Diff line change
Expand Up @@ -417,6 +417,7 @@ Please also find the detailed options from following table for each pass:
| [IncStaticQuantization](inc_static_quantization) | Intel® Neural Compressor Static Quantization Pass. |
| [IncQuantization](inc_quantization) | Quantize ONNX model with Intel® Neural Compressor where we can search for best parameters for static/dynamic quantization at same time. |
| [DynamicToFixedShape](dynamic_to_fixed_shape) | Convert dynamic shape to fixed shape for ONNX model |
| [ExtractAdapters](extract_adapters) | Extract adapters from ONNX model |
| [QuantizationAwareTraining](onnx_quantization_aware_training) | Run quantization aware training on PyTorch model. |
| [OpenVINOConversion](openvino_conversion) | Converts PyTorch, ONNX or TensorFlow Model to OpenVino Model. |
| [OpenVINOQuantization](openvino_quantization) | Post-training quantization for OpenVINO model. |
Expand Down
66 changes: 65 additions & 1 deletion examples/llama2/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@ For Llama2 inference with DirectML on GPUs, pls refer to this [example](https://
### Fine-tune on a code generation dataset using QLoRA and optimize using ONNX Runtime Tools
This workflow fine-tunes Open LLaMA model using [QLoRA](https://arxiv.org/abs/2305.14314) to generate code given a prompt. The fine-tuned model is then optimized using ONNX Runtime Tools.
Performs optimization pipeline:
- GPU, NF4: *Pytorch Model -> Fine-tuned Pytorch Model -> Onnx Model -> Transformers Optimized Onnx Model fp16 -> Onnx Bitsandbytes 4bit Quantization*
- GPU, FP16: *Pytorch Model -> Fine-tuned Pytorch Model -> Onnx Model -> Transformers Optimized Onnx Model fp16 -> Extract Adapter*
<!-- TODO(jambayk): check if bnb quantization works between different adapters -->

**Note:**
- This workflow is only supported for GPU.
Expand All @@ -46,13 +47,76 @@ Supported languages are Python, TypeScript, JavaScript, Ruby, Julia, Rust, C++,

Requirements file: [requirements-qlora.txt](requirements-qlora.txt)

**Extracted Adapters**

The workflow above extracts the lora adapters from the fine-tuned model and converts them into inputs for the model. This way, you can use adapters for the same base model with different tasks.
Pre-existing adapters can be exported directly using the following command:
```bash
# change the adapter_path to the path of the adapter you want to export
# ensure that the target modules are the same as those of the above fine-tuned model
python -m olive.scripts.export_adapters --adapter_path Mikael110/llama-2-7b-guanaco-qlora --dtype float16 --pack_weights --output_path models/guanaco_fp16_packed.npz
```

Snippet below shows an example runs of the generated fine-tuned model using two different adapters.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As a follow up PR, we want to move this snippet into a working Jypter notebook.

```python
import numpy as np
# optimum needs to be installed from git+https://github.com/jambayk/optimum.git@jambayk/constant-inputs
from onnxruntime import InferenceSession
from optimum.onnxruntime import ORTModelForCausalLM
import torch
from transformers import AutoConfig, AutoTokenizer

device = torch.device("cuda")

base_model = "meta-llama/Llama-2-7b-hf"
config = AutoConfig.from_pretrained(base_model)
tokenizer = AutoTokenizer.from_pretrained(base_model)

# the path to the optimized model
model_path = "models/qlora/qlora-conversion-transformers_optimization-extract/gpu-cuda_model/model.onnx"
tiny_codes_adapter_path = "models/qlora/qlora-conversion-transformers_optimization-extract/gpu-cuda_model/adapter_weights.npz"
guanaco_adapter_path = "models/guanaco_fp16_packed.npz"

# load the adapters and put them on the device
tiny_codes_weights = np.load(tiny_codes_adapter_path)
tiny_codes_weights = {k: torch.tensor(v).to(device) for k, v in tiny_codes_weights.items()}
guanaco_weights = np.load(guanaco_adapter_path)
guanaco_weights = {k: torch.tensor(v).to(device) for k, v in guanaco_weights.items()}

# load the model
# io-binding is recommended for optimal performance, the adapters weights are already on the device and don't change
# during generation loop (called constant_inputs here)
session = InferenceSession(model_path, providers=["CUDAExecutionProvider"])
model = ORTModelForCausalLM(session, config=config, preprocessors=[tokenizer], use_cache=True, use_io_binding=True)

# prompt
prompt = "What time is it?"

# generate using tiny_codes adapters
model.constant_inputs = tiny_codes_weights
formatted_prompt = f"### Question: {prompt} \n### Answer:"
inputs = tokenizer(formatted_prompt, return_tensors="pt").to(device)
print("Tiny Codes Adapters:")
outputs = model.generate(inputs=inputs.input_ids, max_new_tokens=150)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))

# generate using guanaco adapters
model.constant_inputs = guanaco_weights
formatted_prompt = f"### Human: {prompt} ### Assistant:"
inputs = tokenizer(formatted_prompt, return_tensors="pt").to(device)
print("Guanaco Adapters:")
outputs = model.generate(inputs=inputs.input_ids, max_new_tokens=150)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
```

### Inference optimization using ONNX Runtime GenAI
For using ONNX runtime GenAI to optimize, follow build and installation instructions [here](https://github.com/microsoft/onnxruntime-genai).

Run the following command to execute the workflow:
```bash
python -m olive.workflows.run --config lamma2_genai.json
```

Snippet below shows an example run of generated llama2 model.
```python
import onnxruntime_genai as og
Expand Down
12 changes: 4 additions & 8 deletions examples/llama2/llama2_qlora.json
Original file line number Diff line number Diff line change
Expand Up @@ -125,17 +125,13 @@
"opt_level": 0,
"only_onnxruntime": false,
"keep_io_types": false,
"float16": true,
"optimization_options": {
"enable_rotary_embeddings": false
}
"float16": true
}
},
"bnb_quantization": {
"type": "OnnxBnb4Quantization",
"extract": {
"type": "ExtractAdapters",
"config": {
"save_as_external_data": true,
"all_tensors_to_one_file": true
"make_inputs": true
}
}
},
Expand Down
1 change: 1 addition & 0 deletions examples/llama2/requirements-qlora.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
-r requirements.txt
accelerate
bitsandbytes
git+https://github.com/jambayk/optimum.git@jambayk/constant-inputs
guotuofeng marked this conversation as resolved.
Show resolved Hide resolved
peft
scikit-learn
sentencepiece
53 changes: 47 additions & 6 deletions olive/common/ort_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import numpy as np

if TYPE_CHECKING:
from numpy.typing import NDArray
from onnxruntime import InferenceSession, IOBinding


Expand All @@ -32,6 +33,7 @@ def get_ort_inference_session(
inference_settings: Dict[str, Any],
use_ort_extensions: bool = False,
device_id: Optional[int] = None,
external_initializers: Optional[Dict[str, "NDArray"]] = None,
):
"""Get an ONNXRuntime inference session.

Expand All @@ -43,6 +45,8 @@ def get_ort_inference_session(
provider_options: list, optional. List of provider options for the execution providers.
:param use_ort_extensions: Whether to use onnxruntime-extensions. Default is False.
:param device_id: Optional device id to use for CUDA or DML execution providers.
:param external_initializers: Optional external initializers for the session. A dictionary of external initializer
names and numpy arrays.
"""
import onnxruntime as ort

Expand All @@ -52,6 +56,18 @@ def get_ort_inference_session(
from onnxruntime_extensions import get_library_path

sess_options.register_custom_ops_library(get_library_path())
if external_initializers:
from onnxruntime import OrtValue

# convert external initializers to OrtValue
initializer_names = []
initializer_values = []
for name, value in external_initializers.items():
initializer_names.append(name)
initializer_values.append(OrtValue.ortvalue_from_numpy(value))

# add external initializers to the session
sess_options.add_external_initializers(initializer_names, initializer_values)

logger.debug("inference_settings: %s", inference_settings)

Expand Down Expand Up @@ -216,14 +232,32 @@ def __init__(
device: str = "cpu",
shared_kv_buffer: bool = False,
use_fp16: bool = False,
input_feed: Optional[Dict[str, np.ndarray]] = None,
input_feed: Optional[Dict[str, "NDArray"]] = None,
constant_inputs: Optional[Dict[str, "NDArray"]] = None,
):
"""Initialize self.

:param session: ONNXRuntime InferenceSession
:param io_bind: Whether to use IO binding. Default is False.
:param device: Device to run inference on. Default is "cpu".
:param shared_kv_buffer: Whether to share the key/value buffer across multiple runs.
Default is False. Only valid if io_bind is True.
:param use_fp16: Whether to use fp16. Default is False. Both shared_kv_buffer and use_fp16 must be True
at the same time to use shared key/value buffer.
:param input_feed: Optional input feed for the session. Required when shared_kv_buffer and use_fp16 are True.
:param constant_inputs: Optional constant inputs for the session. These will be passed to the session every
inference run.
"""
# TODO(anyone): use_fp16 is redundant with shared_kv_buffer. Remove it.
self.session = session
self.io_bind = io_bind
self.device = device
self.shared_kv_buffer = shared_kv_buffer
self.use_fp16 = use_fp16
self.kv_cache_ortvalues = {} if (self.shared_kv_buffer and self.use_fp16) else None
# TODO(jambayk): investigate if io binding can be run without having to bind constant
# inputs every time.
self.constant_inputs = constant_inputs or {}

self.io_binding = None
if self.io_bind:
Expand All @@ -232,7 +266,7 @@ def __init__(
assert input_feed is not None, "input_feed is required when shared_kv_buffer and use_fp16 are True"
bind_input_data(
self.io_binding,
input_feed,
{**input_feed, **self.constant_inputs},
self.use_fp16,
self.device,
shared_kv_buffer=self.shared_kv_buffer,
Expand All @@ -247,8 +281,13 @@ def __init__(
kv_cache_ortvalues=self.kv_cache_ortvalues,
)

def run(self, input_feed: Dict[str, np.ndarray]) -> Sequence[np.ndarray]:
def get_full_input_feed(self, input_feed: Dict[str, "NDArray"]) -> Dict[str, "NDArray"]:
"""Get the full input feed including constant inputs."""
return {**input_feed, **self.constant_inputs}

def run(self, input_feed: Dict[str, "NDArray"]) -> Sequence["NDArray"]:
"""Run inference with the given input data."""
input_feed = self.get_full_input_feed(input_feed)
if self.io_bind and self.device == "gpu":
bind_input_data(
self.io_binding,
Expand All @@ -268,9 +307,10 @@ def run(self, input_feed: Dict[str, np.ndarray]) -> Sequence[np.ndarray]:
return res

def time_run(
self, input_feed: Dict[str, np.ndarray], num_runs: int, num_warmup: int = 0, sleep_time: int = 0
self, input_feed: Dict[str, "NDArray"], num_runs: int, num_warmup: int = 0, sleep_time: int = 0
) -> Sequence[float]:
"""Time inference runs with the given input data."""
input_feed = self.get_full_input_feed(input_feed)
latencies = []
if self.io_bind:
bind_input_data(
Expand Down Expand Up @@ -302,7 +342,7 @@ def time_run(

def bind_input_data(
io_bind_op: "IOBinding",
input_data: Dict[str, np.ndarray],
input_data: Dict[str, "NDArray"],
use_fp16: bool,
device: str,
device_id: int = 0,
Expand Down Expand Up @@ -351,7 +391,7 @@ def bind_output_data(

def prepare_io_bindings(
session: "InferenceSession",
input_data: Dict[str, np.ndarray],
input_data: Dict[str, "NDArray"],
device: str,
device_id: int = 0,
shared_kv_buffer: bool = False,
Expand All @@ -366,6 +406,7 @@ def prepare_io_bindings(
shared_kv_buffer: whether to share the key/value buffer across multiple runs, it is False by default,
and only used when we observe kv cache and fp16 is used.
TODO(trajep): how shared_kv_buffer works with generation task
kv_cache_ortvalues: dict of OrtValue for shared kv cache, it is None by default.
"""
use_fp16 = any(v.dtype == np.float16 for v in input_data.values())
io_bind_op = session.io_binding()
Expand Down
6 changes: 6 additions & 0 deletions olive/evaluator/olive_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -447,6 +447,11 @@ def get_session_wrapper(
if io_bind and shared_kv_buffer and use_fp16:
input_feed = OnnxEvaluator.format_input(next(iter(dataloader))[0], io_config)

# load constant inputs if any
constant_inputs = None
if model.constant_inputs_path:
constant_inputs = OnnxEvaluator.format_input(dict(np.load(model.constant_inputs_path)), io_config)

# create session wrapper
session_wrapper = OrtInferenceSession(
session,
Expand All @@ -455,6 +460,7 @@ def get_session_wrapper(
shared_kv_buffer=shared_kv_buffer,
use_fp16=use_fp16,
input_feed=input_feed,
constant_inputs=constant_inputs,
)

return session_wrapper, inference_settings
Expand Down
Loading
Loading