Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support optimum-quanto #1997

Open
BenjaminBossan opened this issue Aug 9, 2024 · 13 comments · May be fixed by #2000
Open

Support optimum-quanto #1997

BenjaminBossan opened this issue Aug 9, 2024 · 13 comments · May be fixed by #2000

Comments

@BenjaminBossan
Copy link
Member

Feature request

Let's add a new quantization method to LoRA, namely optimum-quanto.

There is some more context in this diffusers issue.

Motivation

First of all, the more quantization methods we support the better. But notably, quanto also works with MPS, which distinguishes it from other quantization methods.

Your contribution

I did some preliminary testing and partly, quanto already works with PEFT, as the QLinear layer is a subclass of nn.Linear and as such, lora.Linear is applied. Some features like inference appear to work already. However, some features don't work correctly, like merging. Here is a very quick test:

import torch
from peft import LoraConfig, set_peft_model_state_dict, get_peft_model
from optimum.quanto import quantize, freeze, qint8
from transformers import AutoModelForCausalLM

torch.manual_seed(0)
inputs = torch.arange(5).view(-1, 1)
print("loading model")
model = AutoModelForCausalLM.from_pretrained("facebook/opt-125m").eval()
with torch.inference_mode():
    output_base = model(inputs).logits

# Step 3: Quantize the Model
print("quantizing model")
quantize(model, weights=qint8)
print("freezing model")
freeze(model)

with torch.inference_mode():
    output_quantized = model(inputs).logits

config = LoraConfig(r=8, lora_alpha=16, lora_dropout=0.1, init_lora_weights=False)
print("adding adapter (random)")
model = get_peft_model(model, config)
model.eval()

with torch.inference_mode():
    output_lora = model(inputs).logits

    with model.disable_adapter():
        output_disabled = model(inputs).logits

    output_after_disabled = model(inputs).logits

model.merge_adapter()
with torch.inference_mode():
    output_merged = model(inputs).logits

model.unmerge_adapter()
with torch.inference_mode():
    output_unmerged = model(inputs).logits

unloaded = model.merge_and_unload()
with torch.inference_mode():
    output_unloaded = unloaded(inputs).logits

print("output_base")
print(output_base[0, 0, :5])
print("output_quantized")
print(output_quantized[0, 0, :5])
print("output_lora")
print(output_lora[0, 0, :5])
print("output_disabled")
print(output_disabled[0, 0, :5])
print("output_after_disabled")
print(output_after_disabled[0, 0, :5])
print("output_merged")
print(output_merged[0, 0, :5])
print("output_unmerged")
print(output_unmerged[0, 0, :5])
print("output_unloaded")
print(output_unloaded[0, 0, :5])

Note that all the outputs involving merging are not as expected.

I can certainly take this when I have time but contributions are highly welcome. For inspiration, check out past PRs that add new quantization methods.

Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

@BenjaminBossan
Copy link
Member Author

not stale

@Lantianyou
Copy link

This would be extremely helpful, currently Flux with lora mostly run on A100 like this, but flux can comfortably run on 4090, a consumer card. If this issue is resolved, it would be a way to run Flux with lora on consumer card like 4090. Thank you for your work @BenjaminBossan

CleanShot 2024-09-22 at 12 05 32@2x

@sayakpaul
Copy link
Member

I am able to run this on a 4090:

from diffusers import DiffusionPipeline
import torch

pipeline = DiffusionPipeline.from_pretrained(
    "black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16
)
pipeline.load_lora_weights(
    "TheLastBen/Jon_Snow_Flux_LoRA", weight_name="jon_snow.safetensors"
)
pipeline.fuse_lora()
pipeline.unload_lora_weights()

pipeline.enable_model_cpu_offload()

prompt = "jon snow eating pizza with ketchup"

out = pipeline(prompt, num_inference_steps=20, guidance_scale=4.0)
out.images[0].save("output.png")

What am I missing?

@Lantianyou
Copy link

Thanks for the write up. I will try it again later

@Lantianyou
Copy link

Thank you @sayakpaul I was running it using the FP8 version following diffusers doc, and it did not load lora successfully. Running a bfloat16 will cause my 4090 out of CUDA memory

torch.OutOfMemoryError: CUDA out of memory. Tried to allocate 90.00 MiB. GPU 0 has a total capacity of 23.65 GiB of which 38.06 MiB is free. Including non-PyTorch memory, this process has 23.60 GiB memory in use. Of the allocated memory 21.21 GiB is allocated by PyTorch, and 1.94 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)
import torch
from diffusers import FluxTransformer2DModel, FluxPipeline
from transformers import T5EncoderModel, CLIPTextModel
from optimum.quanto import freeze, qfloat8, quantize

bfl_repo = "black-forest-labs/FLUX.1-dev"
dtype = torch.bfloat16

transformer = FluxTransformer2DModel.from_single_file("https://huggingface.co/Kijai/flux-fp8/blob/main/flux1-dev-fp8.safetensors", torch_dtype=dtype)
quantize(transformer, weights=qfloat8)
freeze(transformer)

text_encoder_2 = T5EncoderModel.from_pretrained(bfl_repo, subfolder="text_encoder_2", torch_dtype=dtype)
quantize(text_encoder_2, weights=qfloat8)
freeze(text_encoder_2)

pipe = FluxPipeline.from_pretrained(bfl_repo, transformer=None, text_encoder_2=None, torch_dtype=dtype)
pipe.transformer = transformer
pipe.text_encoder_2 = text_encoder_2

pipe.enable_model_cpu_offload()

prompt = "A cat holding a sign that says hello world"
image = pipe(
    prompt,
    guidance_scale=3.5,
    output_type="pil",
    num_inference_steps=20,
    generator=torch.Generator("cpu").manual_seed(0)
).images[0]

image.save("flux-fp8-dev.png")

@sayakpaul
Copy link
Member

Why are you quantizing when you can perfectly run it without quantization? And we haven't yet landed the support to load LoRAs in a quantized base model. So, I am not going to comment on that.

@Lantianyou
Copy link

Lantianyou commented Sep 23, 2024

Understood. Really appreciate your work. Let me reframe my problems, as of now, quantimized Flux can comfortably run on 4090, but 4090 cannot run a BF16 version of Flux seamlessly, for example, in my case. I guess BF16 version of Flux is really pushing limit on 24GB GPU memory

@sayakpaul
Copy link
Member

sayakpaul commented Sep 23, 2024

I have provided a working code snippet. If you run the example with the latest versions of peft, accelerate, transformers, and diffusers -- they are expected to work. If not, please create a new issue on diffusers. If they do work, please notify this thread, as well.

@Lantianyou
Copy link

Sure, will run the test again

1 similar comment
@Lantianyou
Copy link

Sure, will run the test again

@Lantianyou
Copy link

Sorry you are right. Using latest hugging face libraries they do work on 4090. Sorry for the trouble @sayakpaul

@sayakpaul
Copy link
Member

No issues, glad things worked out. You might also be interested in huggingface/diffusers#9213 :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants