-
Notifications
You must be signed in to change notification settings - Fork 48
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Stable Diffusion using aot.export and external parameters #217
Merged
Merged
Changes from all commits
Commits
Show all changes
28 commits
Select commit
Hold shift + click to select a range
3322a70
Initial SD 1.5 inference script
aviator19941 781516e
Add CLIP test
aviator19941 a0e4a08
Use aot.export for CLIPTextModel inference module
aviator19941 e3641af
[WIP] Add Unet CompiledModule example
aviator19941 b24546f
Fix usage of vae to decode latents into real images
aviator19941 d9d7158
[WIP] Debug Unet and VAE nan values
aviator19941 cde951a
Add linalg mlir for debugging
aviator19941 5736574
[WIP] Use linalg for debugging
aviator19941 ceaf350
Change empty.memory_format to aten.zeros.default to fix VAE
aviator19941 33978b4
Fix unet and add torch tests
aviator19941 ced8dc6
Rename to sd1.4_inference
aviator19941 7c24f1f
[WIP] Update CLIP 1.4 example to export parameters/save "stripped" .mlir
aviator19941 1d1c1a1
Load weights at runtime for CLIP
aviator19941 c27efb4
[WIP] Fix batch size for encoder_hidden_states
aviator19941 7f4a455
Start cleaning up code
aviator19941 bba568c
Finish clip example
aviator19941 918d91c
Finish clip and unet scripts
aviator19941 dd072b1
Finish vae script
aviator19941 c83d56c
Add hf token flag and fix vae output comparison
aviator19941 cb199ea
Move scripts to turbine_models/custom_models
aviator19941 65cc5e6
Fix formatting
aviator19941 d8e3306
Move reusable functions to utils
aviator19941 c70db4c
Fix black formatting for utils
aviator19941 23cfead
Address Dan's comments
aviator19941 865cedd
Rename sd_inference files and add tests to turbine_models ci
aviator19941 df6ad67
Add 2.1 test and add requirements for SD
aviator19941 904f519
Add accelerate and diffusers to setup.py
aviator19941 990d3c9
Address comments, make tests run vmfb, add device support
aviator19941 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
201 changes: 201 additions & 0 deletions
201
python/turbine_models/custom_models/sd_inference/clip.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,201 @@ | ||
# Copyright 2023 Nod Labs, Inc | ||
# | ||
# Licensed under the Apache License v2.0 with LLVM Exceptions. | ||
# See https://llvm.org/LICENSE.txt for license information. | ||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
|
||
import os | ||
import sys | ||
import re | ||
|
||
from iree import runtime as ireert | ||
import iree.compiler as ireec | ||
from iree.compiler.ir import Context | ||
import numpy as np | ||
from shark_turbine.aot import * | ||
from turbine_models.custom_models.sd_inference import utils | ||
import torch | ||
import torch._dynamo as dynamo | ||
from transformers import CLIPTextModel, CLIPTokenizer | ||
|
||
import argparse | ||
|
||
parser = argparse.ArgumentParser() | ||
parser.add_argument( | ||
"--hf_auth_token", type=str, help="The Hugging Face auth token, required" | ||
) | ||
parser.add_argument( | ||
"--hf_model_name", | ||
type=str, | ||
help="HF model name", | ||
default="CompVis/stable-diffusion-v1-4", | ||
) | ||
parser.add_argument("--run_vmfb", action="store_true") | ||
parser.add_argument("--compile_to", type=str, help="torch, linalg, vmfb") | ||
parser.add_argument("--external_weight_file", type=str, default="") | ||
parser.add_argument("--vmfb_path", type=str, default="") | ||
parser.add_argument( | ||
"--external_weights", | ||
type=str, | ||
default=None, | ||
help="saves ir/vmfb without global weights for size and readability, options [safetensors]", | ||
) | ||
parser.add_argument("--device", type=str, default="cpu", help="cpu, cuda, vulkan, rocm") | ||
# TODO: Bring in detection for target triple | ||
parser.add_argument( | ||
"--iree_target_triple", | ||
type=str, | ||
default="", | ||
help="Specify vulkan target triple or rocm/cuda target device.", | ||
) | ||
parser.add_argument("--vulkan_max_allocation", type=str, default="4294967296") | ||
|
||
prompt = ["a photograph of an astronaut riding a horse"] | ||
|
||
|
||
def export_clip_model( | ||
hf_model_name, | ||
hf_auth_token=None, | ||
compile_to="torch", | ||
external_weights=None, | ||
external_weight_file=None, | ||
device=None, | ||
target_triple=None, | ||
max_alloc=None, | ||
): | ||
# Load the tokenizer and text encoder to tokenize and encode the text. | ||
tokenizer = CLIPTokenizer.from_pretrained( | ||
hf_model_name, | ||
subfolder="tokenizer", | ||
token=hf_auth_token, | ||
) | ||
text_encoder_model = CLIPTextModel.from_pretrained( | ||
hf_model_name, | ||
subfolder="text_encoder", | ||
token=hf_auth_token, | ||
) | ||
|
||
mapper = {} | ||
utils.save_external_weights( | ||
mapper, text_encoder_model, external_weights, external_weight_file | ||
) | ||
|
||
class CompiledClip(CompiledModule): | ||
if external_weights: | ||
params = export_parameters( | ||
text_encoder_model, | ||
external=True, | ||
external_scope="", | ||
name_mapper=mapper.get, | ||
) | ||
else: | ||
params = export_parameters(text_encoder_model) | ||
|
||
def main(self, inp=AbstractTensor(1, 77, dtype=torch.int64)): | ||
return jittable(text_encoder_model.forward)(inp) | ||
|
||
import_to = "INPUT" if compile_to == "linalg" else "IMPORT" | ||
inst = CompiledClip(context=Context(), import_to=import_to) | ||
|
||
module_str = str(CompiledModule.get_mlir_module(inst)) | ||
safe_name = hf_model_name.split("/")[-1].strip() | ||
safe_name = re.sub("-", "_", safe_name) | ||
if compile_to != "vmfb": | ||
return module_str, tokenizer | ||
else: | ||
utils.compile_to_vmfb(module_str, device, target_triple, max_alloc, safe_name) | ||
|
||
|
||
def run_clip_vmfb_comparison(args): | ||
config = ireert.Config(args.device) | ||
|
||
if args.external_weight_file: | ||
index = ireert.ParameterIndex() | ||
index.load(args.external_weight_file) | ||
|
||
safe_name = args.hf_model_name.split("/")[-1].strip() | ||
safe_name = re.sub("-", "_", safe_name) | ||
if args.vmfb_path: | ||
mod = ireert.VmModule.mmap(config.vm_instance, args.vmfb_path) | ||
elif os.path.exists(f"{safe_name}.vmfb"): | ||
mod = ireert.VmModule.mmap(config.vm_instance, f"{safe_name}.vmfb") | ||
else: | ||
sys.exit("no vmfb_path provided, required for run_vmfb") | ||
|
||
vm_modules = [ | ||
mod, | ||
ireert.create_hal_module(config.vm_instance, config.device), | ||
] | ||
if args.external_weight_file: | ||
param_module = ireert.create_io_parameters_module( | ||
config.vm_instance, index.create_provider(scope="model") | ||
) | ||
vm_modules.insert(0, param_module) | ||
|
||
ctx = ireert.SystemContext( | ||
vm_modules=vm_modules, | ||
config=config, | ||
) | ||
tokenizer = CLIPTokenizer.from_pretrained( | ||
args.hf_model_name, | ||
subfolder="tokenizer", | ||
token=args.hf_auth_token, | ||
) | ||
text_input = tokenizer( | ||
prompt, | ||
padding="max_length", | ||
max_length=tokenizer.model_max_length, | ||
truncation=True, | ||
return_tensors="pt", | ||
) | ||
inp = text_input.input_ids | ||
device_inputs = [ireert.asdevicearray(config.device, inp)] | ||
|
||
# Turbine output | ||
ModuleCompiled = ctx.modules.compiled_clip | ||
turbine_outputs = ModuleCompiled["main"](*device_inputs) | ||
turbine_output = turbine_outputs[0] | ||
print( | ||
"TURBINE OUTPUT:", | ||
turbine_output.to_host(), | ||
turbine_output.to_host().shape, | ||
turbine_output.to_host().dtype, | ||
) | ||
|
||
# Torch output | ||
text_encoder_model = CLIPTextModel.from_pretrained( | ||
args.hf_model_name, | ||
subfolder="text_encoder", | ||
token=args.hf_auth_token, | ||
) | ||
torch_output = text_encoder_model.forward(inp)[0] | ||
np_torch_output = torch_output.detach().cpu().numpy() | ||
print( | ||
"TORCH OUTPUT:", np_torch_output, np_torch_output.shape, np_torch_output.dtype | ||
) | ||
|
||
err = utils.largest_error(np_torch_output, turbine_output) | ||
print("LARGEST ERROR:", err) | ||
assert err < 9e-5 | ||
|
||
|
||
if __name__ == "__main__": | ||
args = parser.parse_args() | ||
if args.run_vmfb: | ||
run_clip_vmfb_comparison(args) | ||
else: | ||
mod_str, _ = export_clip_model( | ||
args.hf_model_name, | ||
args.hf_auth_token, | ||
args.compile_to, | ||
args.external_weights, | ||
args.external_weight_file, | ||
args.device, | ||
args.iree_target_triple, | ||
args.vulkan_max_allocation, | ||
) | ||
safe_name = args.hf_model_name.split("/")[-1].strip() | ||
safe_name = re.sub("-", "_", safe_name) | ||
with open(f"{safe_name}.mlir", "w+") as f: | ||
f.write(mod_str) | ||
print("Saved to", safe_name + ".mlir") |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we want to provide the option to do quantization on the matmuls like we are for llama?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure if we want to provide that option. I can add it later if needed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Quantized (int8) SD is a popular request but we don't have a proof-of-concept yet. Can be follow-up.