Skip to content

Commit

Permalink
[FIX] missing .py files for models using trust_remote (#302)
Browse files Browse the repository at this point in the history
* add save_quant copy py from origin path

* mod clean up

* Rename test_quant_unsupport_transformers.py to test_quant_trust_remote.py

* Update base.py

* Update test_quant_trust_remote.py

---------

Co-authored-by: Qubitium-ModelCloud <[email protected]>
  • Loading branch information
PZS-ModelCloud and Qubitium authored Jul 26, 2024
1 parent a30b1f7 commit a046135
Show file tree
Hide file tree
Showing 4 changed files with 89 additions and 4 deletions.
22 changes: 19 additions & 3 deletions gptqmodel/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import logging
import os
import re
import shutil
from os.path import basename, isfile, join
from typing import Dict, List, Optional, Union

Expand Down Expand Up @@ -31,7 +32,7 @@
convert_gptq_v2_to_v1_format, find_layers, get_checkpoints, get_device,
get_module_by_name_prefix, get_module_by_name_suffix, get_moe_layer_modules,
gptqmodel_post_init, make_quant, move_to, nested_move_to, pack_model,
simple_dispatch_model, verify_model_hash, verify_sharded_model_hashes)
simple_dispatch_model, verify_model_hash, verify_sharded_model_hashes, copy_py_files)
from ..version import __version__
from ._const import CPU, CUDA_0, DEVICE, SUPPORTED_MODELS

Expand Down Expand Up @@ -80,6 +81,8 @@ def __init__(
quantize_config: QuantizeConfig,
qlinear_kernel: nn.Module = None,
load_quantized_model: bool = False,
trust_remote_code: bool = False,
model_name_or_path: str = None,
):
super().__init__()

Expand All @@ -91,6 +94,8 @@ def __init__(

# compat: state to assist in checkpoint_format gptq(v1) to gptq_v2 conversion
self.qlinear_kernel = qlinear_kernel
self.trust_remote_code = trust_remote_code
self.model_name_or_path = model_name_or_path

@property
def quantized(self):
Expand Down Expand Up @@ -770,6 +775,10 @@ def save_quantized(
quantize_config.model_file_base_name = model_base_name
quantize_config.save_pretrained(save_dir)

# need to copy .py files for model/tokenizers not yet merged to HF transformers
if self.trust_remote_code:
copy_py_files(save_dir, model_id_or_path=self.model_name_or_path)

def get_model_with_quantize(self, quantize_config):
config = AutoConfig.from_pretrained(
quantize_config.model_name_or_path,
Expand Down Expand Up @@ -913,8 +922,13 @@ def skip(*args, **kwargs):
logger.warning("can't get model's sequence length from model config, will set to 4096.")
model.seqlen = 4096
model.eval()

return cls(model, quantized=False, quantize_config=quantize_config)
return cls(
model,
quantized=False,
quantize_config=quantize_config,
trust_remote_code=trust_remote_code,
model_name_or_path=pretrained_model_name_or_path
)

@classmethod
def from_quantized(
Expand Down Expand Up @@ -1331,6 +1345,8 @@ def skip(*args, **kwargs):
quantize_config=quantize_config,
qlinear_kernel=qlinear_kernel,
load_quantized_model=True,
trust_remote_code=trust_remote_code,
model_name_or_path=model_name_or_path,
)

def __getattr__(self, item):
Expand Down
18 changes: 17 additions & 1 deletion gptqmodel/utils/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
import os
from logging import getLogger
from typing import List, Optional

from huggingface_hub import HfApi, hf_hub_download
import shutil
import accelerate
import threadpoolctl as tctl
import torch
Expand Down Expand Up @@ -633,3 +634,18 @@ def check_to_quantized(config):
if config.bits > 8 or "fp" in config.data_type or "float" in config.data_type:
return False
return True

def copy_py_files(save_dir, file_extension=".py", model_id_or_path=""):
os.makedirs(save_dir, exist_ok=True)

if os.path.isdir(model_id_or_path):
py_files = [f for f in os.listdir(model_id_or_path) if f.endswith('.py')]
for file in py_files:
shutil.copy2(os.path.join(model_id_or_path, file), save_dir)
else:
api = HfApi()
model_info = api.model_info(model_id_or_path)
for file in model_info.siblings:
if file.rfilename.endswith(file_extension):
_ = hf_hub_download(repo_id=model_id_or_path, filename=file.rfilename,
local_dir=save_dir)
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,4 @@ ninja>=1.11.1.1
protobuf>=4.25.3
intel_extension_for_transformers>=1.4.2
auto-round==0.2
huggingface-hub>=0.24.2
52 changes: 52 additions & 0 deletions tests/test_quant_trust_remote.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# -- do not touch
import os

os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
# -- end do not touch

import tempfile # noqa: E402
import unittest # noqa: E402

from datasets import load_dataset # noqa: E402
from gptqmodel import GPTQModel # noqa: E402
from gptqmodel.quantization import FORMAT, QuantizeConfig # noqa: E402
from transformers import AutoTokenizer # noqa: E402

class TestQuantWithTrustRemoteTrue(unittest.TestCase):
@classmethod
def setUpClass(self):
self.MODEL_ID = "openbmb/MiniCPM-2B-dpo-bf16"
self.tokenizer = AutoTokenizer.from_pretrained(self.MODEL_ID, use_fast=True, trust_remote_code=True)

if not self.tokenizer.pad_token_id:
self.tokenizer.pad_token_id = self.tokenizer.eos_token_id

traindata = load_dataset("wikitext", "wikitext-2-raw-v1", split="train").filter(lambda x: len(x['text']) >= 512)
self.calibration_dataset = [self.tokenizer(example["text"]) for example in traindata.select(range(1024))]

def test_diff_batch(self):
quantize_config = QuantizeConfig(
bits=4,
group_size=128,
format=FORMAT.GPTQ,
)

model = GPTQModel.from_pretrained(
self.MODEL_ID,
quantize_config=quantize_config,
trust_remote_code=True,
)

model.quantize(self.calibration_dataset, batch_size=64)

with tempfile.TemporaryDirectory() as tmp_dir:
model.save_quantized(
tmp_dir,
)

del model
py_files = [f for f in os.listdir(tmp_dir) if f.endswith('.py')]
expected_files = ["modeling_minicpm.py", "configuration_minicpm.py"]
self.assertEqual(py_files, expected_files)


0 comments on commit a046135

Please sign in to comment.