-
Notifications
You must be signed in to change notification settings - Fork 45
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Stable Diffusion using aot.export and external parameters #217
Conversation
aviator19941
commented
Dec 2, 2023
- Saves weights to .safetensors file
- Load weights at runtime with a "stripped" .mlir
self, | ||
sample=AbstractTensor(1, 4, 64, 64, dtype=torch.float32), | ||
timestep=AbstractTensor(1, dtype=torch.float32), | ||
encoder_hidden_states=AbstractTensor(2, 77, 768, dtype=torch.float32), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
need to find better way to change AbstractTensor size
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What do you mean? You can use None for dynamic if needed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey @dan-garvey, the issue is that for SD 1.4/2.1 the last dimension of encoder_hidden_states is 768/1024, respectively. When I try to use None in the AbstractTensor with the constraint set to encoder_hidden_states.dynamic_dim(2)
, I get this error:
RuntimeError: a and b must have same reduction dim, but got [s2*s3, s2] X [768, 320].
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see, yeah unless the encoder hidden state is also dynamic it won't work. In this case I'd just parameterize the last dim based on which model you're instantiating
Is this ready to review? |
@dan-garvey Yes, please review, thanks! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The test files all do the right thing, but I think warrant some refactoring (we are doing the same with stateless llama I think). Try and deduplicate the code and also add tests to the turbine_models ci
@@ -628,6 +628,12 @@ def _import_torch_op_overload( | |||
elif target == torch.ops.aten.lift_fresh_copy.out: | |||
node.target = target = torch.ops.aten.clone.out | |||
node.args = (node.args[0], None, node.args[1]) | |||
# TODO: generalize empty memory_format in the future |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you add else case with an explanation?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure what the else case would be, but I added an explanation for the case.
self, | ||
sample=AbstractTensor(1, 4, 64, 64, dtype=torch.float32), | ||
timestep=AbstractTensor(1, dtype=torch.float32), | ||
encoder_hidden_states=AbstractTensor(2, 77, 768, dtype=torch.float32), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What do you mean? You can use None for dynamic if needed
looks like you already deduped, thanks |
922e4e0
to
49d1dfb
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me, can you add a full e2e test (vmfb and compore to torch result) as a follow up?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks like we are compiling and running the 3 models (clip, unet, and vae) independently? This won't be able to do a full e2e inference to generate an image from a prompt will it?
return jittable(text_encoder_model.forward)(inp) | ||
|
||
import_to = "INPUT" if compile_to == "linalg" else "IMPORT" | ||
inst = CompiledClip(context=Context(), import_to=import_to) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we want to provide the option to do quantization on the matmuls like we are for llama?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure if we want to provide that option. I can add it later if needed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Quantized (int8) SD is a popular request but we don't have a proof-of-concept yet. Can be follow-up.
No, this won't generate an image from a prompt yet. The 3 models are mainly for checking that we can compile and run them without issues as well as verify the results are similar to torch's results. |
ok, cool. Is there anything holding us back from generating an image, or are we planning on doing that when integrating with the web ui? |
Yep, I was planning on doing that when integrating with the web ui |
49d1dfb
to
990d3c9
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me