Skip to content

Commit

Permalink
[CORE] [FEATURE] Add verify_hash (#50)
Browse files Browse the repository at this point in the history
* add verify_hash for quantized model

* use is_sharded choose verify

* mod clean up

* move to utils

* remove unuse import

* mod clean up

* mod clean up

* Add usage doc

* mod name and type check

* mod support hashlib types

* add unit test

* Update auto.py

* clean up verify test code

* format code

---------

Co-authored-by: Qubitium-ModelCloud <[email protected]>
  • Loading branch information
PZS-ModelCloud and Qubitium authored Jun 24, 2024
1 parent b8003ee commit 9a485ba
Show file tree
Hide file tree
Showing 6 changed files with 73 additions and 4 deletions.
7 changes: 6 additions & 1 deletion gptqmodel/models/auto.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Dict, Optional, Union
from typing import Dict, List, Optional, Union

from ..utils.model import check_and_get_model_type
from .baichuan import BaiChuanGPTQ
Expand Down Expand Up @@ -116,6 +116,10 @@ def from_quantized(
disable_exllama: Optional[bool] = None,
disable_exllamav2: bool = False,
use_marlin: bool = False,
# verify weight files matches predefined hash during loading
# usage: hash_format:hash_value, example: md5:ugkdh232
# supports all hashlib hash methods
verify_hash: Optional[Union[str, List[str]]] = None,
**kwargs,
) -> BaseGPTQModel:
# If disable_exllamav2 is True, we want to fall back on the exllama kernel and not the cuda/cuda_old ones.
Expand Down Expand Up @@ -143,6 +147,7 @@ def from_quantized(
disable_exllama=disable_exllama,
disable_exllamav2=disable_exllamav2,
use_marlin=use_marlin,
verify_hash=verify_hash,
**kwargs,
)

13 changes: 11 additions & 2 deletions gptqmodel/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@
from ..utils.model import (auto_dtype_from_config, convert_gptq_v1_to_v2_format, convert_gptq_v2_to_v1_format,
find_layers, get_checkpoints, get_device, get_module_by_name_prefix,
get_module_by_name_suffix, get_moe_layer_modules, gptqmodel_post_init, make_quant,
move_to, nested_move_to, pack_model, simple_dispatch_model)
move_to, nested_move_to, pack_model, simple_dispatch_model, verify_model_hash,
verify_sharded_model_hashes)
from ..version import __version__
from ._const import CPU, CUDA_0, SUPPORTED_MODELS

Expand Down Expand Up @@ -749,6 +750,7 @@ def from_quantized(
disable_exllamav2: bool = False,
format: Optional[FORMAT] = None,
allow_unsafe_loading: bool = False,
verify_hash: Optional[Union[str, List[str]]] = None,
**kwargs,
):
"""load quantized model from local disk"""
Expand Down Expand Up @@ -873,7 +875,14 @@ def from_quantized(
quantize_config.model_file_base_name = true_model_basename

model_save_name = resolved_archive_file # In case a model is sharded, this would be `model.safetensors.index.json` which may later break.

if verify_hash:
if is_sharded:
verfieid = verify_sharded_model_hashes(model_save_name, verify_hash)
else:
verfieid = verify_model_hash(model_save_name, verify_hash)
if not verfieid:
raise ValueError(f"Hash verification failed for {model_save_name}")
logger.info(f"Hash verification succeeded for {model_save_name}")
# == step2: convert model to gptq-model (replace Linear with QuantLinear) == #
def skip(*args, **kwargs):
pass
Expand Down
1 change: 1 addition & 0 deletions gptqmodel/models/dbrx.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from .base import BaseGPTQModel


# placer=holder only as dbrx original models are not supported
# supported dbrx_converted models can be found on https://hf.co/ModelCloud
class DbrxGPTQ(BaseGPTQModel):
Expand Down
31 changes: 31 additions & 0 deletions gptqmodel/utils/model.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import functools
import hashlib
import json
import logging
import os
Expand Down Expand Up @@ -299,6 +300,36 @@ def pack_model(
QuantLinear.warmup(model.to(CUDA_0), seqlen=model.seqlen)
return QuantLinear

def verify_model_hash(file_path: str, verify_hash: str):
if not isinstance(verify_hash, str):
raise ValueError("model verify_hash must be a string")
if ':' not in verify_hash:
raise ValueError("verify_hash must be in the format 'hash_type:hash_value'")
hash_type, hash_value = verify_hash.split(':', 1)
hash_func = getattr(hashlib, hash_type, None)
if not hash_func:
raise ValueError(f"No hash function found for type: {hash_type}")
with open(file_path, "rb") as f:
file_hash = hash_func(f.read()).hexdigest()
return file_hash == hash_value


def verify_sharded_model_hashes(jsonPath: str, verify_hash: List[str]):
if not isinstance(verify_hash, list):
raise ValueError("sharded model verify_hash must be a list")

with open(jsonPath, 'r') as f:
index_data = json.load(f)
weight_map = index_data['weight_map']
shard_files = set(weight_map.values())
if len(shard_files) != len(verify_hash):
raise ValueError("Number of shards and number of hash values do not match.")

for shard_file, expected_hash in zip(shard_files, verify_hash):
if not verify_model_hash(shard_file, expected_hash):
logger.info(f"Hash verification failed for {shard_file}")
return False
return True

def check_and_get_model_type(model_dir, trust_remote_code=False):
config = AutoConfig.from_pretrained(model_dir, trust_remote_code=trust_remote_code)
Expand Down
1 change: 0 additions & 1 deletion tests/test_sharded.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

from gptqmodel import GPTQModel
from gptqmodel.quantization import FORMAT, QuantizeConfig

from transformers import AutoTokenizer


Expand Down
24 changes: 24 additions & 0 deletions tests/test_verify_hash.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import unittest

from gptqmodel import GPTQModel


class TestVerifyHashFunction(unittest.TestCase):
MODEL_ID = "LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-4bit"
EXPECTED_MD5_HASH = "md5:7725c72bc217bcb57b3f1f31d452d871"
EXPECTED_SHA256_HASH = "sha256:2680bb4d5c977ee54f25dae584665641ea887e7bd8e8d7197ce8ffd310e93f2f"


def test_verify_md5_hash_function(self):
# Load the model with MD5 verify_hash parameter
model = GPTQModel.from_quantized(self.MODEL_ID, device="cuda:0", use_marlin=True,
verify_hash=self.EXPECTED_MD5_HASH)
self.assertIsNotNone(model)

def test_verify_sha256_hash_function(self):
# Load the model with SHA-256 verify_hash parameter
model = GPTQModel.from_quantized(self.MODEL_ID, device="cuda:0", use_marlin=True,
verify_hash=self.EXPECTED_SHA256_HASH)
# Add additional checks to ensure the model is loaded correctly
self.assertIsNotNone(model)

0 comments on commit 9a485ba

Please sign in to comment.