From b4042ab2fd93430ed02043427f204183ca88a002 Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Tue, 28 May 2024 10:49:28 -0700 Subject: [PATCH] Factor out the specific configurations to helper functions Summary: int4wo, int8wo, int8dyn, 8da4w are specific configurations for quantize function, we factor that out in the PR so they are easy to use Test Plan: python test/quantization/test_quant_api.py Reviewers: Subscribers: Tasks: Tags: --- test/quantization/test_quant_api.py | 100 +++---------------------- torchao/quantization/quant_api.py | 110 ++++++++++++++++++++++++++++ 2 files changed, 119 insertions(+), 91 deletions(-) diff --git a/test/quantization/test_quant_api.py b/test/quantization/test_quant_api.py index 70c2562bb..9aae14dd8 100644 --- a/test/quantization/test_quant_api.py +++ b/test/quantization/test_quant_api.py @@ -37,6 +37,10 @@ Quantizer, TwoStepQuantizer, quantize, + get_apply_8da4w_quant, + get_apply_int4wo_quant, + get_apply_int8wo_quant, + get_apply_int8dyn_quant, ) from torchao.quantization.utils import ( TORCH_VERSION_AFTER_2_3, @@ -416,42 +420,11 @@ def test_eval_wrapper(self): # TODO: move to a separate test file @unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "Test only enabled for 2.4+") def test_quantized_tensor_subclass_8da4w(self): - # weight settings groupsize = 32 - mapping_type = MappingType.SYMMETRIC - block_size = (1, groupsize) - target_dtype = torch.int8 - eps = torch.finfo(torch.float32).eps - quant_min = -8 - quant_max = 7 - - # TODO: make a general helper function? - # input settings - def get_per_token_block_size(x): - block_size = [] - for i in range(len(x.shape)-1): - block_size.append(1) - block_size.append(x.shape[-1]) - return block_size - - # input settings - input_mapping_type = MappingType.ASYMMETRIC - input_target_dtype = torch.int8 - input_quant_func = lambda x: to_aq(x, input_mapping_type, get_per_token_block_size(x), input_target_dtype) - m = ToyLinearModel().eval() m_copy = copy.deepcopy(m) example_inputs = m.example_inputs() - - def apply_weight_quant(weight): - return to_aq(weight, mapping_type, block_size, target_dtype, quant_min, quant_max, eps) - - def apply_act_quant(weight): - return to_laq(weight, input_quant_func) - - # note: order is important - m = quantize(m, apply_weight_quant) - m = quantize(m, apply_act_quant) + m = quantize(m, get_apply_8da4w_quant(groupsize=groupsize)) assert isinstance(m.linear1.weight, LinearActQuantizedTensor) assert isinstance(m.linear2.weight, LinearActQuantizedTensor) @@ -474,27 +447,13 @@ def apply_act_quant(weight): @unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "Test only enabled for 2.4+") @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") def test_quantized_tensor_subclass_int4(self): - # weight settings - groupsize = 32 - mapping_type = MappingType.ASYMMETRIC - block_size = (1, groupsize) - target_dtype = torch.int32 - quant_min = 0 - quant_max = 15 - eps = 1e-6 - preserve_zero = False - zero_point_dtype = torch.bfloat16 - zero_point_domain = ZeroPointDomain.FLOAT - # use 1024 so that we don't need padding m = ToyLinearModel(1024, 1024, 1024).eval().to(torch.bfloat16).to("cuda") m_copy = copy.deepcopy(m) example_inputs = tuple(map(lambda x: x.to(torch.bfloat16).to("cuda"), m.example_inputs())) - def apply_weight_quant(weight): - return to_aq(weight, mapping_type, block_size, target_dtype, quant_min, quant_max, eps, zero_point_dtype=zero_point_dtype, preserve_zero=preserve_zero, zero_point_domain=zero_point_domain) - - m = quantize(m, apply_weight_quant) + groupsize = 32 + m = quantize(m, get_apply_int4wo_quant(groupsize=groupsize)) assert isinstance(m.linear1.weight, AffineQuantizedTensor) assert isinstance(m.linear2.weight, AffineQuantizedTensor) @@ -511,21 +470,11 @@ def apply_weight_quant(weight): @unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "Test only enabled for 2.4+") @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") def test_quantized_tensor_subclass_int8(self): - # weight settings - mapping_type = MappingType.SYMMETRIC - target_dtype = torch.int8 - eps = torch.finfo(torch.float32).eps - zero_point_dtype = torch.int64 - m = ToyLinearModel().eval().to(torch.bfloat16) m_copy = copy.deepcopy(m) example_inputs = tuple(map(lambda x: x.to(torch.bfloat16), m.example_inputs())) - def apply_weight_quant(weight): - block_size = (1, weight.shape[1]) - return to_aq(weight, mapping_type, block_size, target_dtype, eps=eps, zero_point_dtype=zero_point_dtype) - - m = quantize(m, apply_weight_quant) + m = quantize(m, get_apply_int8wo_quant()) assert isinstance(m.linear1.weight, AffineQuantizedTensor) assert isinstance(m.linear2.weight, AffineQuantizedTensor) @@ -543,43 +492,12 @@ def apply_weight_quant(weight): @unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "Test only enabled for 2.4+") @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") def test_quantized_tensor_subclass_int8_dyn_quant(self): - # weight settings - mapping_type = MappingType.SYMMETRIC - def get_weight_block_size(x): - return (1, x.shape[1]) - target_dtype = torch.int8 - eps = torch.finfo(torch.float32).eps - zero_point_dtype = torch.int64 - - # input settings - def get_per_token_block_size(x): - block_size = list(x.shape) - for i in range(len(block_size)-1): - block_size[i] = 1 - return block_size - - input_mapping_type = MappingType.SYMMETRIC - input_target_dtype = torch.int8 - input_eps = 1e-5 - input_quant_min = -127 - input_quant_max = 127 - input_quant_func = lambda x: to_aq(x, input_mapping_type, get_per_token_block_size(x), input_target_dtype, eps=input_eps, quant_min=input_quant_min, quant_max=input_quant_max, scale_dtype=torch.float32 if x.dtype == torch.float16 else None) - # use 1024 so that we don't need padding m = ToyLinearModel(1024, 1024, 1024).eval().to(torch.bfloat16).to("cuda") m_copy = copy.deepcopy(m) # setting batch_size to 20 to be compatible with the kernel example_inputs = tuple(map(lambda x: x.to(torch.bfloat16).to("cuda"), m.example_inputs(batch_size=20))) - - def apply_weight_quant(weight): - block_size = get_weight_block_size(weight) - return to_aq(weight, mapping_type, block_size, target_dtype, eps=eps, zero_point_dtype=zero_point_dtype) - - def apply_act_quant(weight): - return to_laq(weight, input_quant_func) - - m = quantize(m, apply_weight_quant) - m = quantize(m, apply_act_quant) + m = quantize(m, get_apply_int8dyn_quant()) assert isinstance(m.linear1.weight, LinearActQuantizedTensor) assert isinstance(m.linear2.weight, LinearActQuantizedTensor) diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index d9b731bac..02678ab2c 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -32,6 +32,12 @@ Int8DynamicallyQuantizedLinearWeight, Int8WeightOnlyQuantizedLinearWeight, QuantizedLinearWeightBase, + to_laq, +) + +from .quant_primitives import ( + MappingType, + ZeroPointDomain, ) from .weight_only import WeightOnlyInt8QuantLinear from .unified import Quantizer, TwoStepQuantizer @@ -56,6 +62,10 @@ "quantize", "autoquant", "_get_subclass_inserter", + "get_apply_8da4w_quant", + "get_apply_int4wo_quant", + "get_apply_int8wo_quant", + "get_apply_int8dyn_quant", ] if TORCH_VERSION_AFTER_2_3: @@ -287,3 +297,103 @@ def filter_fn(module, fqn): _is_linear if filter_fn is None else filter_fn, ) return model + +def get_apply_8da4w_quant(groupsize=32): + + def apply_8da4w_quant(weight): + # avoid circular dep + from torchao.dtypes.aqt import to_aq + + # weight settings + mapping_type = MappingType.SYMMETRIC + block_size = (1, groupsize) + target_dtype = torch.int8 + eps = torch.finfo(torch.float32).eps + quant_min = -8 + quant_max = 7 + + # TODO: make a general helper function? + # input settings + def get_per_token_block_size(x): + block_size = [] + for i in range(len(x.shape)-1): + block_size.append(1) + block_size.append(x.shape[-1]) + return block_size + + # input settings + input_mapping_type = MappingType.ASYMMETRIC + input_target_dtype = torch.int8 + input_quant_func = lambda x: to_aq(x, input_mapping_type, get_per_token_block_size(x), input_target_dtype) + + weight = to_aq(weight, mapping_type, block_size, target_dtype, quant_min, quant_max, eps) + weight = to_laq(weight, input_quant_func) + return weight + + return apply_8da4w_quant + + +def get_apply_int4wo_quant(groupsize=32): + def apply_int4wo_quant(weight): + # avoid circular dep + from torchao.dtypes.aqt import to_aq + + groupsize = 32 + mapping_type = MappingType.ASYMMETRIC + block_size = (1, groupsize) + target_dtype = torch.int32 + quant_min = 0 + quant_max = 15 + eps = 1e-6 + preserve_zero = False + zero_point_dtype = torch.bfloat16 + zero_point_domain = ZeroPointDomain.FLOAT + return to_aq(weight, mapping_type, block_size, target_dtype, quant_min, quant_max, eps, zero_point_dtype=zero_point_dtype, preserve_zero=preserve_zero, zero_point_domain=zero_point_domain) + + return apply_int4wo_quant + + +def get_apply_int8wo_quant(): + def apply_int8wo_quant(weight): + # avoid circular dep + from torchao.dtypes.aqt import to_aq + + mapping_type = MappingType.SYMMETRIC + target_dtype = torch.int8 + eps = torch.finfo(torch.float32).eps + zero_point_dtype = torch.int64 + block_size = (1, weight.shape[1]) + return to_aq(weight, mapping_type, block_size, target_dtype, eps=eps, zero_point_dtype=zero_point_dtype) + return apply_int8wo_quant + +def get_apply_int8dyn_quant(): + def apply_int8dyn_quant(weight): + # avoid circular dep + from torchao.dtypes.aqt import to_aq + # weight settings + mapping_type = MappingType.SYMMETRIC + def get_weight_block_size(x): + return (1, x.shape[1]) + target_dtype = torch.int8 + eps = torch.finfo(torch.float32).eps + zero_point_dtype = torch.int64 + + # input settings + def get_per_token_block_size(x): + block_size = list(x.shape) + for i in range(len(block_size)-1): + block_size[i] = 1 + return block_size + + input_mapping_type = MappingType.SYMMETRIC + input_target_dtype = torch.int8 + input_eps = 1e-5 + input_quant_min = -127 + input_quant_max = 127 + input_quant_func = lambda x: to_aq(x, input_mapping_type, get_per_token_block_size(x), input_target_dtype, eps=input_eps, quant_min=input_quant_min, quant_max=input_quant_max, scale_dtype=torch.float32 if x.dtype == torch.float16 else None) + + block_size = get_weight_block_size(weight) + weight = to_aq(weight, mapping_type, block_size, target_dtype, eps=eps, zero_point_dtype=zero_point_dtype) + weight = to_laq(weight, input_quant_func) + return weight + return apply_int8dyn_quant