diff --git a/jetstream_pt/config.py b/jetstream_pt/config.py index 354ed5d3..a274c04d 100644 --- a/jetstream_pt/config.py +++ b/jetstream_pt/config.py @@ -42,6 +42,11 @@ # Quantization related flags flags.DEFINE_bool("quantize_weights", False, "weight quantization") +flags.DEFINE_bool( + "quantize_activation", + False, + "Quantize Q,K,V projection and FeedForward activation.", +) flags.DEFINE_string( "quantize_type", "int8_per_channel", "Type of quantization." ) @@ -90,6 +95,9 @@ def create_quantization_config_from_flags(): config.enable_weight_quantization = True config.num_bits_weight = 8 if "int8" in quantize_type else 4 config.is_blockwise_weight = "blockwise" in quantize_type + + config.enable_activation_quantization = FLAGS.quantize_activation + config.enable_kv_quantization = FLAGS.quantize_kv_cache return config diff --git a/jetstream_pt/environment.py b/jetstream_pt/environment.py index 5ea8f3a3..005114ab 100644 --- a/jetstream_pt/environment.py +++ b/jetstream_pt/environment.py @@ -32,6 +32,10 @@ class QuantizationConfig: enable_weight_quantization: bool = False num_bits_weight: int = 8 is_blockwise_weight: bool = False + block_size_weight: int = 128 + is_symmetric_weight: bool = True + + enable_activation_quantization: bool = False enable_kv_quantization: bool = False diff --git a/jetstream_pt/layers.py b/jetstream_pt/layers.py index c5e305b8..8ef7f131 100644 --- a/jetstream_pt/layers.py +++ b/jetstream_pt/layers.py @@ -25,10 +25,14 @@ import torch_xla2 from jax import lax from jetstream_pt import torchjax +from jetstream_pt.environment import QuantizationConfig from jetstream_pt.quantize import ( dequantize_tensor, load_q_weight_helper, quantize_tensor, + blockwise_jax_kernel, + blockwise_jax_kernel_dot_general, + blockwise_jax_kernel_einsum_flatten, ) from torch import nn from . import attention_kernel as ak @@ -68,8 +72,7 @@ def __init__( out_features, bias=False, device=None, - is_symmetric=True, - n_bit=8, + quant_config=QuantizationConfig(), ): super().__init__() self.in_features = in_features @@ -85,8 +88,9 @@ def __init__( ) self.register_buffer("weight_scaler", weight_scaler) - self.is_symmetric = is_symmetric - if not is_symmetric: + self.is_symmetric_weight = quant_config.is_symmetric_weight + + if not self.is_symmetric_weight: zero_point = torch.ones( (out_features,), dtype=torch.bfloat16, device=device ) @@ -96,7 +100,12 @@ def __init__( assert not bias, "Quantized Linear doesn't support bias." - self.n_bit = n_bit + # Number of bits of weight tensor + self.n_bit = quant_config.num_bits_weight + + # Quantize activation + self.quantize_activation = quant_config.enable_activation_quantization + # Flag to enable dequantize weight first, then do matmul. Useful for debugging. self.run_fake_quantize = False @@ -115,23 +124,40 @@ def quantize_weight_from_nn_linear(self, weight): self.in_features, ), f"Got unexpected weight of shape {weight.shape}, expected weight shape ({self.out_features}, {self.in_features})." w_q, scale, zp = quantize_tensor( - weight, (1,), self.n_bit, self.is_symmetric, block_size=-1 + weight, (1,), self.n_bit, self.is_symmetric_weight, block_size=-1 ) w_dq = dequantize_tensor(w_q, scale, zp) self._load_quantized_weights(w_q, scale, zp) def forward(self, inputs): if not self.run_fake_quantize: - if self.is_symmetric: - return torch.mul(F.linear(inputs, self.weight), self.weight_scaler) + if self.quantize_activation: + inputs, act_s, _ = quantize_tensor(inputs, reduce_axis=(2,)) + if not self.quantize_activation: + result = F.linear(inputs, self.weight) else: - out = torch.mul(F.linear(inputs, self.weight), self.weight_scaler) + # We have to call jax because we need to do dot(int8, int8)->int32. + # This semantic cannot be represented in torch. The inferred output dtype + # will be int8 in torch, causing the dot result to overflow. + result = torchjax.call_jax( + jax.lax.dot_general, + inputs, + self.weight, + (((2,), (1)), ((), ())), + None, + jnp.int32.dtype, + ) + result = result * self.weight_scaler + if self.quantize_activation: + result = result * act_s + if not self.is_symmetric_weight: zp_out = torch.einsum("...c,z->...z", inputs, self.zero_point) - return out - zp_out + result = result - zp_out + return result else: # Fake quantization, debugging purpose. scaler = self.weight_scaler.unsqueeze(-1) - if not self.is_symmetric: + if not self.is_symmetric_weight: zero_point = self.zero_point.unsqueeze(-1) / scaler else: zero_point = None @@ -149,10 +175,7 @@ def __init__( out_features, bias=False, device=None, - is_symmetric=True, - use_dot_general=False, - block_size=128, - n_bit=8, + quant_config=QuantizationConfig(), ): super().__init__() self.in_features = in_features @@ -160,21 +183,29 @@ def __init__( # Use dot general instead of einsum # Use dot general is slow now. - self.use_dot_general = use_dot_general + self.use_dot_general = False # Flatten einsum operands to 3D. XLA was slow if operands are 4D. But it's fixed now. # Same perf as non flattened one now. self.flatten = False - self.block_size = block_size - n_blocks = in_features // block_size + self.block_size = quant_config.block_size_weight + n_blocks = in_features // self.block_size + + assert ( + not quant_config.enable_activation_quantization + ), "Activation quantization not supported for blockwise quantized matmul." if self.use_dot_general: weight = torch.ones( - (n_blocks, out_features, block_size), dtype=torch.int8, device=device + (n_blocks, out_features, self.block_size), + dtype=torch.int8, + device=device, ) else: weight = torch.ones( - (n_blocks, block_size, out_features), dtype=torch.int8, device=device + (n_blocks, self.block_size, out_features), + dtype=torch.int8, + device=device, ) self.register_buffer("weight", weight) @@ -183,8 +214,8 @@ def __init__( ) self.register_buffer("weight_scaler", weight_scaler) - self.is_symmetric = is_symmetric - if not self.is_symmetric: + self.is_symmetric_weight = quant_config.is_symmetric_weight + if not self.is_symmetric_weight: zero_point = torch.ones( (n_blocks, out_features), dtype=torch.bfloat16, device=device ) @@ -192,7 +223,11 @@ def __init__( else: self.register_buffer("zero_point", None) - self.n_bit = n_bit + self.n_bit = quant_config.num_bits_weight + + # Quantize activation + self.quantize_activation = quant_config.enable_activation_quantization + # Flag to enable dequantize weight first, then do matmul. Useful for debugging. self.run_fake_quantize = False @@ -211,112 +246,37 @@ def quantize_weight_from_nn_linear(self, weight): self.in_features, ), f"Unexpected weight shape ({self.out_features}, {self.in_features})." w_q, scale, zp = quantize_tensor( - weight, (1,), self.n_bit, self.is_symmetric, self.block_size + weight, (1,), self.n_bit, self.is_symmetric_weight, self.block_size ) w_dq = dequantize_tensor(w_q, scale, zp) - print("check qweight cosine dist: ", _calc_cosine_dist(weight, w_dq)) - # breakpoint() self._load_quantized_weights(w_q, scale, zp) - @staticmethod - def blockwise_jax_kernel(inputs, weight, weight_scaler, zero_point): - """Blockwise Matmul kernel impl in JAX using einsum""" - weight = weight.astype(jnp.int8) - block_size = weight.shape[1] - inputs_shape = inputs.shape - inputs_new_shape = inputs_shape[:-1] + ( - inputs_shape[-1] // block_size, - block_size, - ) - inputs = inputs.reshape(inputs_new_shape) - out = jnp.einsum("scz,bdsc->bdsz", weight, inputs) - out = jnp.einsum("bdsz,sz->bdz", out, weight_scaler) - if zero_point is not None: - zp_out = jnp.einsum("bdsc,sz->bdz", inputs, zero_point) - out = out - zp_out - return out - - @staticmethod - def blockwise_jax_kernel_dot_general( - inputs, weight, weight_scaler, zero_point - ): - """Blockwise Matmul kernel impl in JAX using dot general""" - inputs_shape = inputs.shape - block_size = weight.shape[2] - bs = inputs_shape[0] - inputs_new_shape = inputs_shape[:-1] + ( - inputs_shape[-1] // block_size, - block_size, - ) - inputs = inputs.reshape(inputs_new_shape) - inputs = jax.lax.collapse(inputs, 0, 2) - out = jax.lax.dot_general( - inputs, weight, dimension_numbers=([(2), (2)], [(1), (0)]) - ) - out = jax.lax.dot_general( - out, weight_scaler, dimension_numbers=([(0), (0)], [(2), (1)]) - ) - out = jax.lax.transpose(out, [1, 0]) - out = out.reshape((bs, -1) + out.shape[1:]) - return out - - @staticmethod - def blockwise_jax_kernel_einsum_flatten( - inputs, weight, weight_scaler, zero_point - ): - """Blockwise Matmul kernel impl in JAX using einsum, with operands flattened""" - weight = weight.astype(jnp.int8) - block_size = weight.shape[1] - inputs_shape = inputs.shape - bs = inputs_shape[0] - inputs_new_shape = inputs_shape[:-1] + ( - inputs_shape[-1] // block_size, - block_size, - ) - inputs = inputs.reshape(inputs_new_shape) - inputs = jax.lax.collapse(inputs, 0, 2) - out = jnp.einsum("scz,bsc->bsz", weight, inputs) - out = jnp.einsum("bsz,sz->bz", out, weight_scaler) - out = out.reshape((bs, -1) + out.shape[1:]) - return out - def forward(self, inputs): if not self.run_fake_quantize: - if self.use_dot_general: + if self.use_dot_general or self.flatten: assert ( self.zero_point is None - ), "Blockwise quantized linear doesn't support zero_point in dot_general implementation." - return torchjax.call_jax( - WeightOnlyBlockwiseQuantizedLinear.blockwise_jax_kernel_dot_general, - inputs, - self.weight, - self.weight_scaler, - self.zero_point, - ) - if self.flatten: - assert ( - self.zero_point is None - ), "Blockwise quantized linear doesn't support zero_point in einsum (flattened) implementation." - return torchjax.call_jax( - WeightOnlyBlockwiseQuantizedLinear.blockwise_jax_kernel_einsum_flatten, - inputs, - self.weight, - self.weight_scaler, - self.zero_point, - ) - else: - return torchjax.call_jax( - WeightOnlyBlockwiseQuantizedLinear.blockwise_jax_kernel, - inputs, - self.weight, - self.weight_scaler, - self.zero_point, - ) + ), "Blockwise quantized linear doesn't support zero_point in dot_general or einsum flattened implementation." + blockwise_matmul_kernel = ( + blockwise_jax_kernel + if not self.use_dot_general and not self.flatten + else blockwise_jax_kernel_dot_general + if self.use_dot_general + else blockwise_jax_kernel_einsum_flatten + ) + result = torchjax.call_jax( + blockwise_matmul_kernel, + inputs, + self.weight, + self.weight_scaler, + self.zero_point, + ) + return result else: # Fake quantization, debugging purpose. weight = self.weight.permute(2, 0, 1).to(torch.bfloat16) scaler = self.weight_scaler.unsqueeze(-1).transpose(1, 0) - if not self.is_symmetric: + if not self.is_symmetric_weight: zero_point = self.zero_point.unsqueeze(-1).transpose(1, 0) / scaler else: zero_point = None @@ -554,12 +514,16 @@ def __init__(self, n_heads, n_kv_heads, head_dim, hidden_size, device, env): self.hidden_size = hidden_size LinearLayer = get_quantized_linear_layer(env.quant_config) + linear_kwargs = {} + if LinearLayer != torch.nn.Linear: + linear_kwargs = {"quant_config": env.quant_config} self.wo = LinearLayer( n_heads * self.head_dim, hidden_size, bias=False, device=device, + **linear_kwargs, ) Kernel = ( @@ -578,6 +542,7 @@ def __init__(self, n_heads, n_kv_heads, head_dim, hidden_size, device, env): (n_heads + 2 * self.n_kv_heads) * self.head_dim, bias=False, device=device, + **linear_kwargs, ) else: self.wq = LinearLayer( @@ -585,18 +550,21 @@ def __init__(self, n_heads, n_kv_heads, head_dim, hidden_size, device, env): n_heads * self.head_dim, bias=False, device=device, + **linear_kwargs, ) self.wk = LinearLayer( hidden_size, self.n_kv_heads * self.head_dim, bias=False, device=device, + **linear_kwargs, ) self.wv = LinearLayer( hidden_size, self.n_kv_heads * self.head_dim, bias=False, device=device, + **linear_kwargs, ) def load_hook(self, state_dict, prefix, *args): diff --git a/jetstream_pt/quantize.py b/jetstream_pt/quantize.py index 30514c33..0e0663cf 100644 --- a/jetstream_pt/quantize.py +++ b/jetstream_pt/quantize.py @@ -14,6 +14,8 @@ from typing import Tuple, Union +import jax +import jax.numpy as jnp import torch EPS = 1e-5 @@ -95,3 +97,63 @@ def load_q_weight_helper(w_q, scale, zp=None, block_size=-1): zp = (zp * scale).transpose(1, 0).squeeze(-1).to(torch.bfloat16) scale = scale.transpose(1, 0).squeeze(-1).to(torch.bfloat16) return w_q, scale, zp + + +def blockwise_jax_kernel(inputs, weight, weight_scaler, zero_point): + """Blockwise Matmul kernel impl in JAX using einsum""" + weight = weight.astype(jnp.int8) + block_size = weight.shape[1] + inputs_shape = inputs.shape + inputs_new_shape = inputs_shape[:-1] + ( + inputs_shape[-1] // block_size, + block_size, + ) + inputs = inputs.reshape(inputs_new_shape) + out = jnp.einsum("scz,bdsc->bdsz", weight, inputs) + out = jnp.einsum("bdsz,sz->bdz", out, weight_scaler) + if zero_point is not None: + zp_out = jnp.einsum("bdsc,sz->bdz", inputs, zero_point) + out = out - zp_out + return out + + +def blockwise_jax_kernel_dot_general(inputs, weight, weight_scaler, zero_point): + """Blockwise Matmul kernel impl in JAX using dot general""" + inputs_shape = inputs.shape + block_size = weight.shape[2] + bs = inputs_shape[0] + inputs_new_shape = inputs_shape[:-1] + ( + inputs_shape[-1] // block_size, + block_size, + ) + inputs = inputs.reshape(inputs_new_shape) + inputs = jax.lax.collapse(inputs, 0, 2) + out = jax.lax.dot_general( + inputs, weight, dimension_numbers=([(2), (2)], [(1), (0)]) + ) + out = jax.lax.dot_general( + out, weight_scaler, dimension_numbers=([(0), (0)], [(2), (1)]) + ) + out = jax.lax.transpose(out, [1, 0]) + out = out.reshape((bs, -1) + out.shape[1:]) + return out + + +def blockwise_jax_kernel_einsum_flatten( + inputs, weight, weight_scaler, zero_point +): + """Blockwise Matmul kernel impl in JAX using einsum, with operands flattened""" + weight = weight.astype(jnp.int8) + block_size = weight.shape[1] + inputs_shape = inputs.shape + bs = inputs_shape[0] + inputs_new_shape = inputs_shape[:-1] + ( + inputs_shape[-1] // block_size, + block_size, + ) + inputs = inputs.reshape(inputs_new_shape) + inputs = jax.lax.collapse(inputs, 0, 2) + out = jnp.einsum("scz,bsc->bsz", weight, inputs) + out = jnp.einsum("bsz,sz->bz", out, weight_scaler) + out = out.reshape((bs, -1) + out.shape[1:]) + return out diff --git a/jetstream_pt/third_party/gemma/model.py b/jetstream_pt/third_party/gemma/model.py index 73d8e07e..1072dad9 100644 --- a/jetstream_pt/third_party/gemma/model.py +++ b/jetstream_pt/third_party/gemma/model.py @@ -97,29 +97,37 @@ def __init__( if env.quant_config.enable_weight_quantization else torch.nn.Linear ) + linear_kwargs = {} + if Linear != torch.nn.Linear: + linear_kwargs = {"quant_config": env.quant_config} + self.wq = Linear( hidden_size, num_heads * self.head_dim, bias=False, device=device, + **linear_kwargs, ) self.wk = Linear( hidden_size, self.num_kv_heads * self.head_dim, bias=False, device=device, + **linear_kwargs, ) self.wv = Linear( hidden_size, self.num_kv_heads * self.head_dim, bias=False, device=device, + **linear_kwargs, ) self.o_proj = Linear( self.num_heads * self.head_dim, self.hidden_size, bias=False, device=device, + **linear_kwargs, ) Kernel = ( @@ -227,14 +235,30 @@ def __init__( if env.quant_config.enable_weight_quantization else torch.nn.Linear ) + linear_kwargs = {} + if Linear != torch.nn.Linear: + linear_kwargs = {"quant_config": env.quant_config} + self.gate_proj = Linear( - hidden_size, intermediate_size, bias=False, device=device + hidden_size, + intermediate_size, + bias=False, + device=device, + **linear_kwargs, ) self.up_proj = Linear( - hidden_size, intermediate_size, bias=False, device=device + hidden_size, + intermediate_size, + bias=False, + device=device, + **linear_kwargs, ) self.down_proj = Linear( - intermediate_size, hidden_size, bias=False, device=device + intermediate_size, + hidden_size, + bias=False, + device=device, + **linear_kwargs, ) def forward(self, x): diff --git a/jetstream_pt/third_party/llama/model_exportable.py b/jetstream_pt/third_party/llama/model_exportable.py index 124df690..c081b3cf 100644 --- a/jetstream_pt/third_party/llama/model_exportable.py +++ b/jetstream_pt/third_party/llama/model_exportable.py @@ -41,24 +41,30 @@ def __init__( hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) LinearLayer = get_quantized_linear_layer(env.quant_config) + linear_kwargs = {} + if LinearLayer != torch.nn.Linear: + linear_kwargs["quant_config"] = env.quant_config self.w1 = LinearLayer( dim, hidden_dim, bias=False, device=device, + **linear_kwargs, ) self.w2 = LinearLayer( hidden_dim, dim, bias=False, device=device, + **linear_kwargs, ) self.w3 = LinearLayer( dim, hidden_dim, bias=False, device=device, + **linear_kwargs, ) def forward(self, x): @@ -171,12 +177,16 @@ def __init__( self.norm = RMSNorm(params.dim, eps=params.norm_eps, device=params.device) LinearLayer = get_quantized_linear_layer(env.quant_config) + linear_kwargs = {} + if LinearLayer != torch.nn.Linear: + linear_kwargs["quant_config"] = env.quant_config self.output = LinearLayer( params.dim, params.vocab_size, bias=False, device=params.device, + **linear_kwargs, ) # TODO what to do with this freqs_cis = precompute_freqs_cis( diff --git a/tests/test_quantization.py b/tests/test_quantization.py index 98eb26a3..e2f2764e 100644 --- a/tests/test_quantization.py +++ b/tests/test_quantization.py @@ -22,6 +22,7 @@ import torch_xla2 from jax.experimental import mesh_utils from jetstream_pt import cache_manager, layers, quantize, torchjax +from jetstream_pt.environment import QuantizationConfig from jetstream_pt.layers import ( WeightOnlyBlockwiseQuantizedLinear, WeightOnlyPerChannelQuantizedLinear, @@ -46,6 +47,20 @@ def _calc_cosine_dist(self, x, y): y = y.flatten().to(torch.float32) return (torch.dot(x, y) / (x.norm() * y.norm())).item() + def _nn_linear_run_and_compare( + self, + nn_linear, + qlinear_layer, + arg, + ): + torch_result = nn_linear(arg) + qlinear_layer.quantize_weight_from_nn_linear(nn_linear.weight) + result = helpers.call_xla_model( + qlinear_layer, qlinear_layer.state_dict(), arg + ) + diff = result - torch_result + return result, torch_result, diff + def _print_diff(self, w, w_dq): print("Print diff:") print(" diff: ", w - w_dq) @@ -128,13 +143,12 @@ def quantize_dequantize_weight(w, n_bit): w_q_asym, s_asym, zp_asym = quantize_tensor( w, (1,), n_bit=n_bit, symmetric=False ) - # print(f"w_q_asym {w_q_asym}, s_asym {s_asym}, zp_asym {zp_asym}") w_dq_asym = dequantize_tensor(w_q_asym, s_asym, zp_asym) - # print(f"w_dq_asym {w_dq_asym}") - # self._print_diff(w, w_dq) - # self._print_diff(w, w_dq_asym) # Asymmetric is more accurate than symmetric. - self.assertLess((w - w_dq_asym).norm(), (w - w_dq).norm()) + self.assertLess( + (w - w_dq_asym).norm(), + (w - w_dq).norm(), + ) # Blockwise quant. w_block_q, s_block, _ = quantize_tensor( w, (1,), n_bit=n_bit, symmetric=True, block_size=2 @@ -154,31 +168,19 @@ def quantize_dequantize_weight(w, n_bit): # Blockwise asymmetric is more accurate than blockwise symmetric. self.assertLess((w - w_block_asym_dq).norm(), (w - w_block_dq).norm()) - w = torch.randn(2, 8) + w = ( + torch.randn(2, 8) + 2 + ) # Add a bias to normal dist to test asymmetric quant. for bit in [4, 8]: with self.subTest(bit=bit): quantize_dequantize_weight(w, bit) - def test_quant_linear(self): + def test_weight_only_quant(self): out_features = 2048 in_features = 2048 block_size = 128 - @torch.no_grad() - def run_and_compare( - nn_linear, - qlinear_layer, - arg, - ): - torch_result = nn_linear(arg) - qlinear_layer.quantize_weight_from_nn_linear(nn_linear.weight) - result = helpers.call_xla_model( - qlinear_layer, qlinear_layer.state_dict(), arg - ) - diff = result - torch_result - return result, torch_result, diff - arg = torch.randn(2, 16, in_features).to(torch.bfloat16) nn_linear = torch.nn.Linear( in_features, out_features, bias=False, dtype=torch.bfloat16 @@ -187,32 +189,38 @@ def run_and_compare( per_channel_q_linear = WeightOnlyPerChannelQuantizedLinear( in_features, out_features ) - res, torch_res, per_channel_diff = run_and_compare( + res, torch_res, per_channel_diff = self._nn_linear_run_and_compare( nn_linear, per_channel_q_linear, arg ) self.assertTrue(torch.allclose(res, torch_res, atol=2)) block_q_linear = WeightOnlyBlockwiseQuantizedLinear( in_features, out_features ) - res, torch_res, block_diff = run_and_compare(nn_linear, block_q_linear, arg) + res, torch_res, block_diff = self._nn_linear_run_and_compare( + nn_linear, block_q_linear, arg + ) # self.assertTrue(torch.allclose(res, torch_res, atol=1.5)) # Block quant is more accurate than per_channel quant. self.assertLess(block_diff.norm(), per_channel_diff.norm()) # Test asymmetric quant + quant_config = QuantizationConfig(is_symmetric_weight=False) per_channel_q_linear = WeightOnlyPerChannelQuantizedLinear( - in_features, out_features, is_symmetric=False + in_features, out_features, quant_config=quant_config ) - res, torch_res, per_channel_diff2 = run_and_compare( + res, torch_res, per_channel_diff2 = self._nn_linear_run_and_compare( nn_linear, per_channel_q_linear, arg ) # self._print_diff(res, torch_res) self.assertTrue(torch.allclose(res, torch_res, atol=2)) + quant_config = QuantizationConfig( + is_symmetric_weight=False, is_blockwise_weight=True + ) block_q_linear = WeightOnlyBlockwiseQuantizedLinear( - in_features, out_features, is_symmetric=False + in_features, out_features, quant_config=quant_config ) # block_q_linear.run_fake_quantize = True - res, torch_res, block_diff2 = run_and_compare( + res, torch_res, block_diff2 = self._nn_linear_run_and_compare( nn_linear, block_q_linear, arg ) # self._print_diff(res, torch_res) @@ -271,6 +279,28 @@ def shard_and_lower(f, layer, state_dict_jax, input, shardings): self.assertFalse("all-to-all" in opt_hlo) self.assertFalse("all-reduce-scatter" in opt_hlo) + def test_activation_quant_per_channel(self): + + out_features = 8 + in_features = 4 + block_size = 128 + + arg = torch.randn(2, 1, in_features).to(torch.bfloat16) + nn_linear = torch.nn.Linear( + in_features, out_features, bias=False, dtype=torch.bfloat16 + ) + quant_config = QuantizationConfig( + enable_weight_quantization=True, + enable_activation_quantization=True, + ) + per_channel_q_linear = WeightOnlyPerChannelQuantizedLinear( + in_features, out_features, quant_config=quant_config + ) + res, torch_res, _ = self._nn_linear_run_and_compare( + nn_linear, per_channel_q_linear, arg + ) + self.assertGreater(self._calc_cosine_dist(res, torch_res), 0.9999) + if __name__ == "__main__": unittest.main()