Skip to content

Commit

Permalink
autoquant using aqt
Browse files Browse the repository at this point in the history
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
  • Loading branch information
HDCharles committed Aug 6, 2024
1 parent de4a1fb commit d4d5abe
Show file tree
Hide file tree
Showing 6 changed files with 120 additions and 50 deletions.
12 changes: 6 additions & 6 deletions test/integration/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -1173,13 +1173,13 @@ def test_on_dummy_distilbert(self):
class TestAutoQuant(unittest.TestCase):
@parameterized.expand(combine_parameters(COMMON_DEVICE_DTYPE,
[
(16, 128, 128),
(64, 128, 128),
# (16, 128, 128),
# (64, 128, 128),
# (2**15, 128, 128), TODO: Runs out of shared memory on T4
(16, 128, 256),
(2, 128, 256),
# (64, 128, 256), # TODO: Runs out of shared memory on T4
(16, 256, 128),
(64, 256, 128),
# (16, 256, 128),
# (64, 256, 128),
# (256, 256, 128), TODO: Runs out of shared memory on T4
]))
@unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "autoquant requires 2.3+.")
Expand All @@ -1194,7 +1194,7 @@ def test_autoquant_one_input(self, device, dtype, m, k, n):
if m == 1:
self.skipTest(f"Shape {(m, k, n)} requires sm80+")
torch._inductor.config.epilogue_fusion = False
torch._inductor.config.use_mixed_mm = True
# torch._inductor.config.use_mixed_mm = True
torch._inductor.config.force_fuse_int_mm_with_mul = True
torch._dynamo.config.automatic_dynamic_shapes = False

Expand Down
3 changes: 3 additions & 0 deletions torchao/_models/llama/benchmark_results.txt
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,6 @@ kv cache quantization:
20240801094415, tok/s= 87.20, mem/s=1308.88 GB/s, peak_mem=17.22 GB, model_size=15.01 GB quant: None, mod: Meta-Llama-3-8B, kv_quant: True, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --kv_cache_quantization --compile --num_samples 5 --max_new_tokens 2048 --top_k 200 --temperature 0.8
20240801095615, tok/s= 80.87, mem/s=1213.82 GB/s, peak_mem=19.77 GB, model_size=15.01 GB quant: None, mod: Meta-Llama-3-8B, kv_quant: False, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 8192 --top_k 200 --temperature 0.8
20240801100912, tok/s= 74.65, mem/s=1120.41 GB/s, peak_mem=19.29 GB, model_size=15.01 GB quant: None, mod: Meta-Llama-3-8B, kv_quant: True, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --kv_cache_quantization --compile --num_samples 5 --max_new_tokens 8192 --top_k 200 --temperature 0.8

20240806071013, tok/s=172.58, mem/s=1161.55 GB/s, peak_mem= 8.90 GB, model_size= 6.73 GB quant: autoquant, mod: Llama-2-7b-chat-hf, kv_quant: False, compile: True, compile_prefill: True, dtype: torch.bfloat16, device: cuda repro: python generate.py --quantization autoquant --checkpoint_path ../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth --device cuda --precision torch.bfloat16 --compile --compile_prefill --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8
20240806073549, tok/s=158.04, mem/s=1192.77 GB/s, peak_mem= 9.99 GB, model_size= 7.55 GB quant: autoquant, mod: Meta-Llama-3-8B, kv_quant: False, compile: True, compile_prefill: True, dtype: torch.bfloat16, device: cuda repro: python generate.py --quantization autoquant --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --compile --compile_prefill --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8
46 changes: 23 additions & 23 deletions torchao/_models/llama/benchmarks.sh
Original file line number Diff line number Diff line change
Expand Up @@ -2,31 +2,31 @@ export CHECKPOINT_PATH=../../../checkpoints # path to checkpoints folder


export MODEL_REPO=meta-llama/Llama-2-7b-chat-hf
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --precision torch.float32 --write_result benchmark_results.txt
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --write_result benchmark_results.txt
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --write_result benchmark_results.txt
# in readme
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --write_result benchmark_results.txt
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int8dq --write_result benchmark_results.txt
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int8wo --write_result benchmark_results.txt
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int4wo-64 --write_result benchmark_results.txt
# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --precision torch.float32 --write_result benchmark_results.txt
# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --write_result benchmark_results.txt
# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --write_result benchmark_results.txt
# # in readme
# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --write_result benchmark_results.txt
# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int8dq --write_result benchmark_results.txt
# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int8wo --write_result benchmark_results.txt
# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int4wo-64 --write_result benchmark_results.txt
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --quantization autoquant --write_result benchmark_results.txt

export MODEL_REPO=meta-llama/Meta-Llama-3-8B
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --precision torch.float32 --write_result benchmark_results.txt
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --write_result benchmark_results.txt
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --write_result benchmark_results.txt
# in readme
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --write_result benchmark_results.txt
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int8dq --write_result benchmark_results.txt
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int8wo --write_result benchmark_results.txt
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int4wo-64 --write_result benchmark_results.txt
# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --precision torch.float32 --write_result benchmark_results.txt
# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --write_result benchmark_results.txt
# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --write_result benchmark_results.txt
# # in readme
# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --write_result benchmark_results.txt
# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int8dq --write_result benchmark_results.txt
# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int8wo --write_result benchmark_results.txt
# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int4wo-64 --write_result benchmark_results.txt
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --quantization autoquant --write_result benchmark_results.txt

export MODEL_REPO=meta-llama/Meta-Llama-3-8B
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --write_result benchmark_results.txt
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --write_result benchmark_results.txt --kv_cache_quantization
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --write_result benchmark_results.txt --max_new_tokens 2048
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --write_result benchmark_results.txt --kv_cache_quantization --max_new_tokens 2048
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --write_result benchmark_results.txt --max_new_tokens 8192
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --write_result benchmark_results.txt --kv_cache_quantization --max_new_tokens 8192
# export MODEL_REPO=meta-llama/Meta-Llama-3-8B
# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --write_result benchmark_results.txt
# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --write_result benchmark_results.txt --kv_cache_quantization
# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --write_result benchmark_results.txt --max_new_tokens 2048
# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --write_result benchmark_results.txt --kv_cache_quantization --max_new_tokens 2048
# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --write_result benchmark_results.txt --max_new_tokens 8192
# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --write_result benchmark_results.txt --kv_cache_quantization --max_new_tokens 8192
6 changes: 5 additions & 1 deletion torchao/dtypes/affine_quantized_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,10 @@ def dequantize(self, output_dtype=None):
int_data, scale, zero_point = self.layout_tensor.get_plain()
return dequantize_affine(int_data, self.block_size, scale, zero_point, int_data.dtype, self.quant_min, self.quant_max, self.zero_point_domain, output_dtype=output_dtype)

@staticmethod
def _quantized_linear_op(input_tensor, weight_tensor, bias):
return _quantized_linear_op(input_tensor, weight_tensor, bias)

def __tensor_flatten__(self):
return ["layout_tensor"], [self.block_size, self.shape, self.quant_min, self.quant_max, self.zero_point_domain, self.dtype]

Expand Down Expand Up @@ -832,7 +836,7 @@ def _(func, types, args, kwargs):
# is not picked up by any of the dispatch paths in `_quantized_linear_op`, this allows us to
# make the branches easier to understand in `_quantized_linear_op`
try:
return _quantized_linear_op(input_tensor, weight_tensor, bias)
return weight_tensor._quantized_linear_op(input_tensor, weight_tensor, bias)
except:
if isinstance(input_tensor, AffineQuantizedTensor):
input_tensor = input_tensor.dequantize()
Expand Down
91 changes: 75 additions & 16 deletions torchao/quantization/autoquant.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,16 @@
import torch
import torchao
from torchao.quantization.quant_primitives import (
MappingType,
ZeroPointDomain,
)
from .subclass import ( # noqa
Int8DynamicallyQuantizedLinearWeight,
Int8WeightOnlyQuantizedLinearWeight,
QuantizedLinearWeightBase,
)
from torchao.dtypes import AffineQuantizedTensor, PlainLayoutType
from torchao.quantization.linear_activation_quantized_tensor import LinearActivationQuantizedTensor
from torch.utils._python_dispatch import return_and_correct_aliasing
from .quant_primitives import (
safe_int_mm,
Expand Down Expand Up @@ -252,9 +258,9 @@ class AQMixin():
def _autoquant_test(cls, act_mat, weight, bias, best_time, mode=["relu", None]):
w_qtensor = cls.from_float(weight)
if _is_interpolate_mode(mode):
q_c_op = torch.compile(cls._quantized_op, mode="max-autotune-no-cudagraphs")
q_c_op = torch.compile(cls._quantized_linear_op, mode="max-autotune-no-cudagraphs")
else:
func = lambda a,b,c: F.relu(cls._quantized_op(F.relu(a), b, c))
func = lambda a,b,c: F.relu(cls._quantized_linear_op(F.relu(a), b, c))
q_c_op = torch.compile(func, mode="max-autotune-no-cudagraphs")
res = do_autoquant_bench(q_c_op, act_mat, w_qtensor, bias, warmup=25, rep=100)
if res < best_time*1.1:
Expand All @@ -263,10 +269,53 @@ def _autoquant_test(cls, act_mat, weight, bias, best_time, mode=["relu", None]):
print(f">>time: {res:0.3f}ms for {cls}, to_beat: {best_time:0.3f}ms ")
return res

class AQInt8DynamicallyQuantizedLinearWeight(AQMixin, Int8DynamicallyQuantizedLinearWeight):
###### TODO !!!!!!!!!!!!!!!
# 1) make class method from_float (just duplicate code)
# 2) undo changes to quant_api?
# 3) point to new quantized_op location
# 4) rewrite the dynamic autoquant test

class AQInt8DynamicallyQuantizedLinearWeight(AQMixin, LinearActivationQuantizedTensor):
"""
AutoQuantizable version of Int8DynamicallyQuantizedLinearWeight
"""
@classmethod
def from_float(cls, weight):
in_features = weight.shape[1]
# int8 dynamic quantization only has benefit when in_feature > 16
# if in_features <= 16:
# return weight

# avoid circular dep
from torchao.dtypes import to_affine_quantized
# weight settings
mapping_type = MappingType.SYMMETRIC
def get_weight_block_size(x):
return (1, x.shape[1])
target_dtype = torch.int8
eps = torch.finfo(torch.float32).eps
zero_point_dtype = torch.int64

# input settings
def get_per_token_block_size(x):
block_size = list(x.shape)
for i in range(len(block_size)-1):
block_size[i] = 1
return block_size

input_mapping_type = MappingType.SYMMETRIC
input_target_dtype = torch.int8
input_eps = 1e-5
input_quant_min = -127
input_quant_max = 127
layout_type = PlainLayoutType()
input_quant_func = lambda x: to_affine_quantized(x, input_mapping_type, get_per_token_block_size(x), input_target_dtype, eps=input_eps, quant_min=input_quant_min, quant_max=input_quant_max, scale_dtype=torch.float32 if x.dtype == torch.float16 else None)

block_size = get_weight_block_size(weight)
weight = to_affine_quantized(weight, mapping_type, block_size, target_dtype, eps=eps, zero_point_dtype=zero_point_dtype, layout_type=layout_type)
weight = super(AQInt8DynamicallyQuantizedLinearWeight, cls).from_float(weight, input_quant_func)
return weight

@classmethod
def _autoquant_test(cls, act_mat, weight, bias, best_time, mode=["relu", None]):
"""
Expand Down Expand Up @@ -298,12 +347,13 @@ def _autoquant_test(cls, act_mat, weight, bias, best_time, mode=["relu", None]):
)
q_c_matmul=torch.compile(quantized_matmul, mode="max-autotune-no-cudagraphs")
with torch.no_grad():
res_matmul = do_autoquant_bench(q_c_matmul, x_vals_int8, x_scales.reshape(-1,1), w_qtensor.int_data)
w_vals_int8 = w_qtensor.original_weight_tensor.layout_tensor.int_data.contiguous().t()
res_matmul = do_autoquant_bench(q_c_matmul, x_vals_int8, x_scales.reshape(-1,1), w_vals_int8)
print(f">>time: {res_matmul:0.3f}ms for {cls} matmul, to_beat: {best_time:0.3f}ms")

# if the (much faster) matmul kernel is already beat, don't bother benchmarking full op
if res_matmul>=best_time:
return res_matmul
# if res_matmul>=best_time:
# return res_matmul

# calculate what time full op needs to beat for dynamic quant to be best given INTERPOLATION_CONSTANT
to_beat = best_time + INTERPOLATION_CONSTANT/(1-INTERPOLATION_CONSTANT)*(best_time-res_matmul)
Expand All @@ -313,18 +363,27 @@ def _autoquant_test(cls, act_mat, weight, bias, best_time, mode=["relu", None]):
print(f">>time: {res_f:0.3f}ms for {cls} interpolated, breakeven constant: {max_int_const_win:0.2f}")
return res_f

class AQWeightOnlyQuantizedLinearWeight(Int8WeightOnlyQuantizedLinearWeight, AQMixin):
class AQWeightOnlyQuantizedLinearWeight(AffineQuantizedTensor, AQMixin):
"""
AutoQuantizable version of Int8WeightOnlyQuantizedLinearWeight
"""
@classmethod
def from_float(cls, weight):
mapping_type = MappingType.SYMMETRIC
target_dtype = torch.int8
eps = torch.finfo(torch.float32).eps
zero_point_dtype = torch.int64
block_size = (1, weight.shape[1])
return super(AQWeightOnlyQuantizedLinearWeight, cls).from_float(weight, mapping_type, block_size, target_dtype, eps=eps, zero_point_dtype=zero_point_dtype)


class AQWeightOnlyQuantizedLinearWeight2(Int8WeightOnlyQuantizedLinearWeight, AQMixin):
class AQWeightOnlyQuantizedLinearWeight2(AQWeightOnlyQuantizedLinearWeight, AQMixin):
"""
AutoQuantizable version of Int8WeightOnlyQuantizedLinearWeight that
uses a different kernel
"""
@staticmethod
def _quantized_op(act_mat, w_qtensor, bias):
def _quantized_linear_op(act_mat, w_qtensor, bias):
"""
Performs the quantized linear operations
Expand All @@ -339,8 +398,8 @@ def _quantized_op(act_mat, w_qtensor, bias):
orig_dtype = act_mat.dtype
orig_shape = act_mat.shape
act_mat = act_mat.reshape(-1, act_mat.shape[-1], 1)
y = (act_mat*w_qtensor.int_data.unsqueeze(0)).sum(dim=-2)
y = y.reshape(*orig_shape[:-1], y.shape[-1]) * w_qtensor.q_scales
y = (act_mat*w_qtensor.layout_tensor.int_data.t().unsqueeze(0)).sum(dim=-2)
y = y.reshape(*orig_shape[:-1], y.shape[-1]) * w_qtensor.layout_tensor.scale
if bias is not None:
y += bias
return y.to(orig_dtype)
Expand All @@ -352,14 +411,14 @@ def _autoquant_test(cls, act_mat, *args):
return torch.inf
return super()._autoquant_test(act_mat, *args)

class AQWeightOnlyQuantizedLinearWeight3(Int8WeightOnlyQuantizedLinearWeight, AQMixin):
class AQWeightOnlyQuantizedLinearWeight3(AQWeightOnlyQuantizedLinearWeight, AQMixin):
"""
AutoQuantizable version of Int8WeightOnlyQuantizedLinearWeight that
uses a different kernel
"""
def _quantized_op(act_mat, w_qtensor, bias):
def _quantized_linear_op(act_mat, w_qtensor, bias):
orig_shape = act_mat.shape
y = torch.mm(act_mat.reshape(-1, orig_shape[-1]), w_qtensor.int_data*w_qtensor.q_scales)
y = torch.mm(act_mat.reshape(-1, orig_shape[-1]), w_qtensor.layout_tensor.int_data.t()*w_qtensor.layout_tensor.scale)
y=y.reshape(*orig_shape[:-1], y.shape[-1])
if bias is not None:
y += bias
Expand All @@ -377,7 +436,7 @@ def __init__(self):
super().__init__()

@staticmethod
def _quantized_op(act_mat, w_qtensor, bias):
def _quantized_linear_op(act_mat, w_qtensor, bias):
return torch.nn.functional.linear(act_mat, w_qtensor, bias)

@classmethod
Expand All @@ -389,7 +448,7 @@ def from_float(cls, weight):
AQWeightOnlyQuantizedLinearWeight,
AQWeightOnlyQuantizedLinearWeight2,
# AQWeightOnlyQuantizedLinearWeight3,
# TODO this gets picked in places where it makes perf worse, why?
# # TODO this gets picked in places where it makes perf worse, why?
AQInt8DynamicallyQuantizedLinearWeight,
]

Expand Down
Loading

0 comments on commit d4d5abe

Please sign in to comment.