Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CORE] [FEATURE] Add verify_hash #50

Merged
merged 14 commits into from
Jun 24, 2024
Merged
7 changes: 6 additions & 1 deletion gptqmodel/models/auto.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Dict, Optional, Union
from typing import Dict, List, Optional, Union

from ..utils.model import check_and_get_model_type
from .baichuan import BaiChuanGPTQ
Expand Down Expand Up @@ -116,6 +116,10 @@ def from_quantized(
disable_exllama: Optional[bool] = None,
disable_exllamav2: bool = False,
use_marlin: bool = False,
# verify weight files matches predefined hash during loading
# usage: hash_format:hash_value, example: md5:ugkdh232
# supports all hashlib hash methods
verify_hash: Optional[Union[str, List[str]]] = None,
**kwargs,
) -> BaseGPTQModel:
# If disable_exllamav2 is True, we want to fall back on the exllama kernel and not the cuda/cuda_old ones.
Expand Down Expand Up @@ -143,6 +147,7 @@ def from_quantized(
disable_exllama=disable_exllama,
disable_exllamav2=disable_exllamav2,
use_marlin=use_marlin,
verify_hash=verify_hash,
**kwargs,
)

13 changes: 11 additions & 2 deletions gptqmodel/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@
from ..utils.model import (auto_dtype_from_config, convert_gptq_v1_to_v2_format, convert_gptq_v2_to_v1_format,
find_layers, get_checkpoints, get_device, get_module_by_name_prefix,
get_module_by_name_suffix, get_moe_layer_modules, gptqmodel_post_init, make_quant,
move_to, nested_move_to, pack_model, simple_dispatch_model)
move_to, nested_move_to, pack_model, simple_dispatch_model, verify_model_hash,
verify_sharded_model_hashes)
from ..version import __version__
from ._const import CPU, CUDA_0, SUPPORTED_MODELS

Expand Down Expand Up @@ -749,6 +750,7 @@ def from_quantized(
disable_exllamav2: bool = False,
format: Optional[FORMAT] = None,
allow_unsafe_loading: bool = False,
verify_hash: Optional[Union[str, List[str]]] = None,
**kwargs,
):
"""load quantized model from local disk"""
Expand Down Expand Up @@ -873,7 +875,14 @@ def from_quantized(
quantize_config.model_file_base_name = true_model_basename

model_save_name = resolved_archive_file # In case a model is sharded, this would be `model.safetensors.index.json` which may later break.

if verify_hash:
if is_sharded:
verfieid = verify_sharded_model_hashes(model_save_name, verify_hash)
else:
verfieid = verify_model_hash(model_save_name, verify_hash)
if not verfieid:
raise ValueError(f"Hash verification failed for {model_save_name}")
logger.info(f"Hash verification succeeded for {model_save_name}")
# == step2: convert model to gptq-model (replace Linear with QuantLinear) == #
def skip(*args, **kwargs):
pass
Expand Down
1 change: 1 addition & 0 deletions gptqmodel/models/dbrx.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from .base import BaseGPTQModel


# placer=holder only as dbrx original models are not supported
# supported dbrx_converted models can be found on https://hf.co/ModelCloud
class DbrxGPTQ(BaseGPTQModel):
Expand Down
31 changes: 31 additions & 0 deletions gptqmodel/utils/model.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import functools
import hashlib
import json
import logging
import os
Expand Down Expand Up @@ -299,6 +300,36 @@ def pack_model(
QuantLinear.warmup(model.to(CUDA_0), seqlen=model.seqlen)
return QuantLinear

def verify_model_hash(file_path: str, verify_hash: str):
if not isinstance(verify_hash, str):
raise ValueError("model verify_hash must be a string")
if ':' not in verify_hash:
raise ValueError("verify_hash must be in the format 'hash_type:hash_value'")
hash_type, hash_value = verify_hash.split(':', 1)
hash_func = getattr(hashlib, hash_type, None)
if not hash_func:
raise ValueError(f"No hash function found for type: {hash_type}")
with open(file_path, "rb") as f:
file_hash = hash_func(f.read()).hexdigest()
return file_hash == hash_value


def verify_sharded_model_hashes(jsonPath: str, verify_hash: List[str]):
if not isinstance(verify_hash, list):
raise ValueError("sharded model verify_hash must be a list")

with open(jsonPath, 'r') as f:
index_data = json.load(f)
weight_map = index_data['weight_map']
shard_files = set(weight_map.values())
if len(shard_files) != len(verify_hash):
raise ValueError("Number of shards and number of hash values do not match.")

for shard_file, expected_hash in zip(shard_files, verify_hash):
if not verify_model_hash(shard_file, expected_hash):
logger.info(f"Hash verification failed for {shard_file}")
return False
return True

def check_and_get_model_type(model_dir, trust_remote_code=False):
config = AutoConfig.from_pretrained(model_dir, trust_remote_code=trust_remote_code)
Expand Down
1 change: 0 additions & 1 deletion tests/test_sharded.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

from gptqmodel import GPTQModel
from gptqmodel.quantization import FORMAT, QuantizeConfig

from transformers import AutoTokenizer


Expand Down
24 changes: 24 additions & 0 deletions tests/test_verify_hash.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import unittest

from gptqmodel import GPTQModel


class TestVerifyHashFunction(unittest.TestCase):
MODEL_ID = "LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-4bit"
EXPECTED_MD5_HASH = "md5:7725c72bc217bcb57b3f1f31d452d871"
EXPECTED_SHA256_HASH = "sha256:2680bb4d5c977ee54f25dae584665641ea887e7bd8e8d7197ce8ffd310e93f2f"


def test_verify_md5_hash_function(self):
# Load the model with MD5 verify_hash parameter
model = GPTQModel.from_quantized(self.MODEL_ID, device="cuda:0", use_marlin=True,
verify_hash=self.EXPECTED_MD5_HASH)
self.assertIsNotNone(model)

def test_verify_sha256_hash_function(self):
# Load the model with SHA-256 verify_hash parameter
model = GPTQModel.from_quantized(self.MODEL_ID, device="cuda:0", use_marlin=True,
verify_hash=self.EXPECTED_SHA256_HASH)
# Add additional checks to ensure the model is loaded correctly
self.assertIsNotNone(model)