forked from microsoft/Olive
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add OnnxIOFloat16ToFloat32 Pass (microsoft#1149)
## Describe your changes Add OnnxIOFloat16ToFloat32 Pass ## Checklist before requesting a review - [ ] Add unit tests for this change. - [ ] Make sure all tests can pass. - [ ] Update documents if necessary. - [ ] Lint and apply fixes to your code by running `lintrunner -a` - [ ] Is this a user-facing change? If yes, give a description of this change to be included in the release notes. - [ ] Is this PR including examples changes? If yes, please remember to update [example documentation](https://github.com/microsoft/Olive/blob/main/docs/source/examples.md) in a follow-up PR. ## (Optional) Issue link --------- Co-authored-by: Devang Patel <[email protected]> Co-authored-by: Emma <[email protected]>
- Loading branch information
1 parent
8b87910
commit 187feef
Showing
12 changed files
with
508 additions
and
5 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
# Phi3 optimization with Olive | ||
This folder contains an example of optimizing [the Phi-3-Mini-4K-Instruct](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct) model in HF for different hardware targets with Olive. | ||
|
||
|
||
## Prerequisites | ||
* einops | ||
* Pytorch: >=2.2.0 \ | ||
_The [official website](https://pytorch.org/) offers packages compatible with CUDA 11.8 and 12.1. Please select the appropriate version according to your needs._ | ||
* [Package onnxruntime](https://onnxruntime.ai/docs/install/#inference-install-table-for-all-languages): >=1.18.0 | ||
* [Package onnxruntime-genai](https://github.com/microsoft/onnxruntime-genai): >=0.2.0. If you target GPU, pls install onnxruntime and onnxruntime-genai gpu packages. | ||
|
||
Install the dependencies | ||
``` | ||
pip install -r requirements.txt | ||
``` | ||
|
||
## Usage | ||
we will use the `phi3.py` script to generate optimized model for a chosen hardware target by running the following commands. | ||
|
||
``` | ||
python phi3.py [--target HARDWARE_TARGET] [--precision DATA_TYPE] [--inference] [--prompt PROMPT] [--max_length LENGTH] | ||
# Examples | ||
python phi3.py --target web | ||
python phi3.py --target mobile --inference --prompt "Write a story starting with once upon a time" --max_length 200 | ||
``` | ||
|
||
- `--target`: cpu, cuda, mobile, web | ||
- `--precision`: optional. fp32, fp16, int4. fp32 or int4(default) for cpu target; fp32 or fp16 or int4(default) for gpu target; int4(default) for mobile or web | ||
- `--inference`: run the optimized model, for non-web models inference. | ||
- `--prompt`: optional, the prompt text fed into the model. Take effect only when `--inference` is set. | ||
- `--max_length`: optional, the max length of the output from the model. Take effect only when `--inference` is set. | ||
|
||
|
||
This script includes | ||
1. Generate the Olive configuration file for your need including the chosen HW target, the preferred model precision. | ||
2. Generate optimized model with Olive based on the configuration file for the chosen HW target | ||
3. (optional) Inference the optimized model with ONNX Runtime Generation API. Not supported for web target | ||
|
||
|
||
If you have an Olive configuration file, you can also run the olive command for model generation: | ||
``` | ||
olive run [--config CONFIGURATION_FILE] | ||
# Examples | ||
olive run --config phi3_mobile_int4.json | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,212 @@ | ||
# ------------------------------------------------------------------------- | ||
# Copyright (c) Microsoft Corporation. All rights reserved. | ||
# Licensed under the MIT License. | ||
# -------------------------------------------------------------------------- | ||
|
||
import argparse | ||
import json | ||
import time | ||
from pathlib import Path | ||
|
||
import onnxruntime_genai as og | ||
|
||
from olive.workflows import run as olive_run | ||
|
||
# flake8: noqa: T201 | ||
|
||
|
||
TARGETS = ["cpu", "cuda", "mobile", "web"] | ||
|
||
TARGET_TO_EP = { | ||
"cpu": "CPUExecutionProvider", | ||
"mobile": "CPUExecutionProvider", | ||
"cuda": "CUDAExecutionProvider", | ||
"web": "JsExecutionProvider", | ||
} | ||
|
||
|
||
def get_args(raw_args): | ||
parser = argparse.ArgumentParser(description="phi3 optimization") | ||
|
||
parser.add_argument( | ||
"--target", | ||
type=str, | ||
default=None, | ||
required=True, | ||
choices=TARGETS, | ||
help="Choose from cpu, cuda, mobile or web", | ||
) | ||
parser.add_argument( | ||
"--precision", | ||
type=str, | ||
default="int4", | ||
choices=["fp32", "fp16", "int4"], | ||
help="Choose from fp32 or int4(default) for cpu target; " | ||
"fp32 or fp16 or int4(default) for gpu target; int4(default) for mobile or web", | ||
) | ||
parser.add_argument( | ||
"--inference", | ||
action="store_true", | ||
help="Run inference with optimized model", | ||
) | ||
parser.add_argument( | ||
"--prompt", | ||
nargs="*", | ||
type=str, | ||
default=["Write a joke"], | ||
help="The prompt text fed into the model. Not supported with Web target.", | ||
) | ||
parser.add_argument( | ||
"--max_length", | ||
type=int, | ||
default=200, | ||
help="Max length for generation. Not supported with Web target.", | ||
) | ||
|
||
return parser.parse_args(raw_args) | ||
|
||
|
||
def main(raw_args=None): | ||
args = get_args(raw_args) | ||
if args.target in ("mobile", "web") and args.precision != "int4": | ||
raise ValueError("mobile or web only supports int4(default)") | ||
elif args.target == "cpu" and args.precision == "fp16": | ||
raise ValueError("Choose from fp32 or int4(default) for cpu target") | ||
|
||
if args.inference and args.target == "web": | ||
raise ValueError("Web model inference is not supported in this script") | ||
|
||
# Generate Olive configuration file for specific target | ||
print("\nGenerating Olive configuration file...") | ||
config_file = generate_config(args) | ||
print("Olive configuration file is generated...\n") | ||
|
||
# Generate optimized model for specific target | ||
print("Generating optimized model for", args.target, " ...\n") | ||
footprints = olive_run(config_file) | ||
if footprints: | ||
print("\nOptimized model is generated...") | ||
|
||
if args.inference: | ||
prompts = "Write a joke" if not args.prompt else "".join(args.prompt) | ||
|
||
chat_template = "<|user|>\n{input}<|end|>\n<|assistant|>" | ||
prompts = f"{chat_template.format(input=prompts)}" | ||
|
||
max_length = 200 if not args.max_length else args.max_length | ||
|
||
output_model_path = get_output_model_path(footprints) | ||
genai_run(prompts, str(output_model_path), max_length) | ||
|
||
|
||
def generate_config(args): | ||
|
||
json_file_template = "phi3_template.json" | ||
with open(json_file_template) as f: | ||
template_json = json.load(f) | ||
|
||
target = str(args.target) | ||
device = "GPU" if target in ("cuda", "web") else "CPU" | ||
execution_providers = [TARGET_TO_EP[target.lower()]] | ||
template_json["systems"]["local_system"]["config"]["accelerators"] = [ | ||
{"device": device, "execution_providers": execution_providers} | ||
] | ||
|
||
model_builder = { | ||
"type": "ModelBuilder", | ||
"config": { | ||
"precision": args.precision, | ||
}, | ||
} | ||
template_json["passes"]["builder"] = model_builder | ||
|
||
if target == "mobile": | ||
template_json["passes"]["builder"]["config"]["int4_accuracy_level"] = 4 | ||
|
||
elif target == "web": | ||
fl_type = {"type": "OnnxIOFloat16ToFloat32"} | ||
template_json["passes"]["fp32_logits"] = fl_type | ||
|
||
new_json_file = f"phi3_{target.lower()}_{args.precision}.json" | ||
with open(new_json_file, "w") as f: | ||
json.dump(template_json, f, indent=4) | ||
|
||
return new_json_file | ||
|
||
|
||
def get_output_model_path(footprints): | ||
# only one model output in phi2 optimization | ||
for footprint in footprints.values(): | ||
for model_id in footprint.nodes: | ||
model_path = Path(footprint.get_model_path(model_id)) | ||
break | ||
return model_path | ||
|
||
|
||
def genai_run(prompt, model_path, max_length): | ||
|
||
print("\nModel inference starts...") | ||
|
||
print("Loading model...") | ||
app_started_timestamp = time.time() | ||
model = og.Model(model_path) | ||
model_loaded_timestamp = time.time() | ||
print("Model loaded in {:.2f} seconds".format(model_loaded_timestamp - app_started_timestamp)) | ||
|
||
print("Creating tokenizer...") | ||
tokenizer = og.Tokenizer(model) | ||
tokenizer_stream = tokenizer.create_stream() | ||
input_tokens = tokenizer.encode(prompt) | ||
started_timestamp = time.time() | ||
|
||
print("Creating generator ...") | ||
params = og.GeneratorParams(model) | ||
# optimal search options for Phi3 | ||
search_options = { | ||
"max_length": max_length, | ||
"top_k": 40, | ||
"top_p": 0.95, | ||
"temperature": 0.8, | ||
"repetition_penalty": 1.0, | ||
} | ||
params.set_search_options(**search_options) | ||
params.input_ids = input_tokens | ||
generator = og.Generator(model, params) | ||
print("Generator created") | ||
|
||
first = True | ||
first_token_timestamp = None | ||
new_tokens = [] | ||
|
||
print("\n", prompt) | ||
|
||
try: | ||
while not generator.is_done(): | ||
generator.compute_logits() | ||
generator.generate_next_token() | ||
if first: | ||
first_token_timestamp = time.time() | ||
first = False | ||
|
||
new_token = generator.get_next_tokens()[0] | ||
print(tokenizer_stream.decode(new_token), end="", flush=True) | ||
new_tokens.append(new_token) | ||
except KeyboardInterrupt: | ||
print(" --control+c pressed, aborting generation--") | ||
|
||
del generator | ||
|
||
run_time = time.time() - started_timestamp | ||
if first_token_timestamp is None: | ||
print("\n\nNo tokens generated") | ||
else: | ||
print( | ||
"\n\n" | ||
f"Prompt tokens: {len(input_tokens)}, New tokens: {len(new_tokens)}," | ||
f" Time to first: {(first_token_timestamp - started_timestamp):.2f}s," | ||
f" New tokens per second: {len(new_tokens)/run_time:.2f} tps" | ||
) | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
{ | ||
"input_model":{ | ||
"type": "PyTorchModel", | ||
"config": { | ||
"hf_config": { | ||
"model_name": "microsoft/Phi-3-mini-4k-instruct", | ||
"task": "text-generation", | ||
"from_pretrained_args": { | ||
"trust_remote_code": true | ||
} | ||
} | ||
} | ||
}, | ||
"systems": { | ||
"local_system": { | ||
"type": "LocalSystem", | ||
"config": { | ||
"accelerators": [ | ||
{ | ||
"device": "CPU", | ||
"execution_providers": [ | ||
"CPUExecutionProvider" | ||
] | ||
} | ||
] | ||
} | ||
} | ||
}, | ||
"passes": { | ||
|
||
}, | ||
"engine": { | ||
"cache_dir": "cache", | ||
"output_dir": "Opt_model", | ||
"host": "local_system", | ||
"target": "local_system" | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
einops | ||
onnx>=1.15.0 | ||
onnxscript>=0.1.0.dev20240126 | ||
torch>=2.2.0 | ||
transformers>=4.36.2 |
Oops, something went wrong.