diff --git a/gptqmodel/models/auto.py b/gptqmodel/models/auto.py index 2780d9e8..8dedbb58 100644 --- a/gptqmodel/models/auto.py +++ b/gptqmodel/models/auto.py @@ -1,4 +1,4 @@ -from typing import Dict, Optional, Union +from typing import Dict, List, Optional, Union from ..utils.model import check_and_get_model_type from .baichuan import BaiChuanGPTQ @@ -116,6 +116,10 @@ def from_quantized( disable_exllama: Optional[bool] = None, disable_exllamav2: bool = False, use_marlin: bool = False, + # verify weight files matches predefined hash during loading + # usage: hash_format:hash_value, example: md5:ugkdh232 + # supports all hashlib hash methods + verify_hash: Optional[Union[str, List[str]]] = None, **kwargs, ) -> BaseGPTQModel: # If disable_exllamav2 is True, we want to fall back on the exllama kernel and not the cuda/cuda_old ones. @@ -143,6 +147,7 @@ def from_quantized( disable_exllama=disable_exllama, disable_exllamav2=disable_exllamav2, use_marlin=use_marlin, + verify_hash=verify_hash, **kwargs, ) diff --git a/gptqmodel/models/base.py b/gptqmodel/models/base.py index 306ed57d..704b25f0 100644 --- a/gptqmodel/models/base.py +++ b/gptqmodel/models/base.py @@ -27,7 +27,8 @@ from ..utils.model import (auto_dtype_from_config, convert_gptq_v1_to_v2_format, convert_gptq_v2_to_v1_format, find_layers, get_checkpoints, get_device, get_module_by_name_prefix, get_module_by_name_suffix, get_moe_layer_modules, gptqmodel_post_init, make_quant, - move_to, nested_move_to, pack_model, simple_dispatch_model) + move_to, nested_move_to, pack_model, simple_dispatch_model, verify_model_hash, + verify_sharded_model_hashes) from ..version import __version__ from ._const import CPU, CUDA_0, SUPPORTED_MODELS @@ -749,6 +750,7 @@ def from_quantized( disable_exllamav2: bool = False, format: Optional[FORMAT] = None, allow_unsafe_loading: bool = False, + verify_hash: Optional[Union[str, List[str]]] = None, **kwargs, ): """load quantized model from local disk""" @@ -873,7 +875,14 @@ def from_quantized( quantize_config.model_file_base_name = true_model_basename model_save_name = resolved_archive_file # In case a model is sharded, this would be `model.safetensors.index.json` which may later break. - + if verify_hash: + if is_sharded: + verfieid = verify_sharded_model_hashes(model_save_name, verify_hash) + else: + verfieid = verify_model_hash(model_save_name, verify_hash) + if not verfieid: + raise ValueError(f"Hash verification failed for {model_save_name}") + logger.info(f"Hash verification succeeded for {model_save_name}") # == step2: convert model to gptq-model (replace Linear with QuantLinear) == # def skip(*args, **kwargs): pass diff --git a/gptqmodel/models/dbrx.py b/gptqmodel/models/dbrx.py index 530f73ca..bf206e76 100644 --- a/gptqmodel/models/dbrx.py +++ b/gptqmodel/models/dbrx.py @@ -1,5 +1,6 @@ from .base import BaseGPTQModel + # placer=holder only as dbrx original models are not supported # supported dbrx_converted models can be found on https://hf.co/ModelCloud class DbrxGPTQ(BaseGPTQModel): diff --git a/gptqmodel/utils/model.py b/gptqmodel/utils/model.py index ca6480be..b4d5d09e 100644 --- a/gptqmodel/utils/model.py +++ b/gptqmodel/utils/model.py @@ -1,4 +1,5 @@ import functools +import hashlib import json import logging import os @@ -299,6 +300,36 @@ def pack_model( QuantLinear.warmup(model.to(CUDA_0), seqlen=model.seqlen) return QuantLinear +def verify_model_hash(file_path: str, verify_hash: str): + if not isinstance(verify_hash, str): + raise ValueError("model verify_hash must be a string") + if ':' not in verify_hash: + raise ValueError("verify_hash must be in the format 'hash_type:hash_value'") + hash_type, hash_value = verify_hash.split(':', 1) + hash_func = getattr(hashlib, hash_type, None) + if not hash_func: + raise ValueError(f"No hash function found for type: {hash_type}") + with open(file_path, "rb") as f: + file_hash = hash_func(f.read()).hexdigest() + return file_hash == hash_value + + +def verify_sharded_model_hashes(jsonPath: str, verify_hash: List[str]): + if not isinstance(verify_hash, list): + raise ValueError("sharded model verify_hash must be a list") + + with open(jsonPath, 'r') as f: + index_data = json.load(f) + weight_map = index_data['weight_map'] + shard_files = set(weight_map.values()) + if len(shard_files) != len(verify_hash): + raise ValueError("Number of shards and number of hash values do not match.") + + for shard_file, expected_hash in zip(shard_files, verify_hash): + if not verify_model_hash(shard_file, expected_hash): + logger.info(f"Hash verification failed for {shard_file}") + return False + return True def check_and_get_model_type(model_dir, trust_remote_code=False): config = AutoConfig.from_pretrained(model_dir, trust_remote_code=trust_remote_code) diff --git a/tests/test_sharded.py b/tests/test_sharded.py index c2109968..95f92262 100644 --- a/tests/test_sharded.py +++ b/tests/test_sharded.py @@ -4,7 +4,6 @@ from gptqmodel import GPTQModel from gptqmodel.quantization import FORMAT, QuantizeConfig - from transformers import AutoTokenizer diff --git a/tests/test_verify_hash.py b/tests/test_verify_hash.py new file mode 100644 index 00000000..8b534a5d --- /dev/null +++ b/tests/test_verify_hash.py @@ -0,0 +1,24 @@ +import unittest + +from gptqmodel import GPTQModel + + +class TestVerifyHashFunction(unittest.TestCase): + MODEL_ID = "LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-4bit" + EXPECTED_MD5_HASH = "md5:7725c72bc217bcb57b3f1f31d452d871" + EXPECTED_SHA256_HASH = "sha256:2680bb4d5c977ee54f25dae584665641ea887e7bd8e8d7197ce8ffd310e93f2f" + + + def test_verify_md5_hash_function(self): + # Load the model with MD5 verify_hash parameter + model = GPTQModel.from_quantized(self.MODEL_ID, device="cuda:0", use_marlin=True, + verify_hash=self.EXPECTED_MD5_HASH) + self.assertIsNotNone(model) + + def test_verify_sha256_hash_function(self): + # Load the model with SHA-256 verify_hash parameter + model = GPTQModel.from_quantized(self.MODEL_ID, device="cuda:0", use_marlin=True, + verify_hash=self.EXPECTED_SHA256_HASH) + # Add additional checks to ensure the model is loaded correctly + self.assertIsNotNone(model) +