Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

StableDiffusion #152

Open
stellaraccident opened this issue Nov 3, 2023 · 9 comments
Open

StableDiffusion #152

stellaraccident opened this issue Nov 3, 2023 · 9 comments
Assignees

Comments

@stellaraccident
Copy link
Contributor

No description provided.

@stellaraccident
Copy link
Contributor Author

Sync with Ean

@monorimet
Copy link
Contributor

monorimet commented Nov 3, 2023

  • lowerings in the spirit of stateless llama (@aviator19941 )
    • Unet
    • VAE
    • CLIP
  • support for arbitrary model formats -- HF ID, civitai checkpoints, torch.nn.modules, .mlir input
  • support for loading LoRAs
  • support for ControlNet
  • support fake weights / loading .ckpt at runtime
  • support for LoRA training

@monorimet
Copy link
Contributor

monorimet commented Dec 7, 2023

@aviator19941 there are a few more missing pieces, mostly small preprocessors/schedulers that I'd like to have turbine implementations ready for. Please take a look at the following list of tasks and let me know if you have any questions.

  • Support for (dynamic) SD schedulers -- refer to SHARK sd_schedulers.py -- currently we import all of these upon loading a model config so they are easily switchable at runtime, but most of them run on CPU (base pytorch) -- only SharkEulerDiscrete and SharkEulerAncestralDiscrete are available as .mlir/vmfb at the moment but we should support all of them as compiled modules.
  • Support for controlnet adapters -- see controlled UNet and StencilControlNetModel (do me a favor and don't use the term 'stencil' for this part, as it's misleading. these are just controlnet adapters)
  • Support for controlnet preprocessors -- see canny, openpose, zoedepth. These just take an input image and convert it to some mask/mapping. They are fed into the unet latents through the stencil adaptors.
  • LoRA support
  • Lycoris support
  • SDXL support -- probably do this separate from SD1.5/2.1 as the model arch is generally different

@gpetters94
Copy link
Contributor

I can take Controlnet and the preprocessors @monorimet @aviator19941

@aviator19941
Copy link
Contributor

aviator19941 commented Dec 7, 2023

sounds good, I'll take (dynamic) SD schedulers then @gpetters94

@monorimet
Copy link
Contributor

@aviator19941 can we implement VAE encode for img2img? we can split into two files or just have a separate API e.g. export_vae_encode_model and export_vae_decode_model or just use a keyword export_vae_model(hf_model_id, variant="encode", ... )

@aviator19941
Copy link
Contributor

aviator19941 commented Dec 14, 2023

@stellaraccident @dan-garvey VAE encode when calling sample() it fails the out = function(*args_functional) call in _functionalize_callabale() in functorch.py. I am not sure how to go about fixing this error. I have a small example that reproduces the same error that happens in latents = self.vae.encode(inp).latent_dist.sample(), but when I provide concrete integers in place of inp.shape in the repro's forward function, i.e. torch.randn(1, 4, 64, 64, ...) it is able to compile to torch IR. Do you have any suggestions on how to fix this?

The Traceback is below:

Traceback (most recent call last):
  File "/home/avinash/nod/SHARK-Turbine/python/turbine_models/custom_models/sd_inference/vae_encode.py", line 37, in <module>
    exported = aot.export(sample_model, example_x)
  File "/home/avinash/nod/SHARK-Turbine/python/shark_turbine/aot/exporter.py", line 199, in export
    cm = Exported(context=context, import_to="import")
  File "/home/avinash/nod/SHARK-Turbine/python/shark_turbine/aot/compiled_module.py", line 538, in __new__
    do_export(proc_def)
  File "/home/avinash/nod/SHARK-Turbine/python/shark_turbine/aot/compiled_module.py", line 535, in do_export
    trace.trace_py_func(invoke_with_self)
  File "/home/avinash/nod/SHARK-Turbine/python/shark_turbine/aot/support/procedural/tracer.py", line 121, in trace_py_func
    return_py_value = _unproxy(py_f(*self.proxy_posargs, **self.proxy_kwargs))
  File "/home/avinash/nod/SHARK-Turbine/python/shark_turbine/aot/compiled_module.py", line 516, in invoke_with_self
    return proc_def.callable(self, *args, **kwargs)
  File "/home/avinash/nod/SHARK-Turbine/python/shark_turbine/aot/exporter.py", line 183, in main
    return jittable(mdl.forward)(*args)
  File "/home/avinash/nod/SHARK-Turbine/python/shark_turbine/aot/support/procedural/base.py", line 137, in __call__
    return current_ir_trace().handle_call(self, args, kwargs)
  File "/home/avinash/nod/SHARK-Turbine/python/shark_turbine/aot/support/procedural/tracer.py", line 137, in handle_call
    return target.resolve_call(self, *args, **kwargs)
  File "/home/avinash/nod/SHARK-Turbine/python/shark_turbine/aot/builtins/jittable.py", line 207, in resolve_call
    transformed_f = functorch_functionalize(transformed_f, *flat_pytorch_args)
  File "/home/avinash/nod/SHARK-Turbine/python/shark_turbine/aot/passes/functorch.py", line 47, in functorch_functionalize
    new_gm = proxy_tensor.make_fx(
  File "/home/avinash/nod/SHARK-Turbine/turbine/lib/python3.10/site-packages/torch/fx/experimental/proxy_tensor.py", line 809, in wrapped
    t = dispatch_trace(wrap_key(func, args, fx_tracer, pre_dispatch), tracer=fx_tracer, concrete_args=tuple(phs))
  File "/home/avinash/nod/SHARK-Turbine/turbine/lib/python3.10/site-packages/torch/_compile.py", line 24, in inner
    return torch._dynamo.disable(fn, recursive)(*args, **kwargs)
  File "/home/avinash/nod/SHARK-Turbine/turbine/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 328, in _fn
    return fn(*args, **kwargs)
  File "/home/avinash/nod/SHARK-Turbine/turbine/lib/python3.10/site-packages/torch/_dynamo/external_utils.py", line 17, in inner
    return fn(*args, **kwargs)
  File "/home/avinash/nod/SHARK-Turbine/turbine/lib/python3.10/site-packages/torch/fx/experimental/proxy_tensor.py", line 468, in dispatch_trace
    graph = tracer.trace(root, concrete_args)
  File "/home/avinash/nod/SHARK-Turbine/turbine/lib/python3.10/site-packages/torch/fx/_symbolic_trace.py", line 817, in trace
    (self.create_arg(fn(*args)),),
  File "/home/avinash/nod/SHARK-Turbine/turbine/lib/python3.10/site-packages/torch/fx/experimental/proxy_tensor.py", line 485, in wrapped
    out = f(*tensors)
  File "<string>", line 1, in <lambda>
  File "/home/avinash/nod/SHARK-Turbine/python/shark_turbine/aot/passes/functorch.py", line 65, in wrapped
    out = function(*args_functional)
  File "/home/avinash/nod/SHARK-Turbine/python/shark_turbine/aot/builtins/jittable.py", line 202, in flat_wrapped_f
    return self.wrapped_f(*pytorch_args, **pytorch_kwargs)
  File "/home/avinash/nod/SHARK-Turbine/python/turbine_models/custom_models/sd_inference/vae_encode.py", line 30, in forward
    sample = torch.randn(inp.shape, generator=generator, device="cpu")
  File "/home/avinash/nod/SHARK-Turbine/turbine/lib/python3.10/site-packages/torch/utils/_stats.py", line 20, in wrapper
    return fn(*args, **kwargs)
  File "/home/avinash/nod/SHARK-Turbine/turbine/lib/python3.10/site-packages/torch/fx/experimental/proxy_tensor.py", line 555, in __torch_dispatch__
    return self.inner_torch_dispatch(func, types, args, kwargs)
  File "/home/avinash/nod/SHARK-Turbine/turbine/lib/python3.10/site-packages/torch/fx/experimental/proxy_tensor.py", line 580, in inner_torch_dispatch
    return proxy_call(self, func, self.pre_dispatch, args, kwargs)
  File "/home/avinash/nod/SHARK-Turbine/turbine/lib/python3.10/site-packages/torch/fx/experimental/proxy_tensor.py", line 361, in proxy_call
    out = func(*args, **kwargs)
  File "/home/avinash/nod/SHARK-Turbine/turbine/lib/python3.10/site-packages/torch/_ops.py", line 448, in __call__
    return self._op(*args, **kwargs or {})
  File "/home/avinash/nod/SHARK-Turbine/turbine/lib/python3.10/site-packages/torch/utils/_stats.py", line 20, in wrapper
    return fn(*args, **kwargs)
  File "/home/avinash/nod/SHARK-Turbine/turbine/lib/python3.10/site-packages/torch/_subclasses/fake_tensor.py", line 1250, in __torch_dispatch__
    return self.dispatch(func, types, args, kwargs)
  File "/home/avinash/nod/SHARK-Turbine/turbine/lib/python3.10/site-packages/torch/_subclasses/fake_tensor.py", line 1487, in dispatch
    op_impl_out = op_impl(self, func, *args, **kwargs)
  File "/home/avinash/nod/SHARK-Turbine/turbine/lib/python3.10/site-packages/torch/_subclasses/fake_tensor.py", line 422, in constructors
    r = func(*args, **new_kwargs)
  File "/home/avinash/nod/SHARK-Turbine/turbine/lib/python3.10/site-packages/torch/_ops.py", line 448, in __call__
    return self._op(*args, **kwargs or {})
RuntimeError: aten/src/ATen/RegisterCompositeExplicitAutograd.cpp:3404: SymIntArrayRef expected to contain only concrete integers

@dan-garvey
Copy link
Member

import torch
import torch._dynamo as dynamo
from torch._export import dynamic_dim
from torch._export.constraints import constrain_as_size, constrain_as_value
from typing import Optional
from shark_turbine.aot import *
from iree.compiler.ir import Context

class SampleModel(CompiledModule):

    def run_forward(self, inp=AbstractTensor(1,4,64,64, dtype=torch.float32)):
        sample = self.forward(inp)
        return sample

    @jittable
    def forward(inp) -> torch.FloatTensor:
        sample = torch.randn(inp.shape, device="cpu")
        # make sure sample is on the same device as the parameters and has same dtype
        sample = sample.to(device="cpu", dtype=torch.float32)
        return sample


sample_model = SampleModel(context=Context(), import_to="IMPORT")
print(str(CompiledModule.get_mlir_module(sample_model)))

seems to fix your issue. Can you model the full thing this way?

@aviator19941
Copy link
Contributor

aviator19941 commented Dec 14, 2023

import shark_turbine.aot as aot
import torch
import torch._dynamo as dynamo
from torch._export import dynamic_dim
from torch._export.constraints import constrain_as_size, constrain_as_value
from typing import Optional
from diffusers import AutoencoderKL
from shark_turbine.aot import *
from iree.compiler.ir import Context

hf_model_name = "CompVis/stable-diffusion-v1-4"

vae = AutoencoderKL.from_pretrained(
    hf_model_name,
    subfolder="vae",
)
class VaeModel(CompiledModule):

    def run_forward(self, inp=AbstractTensor(1, 3, 512, 512, dtype=torch.float32)):
        x = self.forward(inp)
        return x
    
    @jittable
    def forward(inp) -> torch.FloatTensor:
        latents = vae.encode(inp).latent_dist.sample()
        return 0.18215 * latents


vae_model = VaeModel(context=Context(), import_to="IMPORT")
print(str(CompiledModule.get_mlir_module(vae_model)))

I think the issue might be when the jittable calls another function that needs to use the input shape directly. Seems like this example also fails the SymIntArrayRef check.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

5 participants