diff --git a/tests/test_gpu_examples.py b/tests/test_gpu_examples.py index d98142b7fb..000d8e9152 100644 --- a/tests/test_gpu_examples.py +++ b/tests/test_gpu_examples.py @@ -41,7 +41,9 @@ from peft import ( AdaLoraConfig, + LoftQConfig, LoraConfig, + TaskType, get_peft_model, prepare_model_for_int8_training, prepare_model_for_kbit_training, @@ -941,6 +943,139 @@ def test_causal_lm_training_multi_gpu(self): self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) +@require_torch_gpu +class LoftQTests(unittest.TestCase): + r""" + Tests for LoftQ + """ + + def setUp(self): + self.error_factor = 3 + self.model_id = "hf-internal-testing/tiny-random-BloomForCausalLM" + self.tokenizer = AutoTokenizer.from_pretrained(self.model_id) + self.inputs = self.tokenizer("All I want is", padding=True, return_tensors="pt").to("cuda") + + def get_errors(self, bits=4, loftq_iter=1): + # Helper function that returns the quantization errors (MAE and MSE) when comparing the quantized LoRA model + # to the base model, vs the LoftQ quantized model to the base model. We expect the LoftQ quantized model to + # have less error than the normal LoRA quantized model. Since we compare logits, the observed error is + # already somewhat dampened because of the softmax. + model = AutoModelForCausalLM.from_pretrained(self.model_id).cuda().eval() + torch.manual_seed(0) + logits_base = model(**self.inputs).logits + # clean up + del model + gc.collect() + torch.cuda.empty_cache() + + # logits from the normal quantized LoRA model + lora_config = LoraConfig(task_type=TaskType.CAUSAL_LM) + kwargs = {} + if bits == 4: + kwargs["load_in_4bit"] = True + elif bits == 8: + kwargs["load_in_8bit"] = True + else: + raise ValueError("bits must be 4 or 8") + + quantized_model = get_peft_model( + AutoModelForCausalLM.from_pretrained(self.model_id, device_map="auto", **kwargs).eval(), + lora_config, + ) + torch.manual_seed(0) + logits_quantized = quantized_model(**self.inputs).logits + del quantized_model + gc.collect() + torch.cuda.empty_cache() + + # logits from quantized LoRA model using LoftQ + loftq_config = LoftQConfig(loftq_bits=bits, loftq_iter=loftq_iter) + lora_config = LoraConfig(task_type=TaskType.CAUSAL_LM, init_lora_weights="loftq", loftq_config=loftq_config) + loftq_model = get_peft_model(AutoModelForCausalLM.from_pretrained(self.model_id).cuda().eval(), lora_config) + torch.manual_seed(0) + logits_loftq = loftq_model(**self.inputs).logits + del loftq_model + gc.collect() + torch.cuda.empty_cache() + + mae_quantized = torch.abs(logits_base - logits_quantized).mean() + mse_quantized = torch.pow(logits_base - logits_quantized, 2).mean() + mae_loftq = torch.abs(logits_base - logits_loftq).mean() + mse_loftq = torch.pow(logits_base - logits_loftq, 2).mean() + return mae_quantized, mse_quantized, mae_loftq, mse_loftq + + def test_bloomz_loftq_4bit(self): + # In this test, we compare the logits of the base model, the quantized LoRA model, and the quantized model + # using LoftQ. When quantizing, we expect a certain level of error. However, we expect the LoftQ quantized + # model to have less error than the normal LoRA quantized model. Note that when using normal LoRA, the + # quantization error is simply the error from quantization without LoRA, as LoRA is a no-op before training. + # We still apply LoRA for the test for consistency. + + mae_quantized, mse_quantized, mae_loftq, mse_loftq = self.get_errors(bits=4) + # first, sanity check that all errors are > 0.0 + self.assertTrue(mae_quantized > 0.0) + self.assertTrue(mse_quantized > 0.0) + self.assertTrue(mae_loftq > 0.0) + self.assertTrue(mse_loftq > 0.0) + + # next, check that LoftQ quantization errors are smaller than LoRA errors by a certain margin + factor = 3 + self.assertTrue(mae_loftq < mae_quantized / factor) + self.assertTrue(mse_loftq < mse_quantized / factor) + + def test_bloomz_loftq_4bit_iter_5(self): + # Same test as the previous one but with 5 iterations. We should expect the error to be even smaller with more + # iterations, but in practice the difference is not that large, at least not for this small base model. + mae_quantized, mse_quantized, mae_loftq, mse_loftq = self.get_errors(bits=4, loftq_iter=5) + # first, sanity check that all errors are > 0.0 + self.assertTrue(mae_quantized > 0.0) + self.assertTrue(mse_quantized > 0.0) + self.assertTrue(mae_loftq > 0.0) + self.assertTrue(mse_loftq > 0.0) + + # next, check that LoftQ quantization errors are smaller than LoRA errors by a certain margin + self.assertTrue(mae_loftq < mae_quantized / self.error_factor) + self.assertTrue(mse_loftq < mse_quantized / self.error_factor) + + def test_bloomz_loftq_8bit(self): + # this currently does not work: + # https://github.com/huggingface/peft/pull/1150#issuecomment-1838891499 + if True: # TODO: remove as soon as the issue is fixed + return + + # Same test as test_bloomz_loftq_4bit but with 8 bits. + mae_quantized, mse_quantized, mae_loftq, mse_loftq = self.get_errors(bits=8) + + # first, sanity check that all errors are > 0.0 + self.assertTrue(mae_quantized > 0.0) + self.assertTrue(mse_quantized > 0.0) + self.assertTrue(mae_loftq > 0.0) + self.assertTrue(mse_loftq > 0.0) + + # next, check that LoftQ quantization errors are smaller than LoRA errors by a certain margin + self.assertTrue(mae_loftq < mae_quantized / self.error_factor) + self.assertTrue(mse_loftq < mse_quantized / self.error_factor) + + def test_bloomz_loftq_8bit_iter_5(self): + # this currently does not work: + # https://github.com/huggingface/peft/pull/1150#issuecomment-1838891499 + if True: # TODO: remove as soon as the issue is fixed + return + + # Same test as test_bloomz_loftq_4bit_iter_5 but with 8 bits. + mae_quantized, mse_quantized, mae_loftq, mse_loftq = self.get_errors(bits=8, loftq_iter=5) + + # first, sanity check that all errors are > 0.0 + self.assertTrue(mae_quantized > 0.0) + self.assertTrue(mse_quantized > 0.0) + self.assertTrue(mae_loftq > 0.0) + self.assertTrue(mse_loftq > 0.0) + + # next, check that LoftQ quantization errors are smaller than LoRA errors by a certain margin + self.assertTrue(mae_loftq < mae_quantized / self.error_factor) + self.assertTrue(mse_loftq < mse_quantized / self.error_factor) + + @require_bitsandbytes @require_torch_gpu class MultiprocessTester(unittest.TestCase):