Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Stable Diffusion using aot.export and external parameters #217

Merged
merged 28 commits into from
Dec 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
3322a70
Initial SD 1.5 inference script
aviator19941 Nov 9, 2023
781516e
Add CLIP test
aviator19941 Nov 9, 2023
a0e4a08
Use aot.export for CLIPTextModel inference module
aviator19941 Nov 9, 2023
e3641af
[WIP] Add Unet CompiledModule example
aviator19941 Nov 9, 2023
b24546f
Fix usage of vae to decode latents into real images
aviator19941 Nov 9, 2023
d9d7158
[WIP] Debug Unet and VAE nan values
aviator19941 Nov 10, 2023
cde951a
Add linalg mlir for debugging
aviator19941 Nov 15, 2023
5736574
[WIP] Use linalg for debugging
aviator19941 Nov 22, 2023
ceaf350
Change empty.memory_format to aten.zeros.default to fix VAE
aviator19941 Nov 23, 2023
33978b4
Fix unet and add torch tests
aviator19941 Nov 23, 2023
ced8dc6
Rename to sd1.4_inference
aviator19941 Nov 23, 2023
7c24f1f
[WIP] Update CLIP 1.4 example to export parameters/save "stripped" .mlir
aviator19941 Nov 28, 2023
1d1c1a1
Load weights at runtime for CLIP
aviator19941 Nov 29, 2023
c27efb4
[WIP] Fix batch size for encoder_hidden_states
aviator19941 Nov 30, 2023
7f4a455
Start cleaning up code
aviator19941 Dec 1, 2023
bba568c
Finish clip example
aviator19941 Dec 1, 2023
918d91c
Finish clip and unet scripts
aviator19941 Dec 1, 2023
dd072b1
Finish vae script
aviator19941 Dec 1, 2023
c83d56c
Add hf token flag and fix vae output comparison
aviator19941 Dec 2, 2023
cb199ea
Move scripts to turbine_models/custom_models
aviator19941 Dec 2, 2023
65cc5e6
Fix formatting
aviator19941 Dec 2, 2023
d8e3306
Move reusable functions to utils
aviator19941 Dec 4, 2023
c70db4c
Fix black formatting for utils
aviator19941 Dec 4, 2023
23cfead
Address Dan's comments
aviator19941 Dec 5, 2023
865cedd
Rename sd_inference files and add tests to turbine_models ci
aviator19941 Dec 5, 2023
df6ad67
Add 2.1 test and add requirements for SD
aviator19941 Dec 5, 2023
904f519
Add accelerate and diffusers to setup.py
aviator19941 Dec 6, 2023
990d3c9
Address comments, make tests run vmfb, add device support
aviator19941 Dec 7, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions python/shark_turbine/dynamo/passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
torch.ops.aten._to_copy,
torch.ops.aten._log_softmax_backward_data,
torch.ops.aten.lift_fresh_copy.default,
torch.ops.aten._unsafe_index.Tensor,
]


Expand Down
10 changes: 10 additions & 0 deletions python/shark_turbine/importers/fx_importer.py
Original file line number Diff line number Diff line change
Expand Up @@ -628,6 +628,16 @@ def _import_torch_op_overload(
elif target == torch.ops.aten.lift_fresh_copy.out:
node.target = target = torch.ops.aten.clone.out
node.args = (node.args[0], None, node.args[1])
# TODO: generalize empty.memory_format in the future
# Currently, the aten.baddbmm.default op for Unet includes multiplying an
# empty.memory_format input with a constant, which creates NaN values
# because empty.memory_format contains uninitialized data. Converting
# aten.baddbmm.default -> aten.zeros.default fixes the correctness issue
elif target == torch.ops.aten.empty.memory_format:
if len(node.users) == 1:
for key_node in node.users:
if key_node.target == torch.ops.aten.baddbmm.default:
node.target = target = torch.ops.aten.zeros.default

schema = target._schema
assert isinstance(schema, FunctionSchema)
Expand Down
201 changes: 201 additions & 0 deletions python/turbine_models/custom_models/sd_inference/clip.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,201 @@
# Copyright 2023 Nod Labs, Inc
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

import os
import sys
import re

from iree import runtime as ireert
import iree.compiler as ireec
from iree.compiler.ir import Context
import numpy as np
from shark_turbine.aot import *
from turbine_models.custom_models.sd_inference import utils
import torch
import torch._dynamo as dynamo
from transformers import CLIPTextModel, CLIPTokenizer

import argparse

parser = argparse.ArgumentParser()
parser.add_argument(
"--hf_auth_token", type=str, help="The Hugging Face auth token, required"
)
parser.add_argument(
"--hf_model_name",
type=str,
help="HF model name",
default="CompVis/stable-diffusion-v1-4",
)
parser.add_argument("--run_vmfb", action="store_true")
parser.add_argument("--compile_to", type=str, help="torch, linalg, vmfb")
parser.add_argument("--external_weight_file", type=str, default="")
parser.add_argument("--vmfb_path", type=str, default="")
parser.add_argument(
"--external_weights",
type=str,
default=None,
help="saves ir/vmfb without global weights for size and readability, options [safetensors]",
)
parser.add_argument("--device", type=str, default="cpu", help="cpu, cuda, vulkan, rocm")
# TODO: Bring in detection for target triple
parser.add_argument(
"--iree_target_triple",
type=str,
default="",
help="Specify vulkan target triple or rocm/cuda target device.",
)
parser.add_argument("--vulkan_max_allocation", type=str, default="4294967296")

prompt = ["a photograph of an astronaut riding a horse"]


def export_clip_model(
hf_model_name,
hf_auth_token=None,
compile_to="torch",
external_weights=None,
external_weight_file=None,
device=None,
target_triple=None,
max_alloc=None,
):
# Load the tokenizer and text encoder to tokenize and encode the text.
tokenizer = CLIPTokenizer.from_pretrained(
hf_model_name,
subfolder="tokenizer",
token=hf_auth_token,
)
text_encoder_model = CLIPTextModel.from_pretrained(
hf_model_name,
subfolder="text_encoder",
token=hf_auth_token,
)

mapper = {}
utils.save_external_weights(
mapper, text_encoder_model, external_weights, external_weight_file
)

class CompiledClip(CompiledModule):
if external_weights:
params = export_parameters(
text_encoder_model,
external=True,
external_scope="",
name_mapper=mapper.get,
)
else:
params = export_parameters(text_encoder_model)

def main(self, inp=AbstractTensor(1, 77, dtype=torch.int64)):
return jittable(text_encoder_model.forward)(inp)

import_to = "INPUT" if compile_to == "linalg" else "IMPORT"
inst = CompiledClip(context=Context(), import_to=import_to)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we want to provide the option to do quantization on the matmuls like we are for llama?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure if we want to provide that option. I can add it later if needed.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Quantized (int8) SD is a popular request but we don't have a proof-of-concept yet. Can be follow-up.


module_str = str(CompiledModule.get_mlir_module(inst))
safe_name = hf_model_name.split("/")[-1].strip()
safe_name = re.sub("-", "_", safe_name)
if compile_to != "vmfb":
return module_str, tokenizer
else:
utils.compile_to_vmfb(module_str, device, target_triple, max_alloc, safe_name)


def run_clip_vmfb_comparison(args):
config = ireert.Config(args.device)

if args.external_weight_file:
index = ireert.ParameterIndex()
index.load(args.external_weight_file)

safe_name = args.hf_model_name.split("/")[-1].strip()
safe_name = re.sub("-", "_", safe_name)
if args.vmfb_path:
mod = ireert.VmModule.mmap(config.vm_instance, args.vmfb_path)
elif os.path.exists(f"{safe_name}.vmfb"):
mod = ireert.VmModule.mmap(config.vm_instance, f"{safe_name}.vmfb")
else:
sys.exit("no vmfb_path provided, required for run_vmfb")

vm_modules = [
mod,
ireert.create_hal_module(config.vm_instance, config.device),
]
if args.external_weight_file:
param_module = ireert.create_io_parameters_module(
config.vm_instance, index.create_provider(scope="model")
)
vm_modules.insert(0, param_module)

ctx = ireert.SystemContext(
vm_modules=vm_modules,
config=config,
)
tokenizer = CLIPTokenizer.from_pretrained(
args.hf_model_name,
subfolder="tokenizer",
token=args.hf_auth_token,
)
text_input = tokenizer(
prompt,
padding="max_length",
max_length=tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
inp = text_input.input_ids
device_inputs = [ireert.asdevicearray(config.device, inp)]

# Turbine output
ModuleCompiled = ctx.modules.compiled_clip
turbine_outputs = ModuleCompiled["main"](*device_inputs)
turbine_output = turbine_outputs[0]
print(
"TURBINE OUTPUT:",
turbine_output.to_host(),
turbine_output.to_host().shape,
turbine_output.to_host().dtype,
)

# Torch output
text_encoder_model = CLIPTextModel.from_pretrained(
args.hf_model_name,
subfolder="text_encoder",
token=args.hf_auth_token,
)
torch_output = text_encoder_model.forward(inp)[0]
np_torch_output = torch_output.detach().cpu().numpy()
print(
"TORCH OUTPUT:", np_torch_output, np_torch_output.shape, np_torch_output.dtype
)

err = utils.largest_error(np_torch_output, turbine_output)
print("LARGEST ERROR:", err)
assert err < 9e-5


if __name__ == "__main__":
args = parser.parse_args()
if args.run_vmfb:
run_clip_vmfb_comparison(args)
else:
mod_str, _ = export_clip_model(
args.hf_model_name,
args.hf_auth_token,
args.compile_to,
args.external_weights,
args.external_weight_file,
args.device,
args.iree_target_triple,
args.vulkan_max_allocation,
)
safe_name = args.hf_model_name.split("/")[-1].strip()
safe_name = re.sub("-", "_", safe_name)
with open(f"{safe_name}.mlir", "w+") as f:
f.write(mod_str)
print("Saved to", safe_name + ".mlir")
Loading
Loading