Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Stable Diffusion using aot.export and external parameters #217

Merged
merged 28 commits into from
Dec 7, 2023

Conversation

aviator19941
Copy link
Contributor

  • Saves weights to .safetensors file
  • Load weights at runtime with a "stripped" .mlir

@aviator19941
Copy link
Contributor Author

#152

self,
sample=AbstractTensor(1, 4, 64, 64, dtype=torch.float32),
timestep=AbstractTensor(1, dtype=torch.float32),
encoder_hidden_states=AbstractTensor(2, 77, 768, dtype=torch.float32),
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

need to find better way to change AbstractTensor size

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What do you mean? You can use None for dynamic if needed

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @dan-garvey, the issue is that for SD 1.4/2.1 the last dimension of encoder_hidden_states is 768/1024, respectively. When I try to use None in the AbstractTensor with the constraint set to encoder_hidden_states.dynamic_dim(2), I get this error:

RuntimeError: a and b must have same reduction dim, but got [s2*s3, s2] X [768, 320].

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, yeah unless the encoder hidden state is also dynamic it won't work. In this case I'd just parameterize the last dim based on which model you're instantiating

@dan-garvey
Copy link
Member

Is this ready to review?

@aviator19941
Copy link
Contributor Author

Is this ready to review?

@dan-garvey Yes, please review, thanks!

Copy link
Member

@dan-garvey dan-garvey left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The test files all do the right thing, but I think warrant some refactoring (we are doing the same with stateless llama I think). Try and deduplicate the code and also add tests to the turbine_models ci

@@ -628,6 +628,12 @@ def _import_torch_op_overload(
elif target == torch.ops.aten.lift_fresh_copy.out:
node.target = target = torch.ops.aten.clone.out
node.args = (node.args[0], None, node.args[1])
# TODO: generalize empty memory_format in the future
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you add else case with an explanation?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure what the else case would be, but I added an explanation for the case.

self,
sample=AbstractTensor(1, 4, 64, 64, dtype=torch.float32),
timestep=AbstractTensor(1, dtype=torch.float32),
encoder_hidden_states=AbstractTensor(2, 77, 768, dtype=torch.float32),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What do you mean? You can use None for dynamic if needed

@dan-garvey
Copy link
Member

looks like you already deduped, thanks

Copy link
Member

@dan-garvey dan-garvey left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me, can you add a full e2e test (vmfb and compore to torch result) as a follow up?

Copy link
Contributor

@IanNod IanNod left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like we are compiling and running the 3 models (clip, unet, and vae) independently? This won't be able to do a full e2e inference to generate an image from a prompt will it?

python/turbine_models/custom_models/sd_inference/utils.py Outdated Show resolved Hide resolved
return jittable(text_encoder_model.forward)(inp)

import_to = "INPUT" if compile_to == "linalg" else "IMPORT"
inst = CompiledClip(context=Context(), import_to=import_to)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we want to provide the option to do quantization on the matmuls like we are for llama?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure if we want to provide that option. I can add it later if needed.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Quantized (int8) SD is a popular request but we don't have a proof-of-concept yet. Can be follow-up.

@aviator19941
Copy link
Contributor Author

Looks like we are compiling and running the 3 models (clip, unet, and vae) independently? This won't be able to do a full e2e inference to generate an image from a prompt will it?

No, this won't generate an image from a prompt yet. The 3 models are mainly for checking that we can compile and run them without issues as well as verify the results are similar to torch's results.

@IanNod
Copy link
Contributor

IanNod commented Dec 7, 2023

Looks like we are compiling and running the 3 models (clip, unet, and vae) independently? This won't be able to do a full e2e inference to generate an image from a prompt will it?

No, this won't generate an image from a prompt yet. The 3 models are mainly for checking that we can compile and run them without issues as well as verify the results are similar to torch's results.

ok, cool. Is there anything holding us back from generating an image, or are we planning on doing that when integrating with the web ui?

@aviator19941
Copy link
Contributor Author

ok, cool. Is there anything holding us back from generating an image, or are we planning on doing that when integrating with the web ui?

Yep, I was planning on doing that when integrating with the web ui

Copy link
Contributor

@IanNod IanNod left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me

@aviator19941 aviator19941 merged commit b5a6192 into main Dec 7, 2023
3 checks passed
@aviator19941 aviator19941 deleted the stateless_sd branch December 7, 2023 18:15
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants