Skip to content

Commit

Permalink
TST: Add tests for 4bit LoftQ (#1208)
Browse files Browse the repository at this point in the history
Add GPU tests for LoftQ with 4bit quantization.

Notes

Tests for 8bit quantization are already there but not run at the moment,
see this comment:

#1150 (comment)

In my testing, 8bit passes when using NFQuantizer, so if the original
author is fine with using that, I can make the adjustment.

---------

Co-authored-by: Younes Belkada <[email protected]>
  • Loading branch information
BenjaminBossan and younesbelkada authored Dec 11, 2023
1 parent 5c13ea3 commit b08e6fa
Showing 1 changed file with 135 additions and 0 deletions.
135 changes: 135 additions & 0 deletions tests/test_gpu_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,9 @@

from peft import (
AdaLoraConfig,
LoftQConfig,
LoraConfig,
TaskType,
get_peft_model,
prepare_model_for_int8_training,
prepare_model_for_kbit_training,
Expand Down Expand Up @@ -941,6 +943,139 @@ def test_causal_lm_training_multi_gpu(self):
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])


@require_torch_gpu
class LoftQTests(unittest.TestCase):
r"""
Tests for LoftQ
"""

def setUp(self):
self.error_factor = 3
self.model_id = "hf-internal-testing/tiny-random-BloomForCausalLM"
self.tokenizer = AutoTokenizer.from_pretrained(self.model_id)
self.inputs = self.tokenizer("All I want is", padding=True, return_tensors="pt").to("cuda")

def get_errors(self, bits=4, loftq_iter=1):
# Helper function that returns the quantization errors (MAE and MSE) when comparing the quantized LoRA model
# to the base model, vs the LoftQ quantized model to the base model. We expect the LoftQ quantized model to
# have less error than the normal LoRA quantized model. Since we compare logits, the observed error is
# already somewhat dampened because of the softmax.
model = AutoModelForCausalLM.from_pretrained(self.model_id).cuda().eval()
torch.manual_seed(0)
logits_base = model(**self.inputs).logits
# clean up
del model
gc.collect()
torch.cuda.empty_cache()

# logits from the normal quantized LoRA model
lora_config = LoraConfig(task_type=TaskType.CAUSAL_LM)
kwargs = {}
if bits == 4:
kwargs["load_in_4bit"] = True
elif bits == 8:
kwargs["load_in_8bit"] = True
else:
raise ValueError("bits must be 4 or 8")

quantized_model = get_peft_model(
AutoModelForCausalLM.from_pretrained(self.model_id, device_map="auto", **kwargs).eval(),
lora_config,
)
torch.manual_seed(0)
logits_quantized = quantized_model(**self.inputs).logits
del quantized_model
gc.collect()
torch.cuda.empty_cache()

# logits from quantized LoRA model using LoftQ
loftq_config = LoftQConfig(loftq_bits=bits, loftq_iter=loftq_iter)
lora_config = LoraConfig(task_type=TaskType.CAUSAL_LM, init_lora_weights="loftq", loftq_config=loftq_config)
loftq_model = get_peft_model(AutoModelForCausalLM.from_pretrained(self.model_id).cuda().eval(), lora_config)
torch.manual_seed(0)
logits_loftq = loftq_model(**self.inputs).logits
del loftq_model
gc.collect()
torch.cuda.empty_cache()

mae_quantized = torch.abs(logits_base - logits_quantized).mean()
mse_quantized = torch.pow(logits_base - logits_quantized, 2).mean()
mae_loftq = torch.abs(logits_base - logits_loftq).mean()
mse_loftq = torch.pow(logits_base - logits_loftq, 2).mean()
return mae_quantized, mse_quantized, mae_loftq, mse_loftq

def test_bloomz_loftq_4bit(self):
# In this test, we compare the logits of the base model, the quantized LoRA model, and the quantized model
# using LoftQ. When quantizing, we expect a certain level of error. However, we expect the LoftQ quantized
# model to have less error than the normal LoRA quantized model. Note that when using normal LoRA, the
# quantization error is simply the error from quantization without LoRA, as LoRA is a no-op before training.
# We still apply LoRA for the test for consistency.

mae_quantized, mse_quantized, mae_loftq, mse_loftq = self.get_errors(bits=4)
# first, sanity check that all errors are > 0.0
self.assertTrue(mae_quantized > 0.0)
self.assertTrue(mse_quantized > 0.0)
self.assertTrue(mae_loftq > 0.0)
self.assertTrue(mse_loftq > 0.0)

# next, check that LoftQ quantization errors are smaller than LoRA errors by a certain margin
factor = 3
self.assertTrue(mae_loftq < mae_quantized / factor)
self.assertTrue(mse_loftq < mse_quantized / factor)

def test_bloomz_loftq_4bit_iter_5(self):
# Same test as the previous one but with 5 iterations. We should expect the error to be even smaller with more
# iterations, but in practice the difference is not that large, at least not for this small base model.
mae_quantized, mse_quantized, mae_loftq, mse_loftq = self.get_errors(bits=4, loftq_iter=5)
# first, sanity check that all errors are > 0.0
self.assertTrue(mae_quantized > 0.0)
self.assertTrue(mse_quantized > 0.0)
self.assertTrue(mae_loftq > 0.0)
self.assertTrue(mse_loftq > 0.0)

# next, check that LoftQ quantization errors are smaller than LoRA errors by a certain margin
self.assertTrue(mae_loftq < mae_quantized / self.error_factor)
self.assertTrue(mse_loftq < mse_quantized / self.error_factor)

def test_bloomz_loftq_8bit(self):
# this currently does not work:
# https://github.com/huggingface/peft/pull/1150#issuecomment-1838891499
if True: # TODO: remove as soon as the issue is fixed
return

# Same test as test_bloomz_loftq_4bit but with 8 bits.
mae_quantized, mse_quantized, mae_loftq, mse_loftq = self.get_errors(bits=8)

# first, sanity check that all errors are > 0.0
self.assertTrue(mae_quantized > 0.0)
self.assertTrue(mse_quantized > 0.0)
self.assertTrue(mae_loftq > 0.0)
self.assertTrue(mse_loftq > 0.0)

# next, check that LoftQ quantization errors are smaller than LoRA errors by a certain margin
self.assertTrue(mae_loftq < mae_quantized / self.error_factor)
self.assertTrue(mse_loftq < mse_quantized / self.error_factor)

def test_bloomz_loftq_8bit_iter_5(self):
# this currently does not work:
# https://github.com/huggingface/peft/pull/1150#issuecomment-1838891499
if True: # TODO: remove as soon as the issue is fixed
return

# Same test as test_bloomz_loftq_4bit_iter_5 but with 8 bits.
mae_quantized, mse_quantized, mae_loftq, mse_loftq = self.get_errors(bits=8, loftq_iter=5)

# first, sanity check that all errors are > 0.0
self.assertTrue(mae_quantized > 0.0)
self.assertTrue(mse_quantized > 0.0)
self.assertTrue(mae_loftq > 0.0)
self.assertTrue(mse_loftq > 0.0)

# next, check that LoftQ quantization errors are smaller than LoRA errors by a certain margin
self.assertTrue(mae_loftq < mae_quantized / self.error_factor)
self.assertTrue(mse_loftq < mse_quantized / self.error_factor)


@require_bitsandbytes
@require_torch_gpu
class MultiprocessTester(unittest.TestCase):
Expand Down

0 comments on commit b08e6fa

Please sign in to comment.