Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Core] [!4Bit] fix cannot pickle 'module' object for 8 bit (fix #47) #49

Merged
merged 21 commits into from
Jun 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 28 additions & 13 deletions gptqmodel/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,9 +102,9 @@ def hf_device_map(self):
return getattr(self.model, "hf_device_map", None)

def _prepare_dataset_for_quantization(
self,
calibration_dataset: List[Dict[str, Union[List[int], torch.LongTensor]]],
batch_size: int = 1,
self,
calibration_dataset: List[Dict[str, Union[List[int], torch.LongTensor]]],
batch_size: int = 1,
):
def _convert_tensor_to_list(tensor):
if isinstance(tensor, torch.Tensor):
Expand Down Expand Up @@ -138,7 +138,7 @@ def _convert_tensor_to_list(tensor):
pad_token_id = self.config.eos_token_id

new_calibration_dataset = [
collate_data(new_calibration_dataset[start : start + batch_size], pad_token_id)
collate_data(new_calibration_dataset[start: start + batch_size], pad_token_id)
for start in range(0, len(new_calibration_dataset), batch_size)
]
for new_example in new_calibration_dataset:
Expand Down Expand Up @@ -183,7 +183,7 @@ def quantize(

if len(calibration_dataset) < MIN_CALIBRATION_DATASET_SIZE:
logger.warning(f"Calibration dataset size should be greater than {MIN_CALIBRATION_DATASET_SIZE}. "
f"Current size: {len(calibration_dataset)}.")
f"Current size: {len(calibration_dataset)}.")

# Calculate the average length of the average input_ids
total_input_ids_length = 0
Expand All @@ -194,7 +194,7 @@ def quantize(

if avg < MIN_CALIBRATION_DATASET_INPUT_IDS_AVG_LENGTH:
logger.warning(f"The average length of input_ids of calibration_dataset should be greater than "
f"{MIN_CALIBRATION_DATASET_INPUT_IDS_AVG_LENGTH}! Current AVG is {avg}.")
f"{MIN_CALIBRATION_DATASET_INPUT_IDS_AVG_LENGTH}! Current AVG is {avg}.")

device_map = self.hf_device_map
if device_map:
Expand Down Expand Up @@ -239,10 +239,7 @@ def store_input_hook(_, args, kwargs):
if pos_ids is not None:
position_ids.append(move_to(pos_ids, data_device))
one_kwargs = {}
for (
k,
v,
) in kwargs.items(): # make sure other arguments also be captured
for (k, v) in kwargs.items(): # make sure other arguments also be captured
if k not in ["hidden_states", "attention_mask", "position_ids"]:
one_kwargs[k] = nested_move_to(v, data_device)
layer_input_kwargs.append(one_kwargs)
Expand Down Expand Up @@ -497,8 +494,8 @@ def save_quantized(

if model_base_name is None:
model_base_name = (
self.quantize_config.model_file_base_name or
f"gptq_model-{self.quantize_config.bits}bit-{self.quantize_config.group_size}g"
self.quantize_config.model_file_base_name or
f"gptq_model-{self.quantize_config.bits}bit-{self.quantize_config.group_size}g"
)

state_dict = self.model.state_dict()
Expand Down Expand Up @@ -542,7 +539,25 @@ def save_quantized(
if format is None and quantize_config.format == FORMAT.GPTQ:
# Model qzeros may be edited in place.
# TODO: avoid inplace modification of the weights
model = copy.deepcopy(self.model)
# fix ModelCloud/GPTQModel/issues/47
# fix gptqmodel_cuda cannot be serialized
# no need to set it back, no calculation below
if quantize_config.bits != 4:
cuda_name_modules = {}
from gptqmodel.nn_modules.qlinear.qlinear_cuda import BaseCudaQuantLinear
for name, module in model.named_modules():
if isinstance(module, BaseCudaQuantLinear):
cuda_name_modules[name] = module.gptqmodel_cuda
module.gptqmodel_cuda = None
model = copy.deepcopy(self.model)

for name, modules in model.named_modules():
if isinstance(module, BaseCudaQuantLinear) and name in cuda_name_modules:
module.gptqmodel_cuda = cuda_name_modules[name]

del cuda_name_modules
else:
model = copy.deepcopy(self.model)
model = convert_gptq_v2_to_v1_format(
model, quantize_config=quantize_config, qlinear_kernel=self.qlinear_kernel
)
Expand Down
5 changes: 5 additions & 0 deletions gptqmodel/nn_modules/qlinear/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,8 @@
class BaseQuantLinear(nn.Module):
# override me
QUANT_TYPE = "base"


class BaseCudaQuantLinear(BaseQuantLinear):
# override me
QUANT_TYPE = "base-cuda"
4 changes: 2 additions & 2 deletions gptqmodel/nn_modules/qlinear/qlinear_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,12 @@
import torch
import torch.nn as nn
import transformers
from gptqmodel.nn_modules.qlinear import BaseQuantLinear
from gptqmodel.nn_modules.qlinear import BaseCudaQuantLinear

logger = getLogger(__name__)


class QuantLinear(BaseQuantLinear):
class QuantLinear(BaseCudaQuantLinear):
QUANT_TYPE = "cuda"

def __init__(
Expand Down
4 changes: 2 additions & 2 deletions gptqmodel/nn_modules/qlinear/qlinear_cuda_old.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,12 @@
import torch
import torch.nn as nn
import transformers
from gptqmodel.nn_modules.qlinear import BaseQuantLinear
from gptqmodel.nn_modules.qlinear import BaseCudaQuantLinear

logger = getLogger(__name__)


class QuantLinear(BaseQuantLinear):
class QuantLinear(BaseCudaQuantLinear):
QUANT_TYPE = "cuda-old"

def __init__(
Expand Down
54 changes: 39 additions & 15 deletions tests/test_quant_formats.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,18 @@


class TestQuantization(unittest.TestCase):

def setUp(self):
self.pretrained_model_dir = "TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T"

self.tokenizer = AutoTokenizer.from_pretrained(self.pretrained_model_dir, use_fast=True)
self.calibration_dataset = [
self.tokenizer(
"auto-gptq is an easy-to-use model quantization library with user-friendly apis, based on GPTQ algorithm."
),
self.tokenizer("Today I am in Paris and it is a wonderful day."),
]

@parameterized.expand(
[
(False, True, FORMAT.GPTQ_V2),
Expand All @@ -21,16 +33,6 @@ class TestQuantization(unittest.TestCase):
]
)
def test_quantize(self, use_marlin: bool, sym: bool, format: FORMAT):
pretrained_model_dir = "TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T"

tokenizer = AutoTokenizer.from_pretrained(pretrained_model_dir, use_fast=True)
calibration_dataset = [
tokenizer(
"auto-gptq is an easy-to-use model quantization library with user-friendly apis, based on GPTQ algorithm."
),
tokenizer("Today I am in Paris and it is a wonderful day."),
]

quantize_config = QuantizeConfig(
bits=4,
group_size=128,
Expand All @@ -40,17 +42,15 @@ def test_quantize(self, use_marlin: bool, sym: bool, format: FORMAT):
)

model = GPTQModel.from_pretrained(
pretrained_model_dir,
self.pretrained_model_dir,
quantize_config=quantize_config,
use_flash_attention_2=False,
)

model.quantize(calibration_dataset)
model.quantize(self.calibration_dataset)

with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(
tmpdirname,
)
model.save_quantized(tmpdirname)

logging.info(f"Saved config mem: {model.quantize_config}")

Expand Down Expand Up @@ -117,3 +117,27 @@ def test_quantize(self, use_marlin: bool, sym: bool, format: FORMAT):
format=format,
)
assert isinstance(model.quantize_config, QuantizeConfig)

def test_gptq_8bit(self):
quantize_config = QuantizeConfig(
bits=8,
group_size=128,
format=FORMAT.GPTQ,
desc_act=True
)

model = GPTQModel.from_pretrained(
self.pretrained_model_dir,
quantize_config=quantize_config,
)

model.quantize(self.calibration_dataset)

with tempfile.TemporaryDirectory() as tmpdirname:
err = None
try:
model.save_quantized(tmpdirname)
except Exception as e:
print(e)
err = e
self.assertTrue(err is None)