Skip to content

Commit

Permalink
Address the comments by amyeroberts
Browse files Browse the repository at this point in the history
  • Loading branch information
dtlzhuangz committed Apr 22, 2024
1 parent 9108dd9 commit c8fb808
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 8 deletions.
8 changes: 5 additions & 3 deletions src/transformers/integrations/eetq.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,10 @@ def _replace_with_eetq_linear(
Returns the converted model and a boolean that indicates if the conversion has been successfull or not.
"""
for name, module in model.named_children():
if current_key_name is None:
current_key_name = []
if current_key_name is None:
current_key_name = []

for name, module in model.named_children():
current_key_name.append(name)

if (isinstance(module, nn.Linear)) and name not in modules_to_not_convert:
Expand Down Expand Up @@ -105,6 +106,7 @@ def replace_with_eetq_linear(

if quantization_config.modules_to_not_convert is not None:
modules_to_not_convert.extend(quantization_config.modules_to_not_convert)
modules_to_not_convert = list(set(modules_to_not_convert))
model, has_been_replaced = _replace_with_eetq_linear(
model, modules_to_not_convert, current_key_name, quantization_config, pre_quantized=pre_quantized
)
Expand Down
7 changes: 7 additions & 0 deletions src/transformers/quantizers/quantizer_eetq.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,13 @@ def validate_environment(self, *args, **kwargs):
def update_torch_dtype(self, torch_dtype: "torch.dtype") -> "torch.dtype":
if torch_dtype is None:
torch_dtype = torch.float16
logger.info(
"Overriding torch_dtype=%s with `torch_dtype=torch.float16` due to "
"requirements of `eetq` to enable model loading in 8-bit. "
"Pass your own torch_dtype to specify the dtype of the remaining non-linear layers or pass"
" torch_dtype=torch.float16 to remove this warning.",
torch_dtype,
)
elif torch_dtype != torch.float16:
logger.info("We suggest you to set `torch_dtype=torch.float16` for better efficiency with EETQ.")
return torch_dtype
Expand Down
5 changes: 0 additions & 5 deletions tests/quantization/eetq_integration/test_eetq.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,11 +140,6 @@ def test_quantized_model(self):
output = self.quantized_model.generate(**input_ids, max_new_tokens=self.max_new_tokens)
self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT)

def test_raise_if_non_quantized(self):
model_id = "facebook/opt-125m"
quantization_config = EetqConfig()
_ = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto", quantization_config=quantization_config)

def test_save_pretrained(self):
"""
Simple test that checks if the quantized model is working properly after being saved and loaded
Expand Down

0 comments on commit c8fb808

Please sign in to comment.