diff --git a/test/integration/test_integration.py b/test/integration/test_integration.py index 81fd9fcf5..1bcc44f08 100644 --- a/test/integration/test_integration.py +++ b/test/integration/test_integration.py @@ -1173,13 +1173,13 @@ def test_on_dummy_distilbert(self): class TestAutoQuant(unittest.TestCase): @parameterized.expand(combine_parameters(COMMON_DEVICE_DTYPE, [ - (16, 128, 128), - (64, 128, 128), + # (16, 128, 128), + # (64, 128, 128), # (2**15, 128, 128), TODO: Runs out of shared memory on T4 - (16, 128, 256), + (2, 128, 256), # (64, 128, 256), # TODO: Runs out of shared memory on T4 - (16, 256, 128), - (64, 256, 128), + # (16, 256, 128), + # (64, 256, 128), # (256, 256, 128), TODO: Runs out of shared memory on T4 ])) @unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "autoquant requires 2.3+.") @@ -1194,7 +1194,7 @@ def test_autoquant_one_input(self, device, dtype, m, k, n): if m == 1: self.skipTest(f"Shape {(m, k, n)} requires sm80+") torch._inductor.config.epilogue_fusion = False - torch._inductor.config.use_mixed_mm = True + # torch._inductor.config.use_mixed_mm = True torch._inductor.config.force_fuse_int_mm_with_mul = True torch._dynamo.config.automatic_dynamic_shapes = False diff --git a/torchao/_models/llama/benchmark_results.txt b/torchao/_models/llama/benchmark_results.txt index e07a3e679..d26cb82a9 100644 --- a/torchao/_models/llama/benchmark_results.txt +++ b/torchao/_models/llama/benchmark_results.txt @@ -27,3 +27,6 @@ kv cache quantization: 20240801094415, tok/s= 87.20, mem/s=1308.88 GB/s, peak_mem=17.22 GB, model_size=15.01 GB quant: None, mod: Meta-Llama-3-8B, kv_quant: True, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --kv_cache_quantization --compile --num_samples 5 --max_new_tokens 2048 --top_k 200 --temperature 0.8 20240801095615, tok/s= 80.87, mem/s=1213.82 GB/s, peak_mem=19.77 GB, model_size=15.01 GB quant: None, mod: Meta-Llama-3-8B, kv_quant: False, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 8192 --top_k 200 --temperature 0.8 20240801100912, tok/s= 74.65, mem/s=1120.41 GB/s, peak_mem=19.29 GB, model_size=15.01 GB quant: None, mod: Meta-Llama-3-8B, kv_quant: True, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --kv_cache_quantization --compile --num_samples 5 --max_new_tokens 8192 --top_k 200 --temperature 0.8 + +20240806071013, tok/s=172.58, mem/s=1161.55 GB/s, peak_mem= 8.90 GB, model_size= 6.73 GB quant: autoquant, mod: Llama-2-7b-chat-hf, kv_quant: False, compile: True, compile_prefill: True, dtype: torch.bfloat16, device: cuda repro: python generate.py --quantization autoquant --checkpoint_path ../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth --device cuda --precision torch.bfloat16 --compile --compile_prefill --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 +20240806073549, tok/s=158.04, mem/s=1192.77 GB/s, peak_mem= 9.99 GB, model_size= 7.55 GB quant: autoquant, mod: Meta-Llama-3-8B, kv_quant: False, compile: True, compile_prefill: True, dtype: torch.bfloat16, device: cuda repro: python generate.py --quantization autoquant --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --compile --compile_prefill --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 diff --git a/torchao/_models/llama/benchmarks.sh b/torchao/_models/llama/benchmarks.sh index 6dd9c10d9..402acff91 100644 --- a/torchao/_models/llama/benchmarks.sh +++ b/torchao/_models/llama/benchmarks.sh @@ -2,31 +2,31 @@ export CHECKPOINT_PATH=../../../checkpoints # path to checkpoints folder export MODEL_REPO=meta-llama/Llama-2-7b-chat-hf -python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --precision torch.float32 --write_result benchmark_results.txt -python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --write_result benchmark_results.txt -python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --write_result benchmark_results.txt -# in readme -python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --write_result benchmark_results.txt -python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int8dq --write_result benchmark_results.txt -python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int8wo --write_result benchmark_results.txt -python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int4wo-64 --write_result benchmark_results.txt +# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --precision torch.float32 --write_result benchmark_results.txt +# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --write_result benchmark_results.txt +# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --write_result benchmark_results.txt +# # in readme +# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --write_result benchmark_results.txt +# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int8dq --write_result benchmark_results.txt +# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int8wo --write_result benchmark_results.txt +# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int4wo-64 --write_result benchmark_results.txt python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --quantization autoquant --write_result benchmark_results.txt export MODEL_REPO=meta-llama/Meta-Llama-3-8B -python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --precision torch.float32 --write_result benchmark_results.txt -python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --write_result benchmark_results.txt -python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --write_result benchmark_results.txt -# in readme -python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --write_result benchmark_results.txt -python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int8dq --write_result benchmark_results.txt -python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int8wo --write_result benchmark_results.txt -python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int4wo-64 --write_result benchmark_results.txt +# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --precision torch.float32 --write_result benchmark_results.txt +# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --write_result benchmark_results.txt +# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --write_result benchmark_results.txt +# # in readme +# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --write_result benchmark_results.txt +# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int8dq --write_result benchmark_results.txt +# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int8wo --write_result benchmark_results.txt +# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int4wo-64 --write_result benchmark_results.txt python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --quantization autoquant --write_result benchmark_results.txt -export MODEL_REPO=meta-llama/Meta-Llama-3-8B -python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --write_result benchmark_results.txt -python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --write_result benchmark_results.txt --kv_cache_quantization -python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --write_result benchmark_results.txt --max_new_tokens 2048 -python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --write_result benchmark_results.txt --kv_cache_quantization --max_new_tokens 2048 -python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --write_result benchmark_results.txt --max_new_tokens 8192 -python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --write_result benchmark_results.txt --kv_cache_quantization --max_new_tokens 8192 +# export MODEL_REPO=meta-llama/Meta-Llama-3-8B +# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --write_result benchmark_results.txt +# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --write_result benchmark_results.txt --kv_cache_quantization +# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --write_result benchmark_results.txt --max_new_tokens 2048 +# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --write_result benchmark_results.txt --kv_cache_quantization --max_new_tokens 2048 +# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --write_result benchmark_results.txt --max_new_tokens 8192 +# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --write_result benchmark_results.txt --kv_cache_quantization --max_new_tokens 8192 diff --git a/torchao/dtypes/affine_quantized_tensor.py b/torchao/dtypes/affine_quantized_tensor.py index 0d271776a..4591b0174 100644 --- a/torchao/dtypes/affine_quantized_tensor.py +++ b/torchao/dtypes/affine_quantized_tensor.py @@ -155,6 +155,10 @@ def dequantize(self, output_dtype=None): int_data, scale, zero_point = self.layout_tensor.get_plain() return dequantize_affine(int_data, self.block_size, scale, zero_point, int_data.dtype, self.quant_min, self.quant_max, self.zero_point_domain, output_dtype=output_dtype) + @staticmethod + def _quantized_linear_op(input_tensor, weight_tensor, bias): + return _quantized_linear_op(input_tensor, weight_tensor, bias) + def __tensor_flatten__(self): return ["layout_tensor"], [self.block_size, self.shape, self.quant_min, self.quant_max, self.zero_point_domain, self.dtype] @@ -832,7 +836,7 @@ def _(func, types, args, kwargs): # is not picked up by any of the dispatch paths in `_quantized_linear_op`, this allows us to # make the branches easier to understand in `_quantized_linear_op` try: - return _quantized_linear_op(input_tensor, weight_tensor, bias) + return weight_tensor._quantized_linear_op(input_tensor, weight_tensor, bias) except: if isinstance(input_tensor, AffineQuantizedTensor): input_tensor = input_tensor.dequantize() diff --git a/torchao/quantization/autoquant.py b/torchao/quantization/autoquant.py index 83d7837d3..122b90950 100644 --- a/torchao/quantization/autoquant.py +++ b/torchao/quantization/autoquant.py @@ -1,10 +1,16 @@ import torch import torchao +from torchao.quantization.quant_primitives import ( + MappingType, + ZeroPointDomain, +) from .subclass import ( # noqa Int8DynamicallyQuantizedLinearWeight, Int8WeightOnlyQuantizedLinearWeight, QuantizedLinearWeightBase, ) +from torchao.dtypes import AffineQuantizedTensor, PlainLayoutType +from torchao.quantization.linear_activation_quantized_tensor import LinearActivationQuantizedTensor from torch.utils._python_dispatch import return_and_correct_aliasing from .quant_primitives import ( safe_int_mm, @@ -252,9 +258,9 @@ class AQMixin(): def _autoquant_test(cls, act_mat, weight, bias, best_time, mode=["relu", None]): w_qtensor = cls.from_float(weight) if _is_interpolate_mode(mode): - q_c_op = torch.compile(cls._quantized_op, mode="max-autotune-no-cudagraphs") + q_c_op = torch.compile(cls._quantized_linear_op, mode="max-autotune-no-cudagraphs") else: - func = lambda a,b,c: F.relu(cls._quantized_op(F.relu(a), b, c)) + func = lambda a,b,c: F.relu(cls._quantized_linear_op(F.relu(a), b, c)) q_c_op = torch.compile(func, mode="max-autotune-no-cudagraphs") res = do_autoquant_bench(q_c_op, act_mat, w_qtensor, bias, warmup=25, rep=100) if res < best_time*1.1: @@ -263,10 +269,53 @@ def _autoquant_test(cls, act_mat, weight, bias, best_time, mode=["relu", None]): print(f">>time: {res:0.3f}ms for {cls}, to_beat: {best_time:0.3f}ms ") return res -class AQInt8DynamicallyQuantizedLinearWeight(AQMixin, Int8DynamicallyQuantizedLinearWeight): +###### TODO !!!!!!!!!!!!!!! +# 1) make class method from_float (just duplicate code) +# 2) undo changes to quant_api? +# 3) point to new quantized_op location +# 4) rewrite the dynamic autoquant test + +class AQInt8DynamicallyQuantizedLinearWeight(AQMixin, LinearActivationQuantizedTensor): """ AutoQuantizable version of Int8DynamicallyQuantizedLinearWeight """ + @classmethod + def from_float(cls, weight): + in_features = weight.shape[1] + # int8 dynamic quantization only has benefit when in_feature > 16 + # if in_features <= 16: + # return weight + + # avoid circular dep + from torchao.dtypes import to_affine_quantized + # weight settings + mapping_type = MappingType.SYMMETRIC + def get_weight_block_size(x): + return (1, x.shape[1]) + target_dtype = torch.int8 + eps = torch.finfo(torch.float32).eps + zero_point_dtype = torch.int64 + + # input settings + def get_per_token_block_size(x): + block_size = list(x.shape) + for i in range(len(block_size)-1): + block_size[i] = 1 + return block_size + + input_mapping_type = MappingType.SYMMETRIC + input_target_dtype = torch.int8 + input_eps = 1e-5 + input_quant_min = -127 + input_quant_max = 127 + layout_type = PlainLayoutType() + input_quant_func = lambda x: to_affine_quantized(x, input_mapping_type, get_per_token_block_size(x), input_target_dtype, eps=input_eps, quant_min=input_quant_min, quant_max=input_quant_max, scale_dtype=torch.float32 if x.dtype == torch.float16 else None) + + block_size = get_weight_block_size(weight) + weight = to_affine_quantized(weight, mapping_type, block_size, target_dtype, eps=eps, zero_point_dtype=zero_point_dtype, layout_type=layout_type) + weight = super(AQInt8DynamicallyQuantizedLinearWeight, cls).from_float(weight, input_quant_func) + return weight + @classmethod def _autoquant_test(cls, act_mat, weight, bias, best_time, mode=["relu", None]): """ @@ -298,12 +347,13 @@ def _autoquant_test(cls, act_mat, weight, bias, best_time, mode=["relu", None]): ) q_c_matmul=torch.compile(quantized_matmul, mode="max-autotune-no-cudagraphs") with torch.no_grad(): - res_matmul = do_autoquant_bench(q_c_matmul, x_vals_int8, x_scales.reshape(-1,1), w_qtensor.int_data) + w_vals_int8 = w_qtensor.original_weight_tensor.layout_tensor.int_data.contiguous().t() + res_matmul = do_autoquant_bench(q_c_matmul, x_vals_int8, x_scales.reshape(-1,1), w_vals_int8) print(f">>time: {res_matmul:0.3f}ms for {cls} matmul, to_beat: {best_time:0.3f}ms") # if the (much faster) matmul kernel is already beat, don't bother benchmarking full op - if res_matmul>=best_time: - return res_matmul + # if res_matmul>=best_time: + # return res_matmul # calculate what time full op needs to beat for dynamic quant to be best given INTERPOLATION_CONSTANT to_beat = best_time + INTERPOLATION_CONSTANT/(1-INTERPOLATION_CONSTANT)*(best_time-res_matmul) @@ -313,18 +363,27 @@ def _autoquant_test(cls, act_mat, weight, bias, best_time, mode=["relu", None]): print(f">>time: {res_f:0.3f}ms for {cls} interpolated, breakeven constant: {max_int_const_win:0.2f}") return res_f -class AQWeightOnlyQuantizedLinearWeight(Int8WeightOnlyQuantizedLinearWeight, AQMixin): +class AQWeightOnlyQuantizedLinearWeight(AffineQuantizedTensor, AQMixin): """ AutoQuantizable version of Int8WeightOnlyQuantizedLinearWeight """ + @classmethod + def from_float(cls, weight): + mapping_type = MappingType.SYMMETRIC + target_dtype = torch.int8 + eps = torch.finfo(torch.float32).eps + zero_point_dtype = torch.int64 + block_size = (1, weight.shape[1]) + return super(AQWeightOnlyQuantizedLinearWeight, cls).from_float(weight, mapping_type, block_size, target_dtype, eps=eps, zero_point_dtype=zero_point_dtype) + -class AQWeightOnlyQuantizedLinearWeight2(Int8WeightOnlyQuantizedLinearWeight, AQMixin): +class AQWeightOnlyQuantizedLinearWeight2(AQWeightOnlyQuantizedLinearWeight, AQMixin): """ AutoQuantizable version of Int8WeightOnlyQuantizedLinearWeight that uses a different kernel """ @staticmethod - def _quantized_op(act_mat, w_qtensor, bias): + def _quantized_linear_op(act_mat, w_qtensor, bias): """ Performs the quantized linear operations @@ -339,8 +398,8 @@ def _quantized_op(act_mat, w_qtensor, bias): orig_dtype = act_mat.dtype orig_shape = act_mat.shape act_mat = act_mat.reshape(-1, act_mat.shape[-1], 1) - y = (act_mat*w_qtensor.int_data.unsqueeze(0)).sum(dim=-2) - y = y.reshape(*orig_shape[:-1], y.shape[-1]) * w_qtensor.q_scales + y = (act_mat*w_qtensor.layout_tensor.int_data.t().unsqueeze(0)).sum(dim=-2) + y = y.reshape(*orig_shape[:-1], y.shape[-1]) * w_qtensor.layout_tensor.scale if bias is not None: y += bias return y.to(orig_dtype) @@ -352,14 +411,14 @@ def _autoquant_test(cls, act_mat, *args): return torch.inf return super()._autoquant_test(act_mat, *args) -class AQWeightOnlyQuantizedLinearWeight3(Int8WeightOnlyQuantizedLinearWeight, AQMixin): +class AQWeightOnlyQuantizedLinearWeight3(AQWeightOnlyQuantizedLinearWeight, AQMixin): """ AutoQuantizable version of Int8WeightOnlyQuantizedLinearWeight that uses a different kernel """ - def _quantized_op(act_mat, w_qtensor, bias): + def _quantized_linear_op(act_mat, w_qtensor, bias): orig_shape = act_mat.shape - y = torch.mm(act_mat.reshape(-1, orig_shape[-1]), w_qtensor.int_data*w_qtensor.q_scales) + y = torch.mm(act_mat.reshape(-1, orig_shape[-1]), w_qtensor.layout_tensor.int_data.t()*w_qtensor.layout_tensor.scale) y=y.reshape(*orig_shape[:-1], y.shape[-1]) if bias is not None: y += bias @@ -377,7 +436,7 @@ def __init__(self): super().__init__() @staticmethod - def _quantized_op(act_mat, w_qtensor, bias): + def _quantized_linear_op(act_mat, w_qtensor, bias): return torch.nn.functional.linear(act_mat, w_qtensor, bias) @classmethod @@ -389,7 +448,7 @@ def from_float(cls, weight): AQWeightOnlyQuantizedLinearWeight, AQWeightOnlyQuantizedLinearWeight2, # AQWeightOnlyQuantizedLinearWeight3, - # TODO this gets picked in places where it makes perf worse, why? + # # TODO this gets picked in places where it makes perf worse, why? AQInt8DynamicallyQuantizedLinearWeight, ] diff --git a/torchao/quantization/linear_activation_quantized_tensor.py b/torchao/quantization/linear_activation_quantized_tensor.py index dfe1f62de..de478ffc9 100644 --- a/torchao/quantization/linear_activation_quantized_tensor.py +++ b/torchao/quantization/linear_activation_quantized_tensor.py @@ -53,6 +53,13 @@ def __tensor_unflatten__( input_quant_func, ) + @staticmethod + def _quantized_linear_op(input_tensor, weight_tensor, bias): + input_quant_func = weight_tensor.input_quant_func + original_weight_tensor = weight_tensor.original_weight_tensor + aqt = input_quant_func(input_tensor) + return torch.nn.functional.linear(aqt, original_weight_tensor, bias) + @classmethod def from_float(cls, input_float, input_quant_func): return cls(input_float, input_quant_func) @@ -98,10 +105,7 @@ def _(func, types, args, kwargs): args[2] if len(args) > 2 else None, ) if isinstance(weight_tensor, LinearActivationQuantizedTensor): - input_quant_func = weight_tensor.input_quant_func - original_weight_tensor = weight_tensor.original_weight_tensor - aqt = input_quant_func(input_tensor) - return torch.nn.functional.linear(aqt, original_weight_tensor, bias) + return weight_tensor._quantized_linear_op(input_tensor, weight_tensor, bias) raise NotImplementedError("LinearActivationQuantizedTensor: No specialized dispatch found for linear op")