-
Notifications
You must be signed in to change notification settings - Fork 15
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add activation quantization support to per-channel quantized linear layers #105
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -25,10 +25,14 @@ | |
import torch_xla2 | ||
from jax import lax | ||
from jetstream_pt import torchjax | ||
from jetstream_pt.environment import QuantizationConfig | ||
from jetstream_pt.quantize import ( | ||
dequantize_tensor, | ||
load_q_weight_helper, | ||
quantize_tensor, | ||
blockwise_jax_kernel, | ||
blockwise_jax_kernel_dot_general, | ||
blockwise_jax_kernel_einsum_flatten, | ||
) | ||
from torch import nn | ||
from . import attention_kernel as ak | ||
|
@@ -68,8 +72,7 @@ def __init__( | |
out_features, | ||
bias=False, | ||
device=None, | ||
is_symmetric=True, | ||
n_bit=8, | ||
quant_config=QuantizationConfig(), | ||
): | ||
super().__init__() | ||
self.in_features = in_features | ||
|
@@ -85,8 +88,9 @@ def __init__( | |
) | ||
self.register_buffer("weight_scaler", weight_scaler) | ||
|
||
self.is_symmetric = is_symmetric | ||
if not is_symmetric: | ||
self.is_symmetric_weight = quant_config.is_symmetric_weight | ||
|
||
if not self.is_symmetric_weight: | ||
zero_point = torch.ones( | ||
(out_features,), dtype=torch.bfloat16, device=device | ||
) | ||
|
@@ -96,7 +100,12 @@ def __init__( | |
|
||
assert not bias, "Quantized Linear doesn't support bias." | ||
|
||
self.n_bit = n_bit | ||
# Number of bits of weight tensor | ||
self.n_bit = quant_config.num_bits_weight | ||
|
||
# Quantize activation | ||
self.quantize_activation = quant_config.enable_activation_quantization | ||
|
||
# Flag to enable dequantize weight first, then do matmul. Useful for debugging. | ||
self.run_fake_quantize = False | ||
|
||
|
@@ -115,23 +124,40 @@ def quantize_weight_from_nn_linear(self, weight): | |
self.in_features, | ||
), f"Got unexpected weight of shape {weight.shape}, expected weight shape ({self.out_features}, {self.in_features})." | ||
w_q, scale, zp = quantize_tensor( | ||
weight, (1,), self.n_bit, self.is_symmetric, block_size=-1 | ||
weight, (1,), self.n_bit, self.is_symmetric_weight, block_size=-1 | ||
) | ||
w_dq = dequantize_tensor(w_q, scale, zp) | ||
self._load_quantized_weights(w_q, scale, zp) | ||
|
||
def forward(self, inputs): | ||
if not self.run_fake_quantize: | ||
if self.is_symmetric: | ||
return torch.mul(F.linear(inputs, self.weight), self.weight_scaler) | ||
if self.quantize_activation: | ||
inputs, act_s, _ = quantize_tensor(inputs, reduce_axis=(2,)) | ||
if not self.quantize_activation: | ||
result = F.linear(inputs, self.weight) | ||
else: | ||
out = torch.mul(F.linear(inputs, self.weight), self.weight_scaler) | ||
# We have to call jax because we need to do dot(int8, int8)->int32. | ||
# This semantic cannot be represented in torch. The inferred output dtype | ||
# will be int8 in torch, causing the dot result to overflow. | ||
result = torchjax.call_jax( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is it a bit confusing that when quantize_activation not enabled the inputs and self.weight are torch tensor and when it's enabled it's Jax arrays. At least we need more detailed comments here. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Here we have to call jax because we need to do Let me add a comment to make it clear |
||
jax.lax.dot_general, | ||
inputs, | ||
self.weight, | ||
(((2,), (1)), ((), ())), | ||
None, | ||
jnp.int32.dtype, | ||
) | ||
result = result * self.weight_scaler | ||
if self.quantize_activation: | ||
result = result * act_s | ||
if not self.is_symmetric_weight: | ||
zp_out = torch.einsum("...c,z->...z", inputs, self.zero_point) | ||
return out - zp_out | ||
result = result - zp_out | ||
return result | ||
else: | ||
# Fake quantization, debugging purpose. | ||
scaler = self.weight_scaler.unsqueeze(-1) | ||
if not self.is_symmetric: | ||
if not self.is_symmetric_weight: | ||
zero_point = self.zero_point.unsqueeze(-1) / scaler | ||
else: | ||
zero_point = None | ||
|
@@ -149,32 +175,37 @@ def __init__( | |
out_features, | ||
bias=False, | ||
device=None, | ||
is_symmetric=True, | ||
use_dot_general=False, | ||
block_size=128, | ||
n_bit=8, | ||
quant_config=QuantizationConfig(), | ||
): | ||
super().__init__() | ||
self.in_features = in_features | ||
self.out_features = out_features | ||
|
||
# Use dot general instead of einsum | ||
# Use dot general is slow now. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Known torch xla2 issue? Is there a bug tracker for this? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This should be an XLA issue I think, using |
||
self.use_dot_general = use_dot_general | ||
self.use_dot_general = False | ||
# Flatten einsum operands to 3D. XLA was slow if operands are 4D. But it's fixed now. | ||
# Same perf as non flattened one now. | ||
self.flatten = False | ||
|
||
self.block_size = block_size | ||
n_blocks = in_features // block_size | ||
self.block_size = quant_config.block_size_weight | ||
n_blocks = in_features // self.block_size | ||
|
||
assert ( | ||
not quant_config.enable_activation_quantization | ||
), "Activation quantization not supported for blockwise quantized matmul." | ||
|
||
if self.use_dot_general: | ||
weight = torch.ones( | ||
(n_blocks, out_features, block_size), dtype=torch.int8, device=device | ||
(n_blocks, out_features, self.block_size), | ||
dtype=torch.int8, | ||
device=device, | ||
) | ||
else: | ||
weight = torch.ones( | ||
(n_blocks, block_size, out_features), dtype=torch.int8, device=device | ||
(n_blocks, self.block_size, out_features), | ||
dtype=torch.int8, | ||
device=device, | ||
) | ||
self.register_buffer("weight", weight) | ||
|
||
|
@@ -183,16 +214,20 @@ def __init__( | |
) | ||
self.register_buffer("weight_scaler", weight_scaler) | ||
|
||
self.is_symmetric = is_symmetric | ||
if not self.is_symmetric: | ||
self.is_symmetric_weight = quant_config.is_symmetric_weight | ||
if not self.is_symmetric_weight: | ||
zero_point = torch.ones( | ||
(n_blocks, out_features), dtype=torch.bfloat16, device=device | ||
) | ||
self.register_buffer("zero_point", zero_point) | ||
else: | ||
self.register_buffer("zero_point", None) | ||
|
||
self.n_bit = n_bit | ||
self.n_bit = quant_config.num_bits_weight | ||
|
||
# Quantize activation | ||
self.quantize_activation = quant_config.enable_activation_quantization | ||
|
||
# Flag to enable dequantize weight first, then do matmul. Useful for debugging. | ||
self.run_fake_quantize = False | ||
|
||
|
@@ -211,112 +246,37 @@ def quantize_weight_from_nn_linear(self, weight): | |
self.in_features, | ||
), f"Unexpected weight shape ({self.out_features}, {self.in_features})." | ||
w_q, scale, zp = quantize_tensor( | ||
weight, (1,), self.n_bit, self.is_symmetric, self.block_size | ||
weight, (1,), self.n_bit, self.is_symmetric_weight, self.block_size | ||
) | ||
w_dq = dequantize_tensor(w_q, scale, zp) | ||
print("check qweight cosine dist: ", _calc_cosine_dist(weight, w_dq)) | ||
# breakpoint() | ||
self._load_quantized_weights(w_q, scale, zp) | ||
|
||
@staticmethod | ||
def blockwise_jax_kernel(inputs, weight, weight_scaler, zero_point): | ||
"""Blockwise Matmul kernel impl in JAX using einsum""" | ||
weight = weight.astype(jnp.int8) | ||
block_size = weight.shape[1] | ||
inputs_shape = inputs.shape | ||
inputs_new_shape = inputs_shape[:-1] + ( | ||
inputs_shape[-1] // block_size, | ||
block_size, | ||
) | ||
inputs = inputs.reshape(inputs_new_shape) | ||
out = jnp.einsum("scz,bdsc->bdsz", weight, inputs) | ||
out = jnp.einsum("bdsz,sz->bdz", out, weight_scaler) | ||
if zero_point is not None: | ||
zp_out = jnp.einsum("bdsc,sz->bdz", inputs, zero_point) | ||
out = out - zp_out | ||
return out | ||
|
||
@staticmethod | ||
def blockwise_jax_kernel_dot_general( | ||
inputs, weight, weight_scaler, zero_point | ||
): | ||
"""Blockwise Matmul kernel impl in JAX using dot general""" | ||
inputs_shape = inputs.shape | ||
block_size = weight.shape[2] | ||
bs = inputs_shape[0] | ||
inputs_new_shape = inputs_shape[:-1] + ( | ||
inputs_shape[-1] // block_size, | ||
block_size, | ||
) | ||
inputs = inputs.reshape(inputs_new_shape) | ||
inputs = jax.lax.collapse(inputs, 0, 2) | ||
out = jax.lax.dot_general( | ||
inputs, weight, dimension_numbers=([(2), (2)], [(1), (0)]) | ||
) | ||
out = jax.lax.dot_general( | ||
out, weight_scaler, dimension_numbers=([(0), (0)], [(2), (1)]) | ||
) | ||
out = jax.lax.transpose(out, [1, 0]) | ||
out = out.reshape((bs, -1) + out.shape[1:]) | ||
return out | ||
|
||
@staticmethod | ||
def blockwise_jax_kernel_einsum_flatten( | ||
inputs, weight, weight_scaler, zero_point | ||
): | ||
"""Blockwise Matmul kernel impl in JAX using einsum, with operands flattened""" | ||
weight = weight.astype(jnp.int8) | ||
block_size = weight.shape[1] | ||
inputs_shape = inputs.shape | ||
bs = inputs_shape[0] | ||
inputs_new_shape = inputs_shape[:-1] + ( | ||
inputs_shape[-1] // block_size, | ||
block_size, | ||
) | ||
inputs = inputs.reshape(inputs_new_shape) | ||
inputs = jax.lax.collapse(inputs, 0, 2) | ||
out = jnp.einsum("scz,bsc->bsz", weight, inputs) | ||
out = jnp.einsum("bsz,sz->bz", out, weight_scaler) | ||
out = out.reshape((bs, -1) + out.shape[1:]) | ||
return out | ||
|
||
def forward(self, inputs): | ||
if not self.run_fake_quantize: | ||
if self.use_dot_general: | ||
if self.use_dot_general or self.flatten: | ||
assert ( | ||
self.zero_point is None | ||
), "Blockwise quantized linear doesn't support zero_point in dot_general implementation." | ||
return torchjax.call_jax( | ||
WeightOnlyBlockwiseQuantizedLinear.blockwise_jax_kernel_dot_general, | ||
inputs, | ||
self.weight, | ||
self.weight_scaler, | ||
self.zero_point, | ||
) | ||
if self.flatten: | ||
assert ( | ||
self.zero_point is None | ||
), "Blockwise quantized linear doesn't support zero_point in einsum (flattened) implementation." | ||
return torchjax.call_jax( | ||
WeightOnlyBlockwiseQuantizedLinear.blockwise_jax_kernel_einsum_flatten, | ||
inputs, | ||
self.weight, | ||
self.weight_scaler, | ||
self.zero_point, | ||
) | ||
else: | ||
return torchjax.call_jax( | ||
WeightOnlyBlockwiseQuantizedLinear.blockwise_jax_kernel, | ||
inputs, | ||
self.weight, | ||
self.weight_scaler, | ||
self.zero_point, | ||
) | ||
), "Blockwise quantized linear doesn't support zero_point in dot_general or einsum flattened implementation." | ||
blockwise_matmul_kernel = ( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit, maybe the following is a little simpler: There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks, this will be cleaner, let me update |
||
blockwise_jax_kernel | ||
if not self.use_dot_general and not self.flatten | ||
else blockwise_jax_kernel_dot_general | ||
if self.use_dot_general | ||
else blockwise_jax_kernel_einsum_flatten | ||
) | ||
result = torchjax.call_jax( | ||
blockwise_matmul_kernel, | ||
inputs, | ||
self.weight, | ||
self.weight_scaler, | ||
self.zero_point, | ||
) | ||
return result | ||
else: | ||
# Fake quantization, debugging purpose. | ||
weight = self.weight.permute(2, 0, 1).to(torch.bfloat16) | ||
scaler = self.weight_scaler.unsqueeze(-1).transpose(1, 0) | ||
if not self.is_symmetric: | ||
if not self.is_symmetric_weight: | ||
zero_point = self.zero_point.unsqueeze(-1).transpose(1, 0) / scaler | ||
else: | ||
zero_point = None | ||
|
@@ -554,12 +514,16 @@ def __init__(self, n_heads, n_kv_heads, head_dim, hidden_size, device, env): | |
self.hidden_size = hidden_size | ||
|
||
LinearLayer = get_quantized_linear_layer(env.quant_config) | ||
linear_kwargs = {} | ||
if LinearLayer != torch.nn.Linear: | ||
linear_kwargs = {"quant_config": env.quant_config} | ||
|
||
self.wo = LinearLayer( | ||
n_heads * self.head_dim, | ||
hidden_size, | ||
bias=False, | ||
device=device, | ||
**linear_kwargs, | ||
) | ||
|
||
Kernel = ( | ||
|
@@ -578,25 +542,29 @@ def __init__(self, n_heads, n_kv_heads, head_dim, hidden_size, device, env): | |
(n_heads + 2 * self.n_kv_heads) * self.head_dim, | ||
bias=False, | ||
device=device, | ||
**linear_kwargs, | ||
) | ||
else: | ||
self.wq = LinearLayer( | ||
hidden_size, | ||
n_heads * self.head_dim, | ||
bias=False, | ||
device=device, | ||
**linear_kwargs, | ||
) | ||
self.wk = LinearLayer( | ||
hidden_size, | ||
self.n_kv_heads * self.head_dim, | ||
bias=False, | ||
device=device, | ||
**linear_kwargs, | ||
) | ||
self.wv = LinearLayer( | ||
hidden_size, | ||
self.n_kv_heads * self.head_dim, | ||
bias=False, | ||
device=device, | ||
**linear_kwargs, | ||
) | ||
|
||
def load_hook(self, state_dict, prefix, *args): | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we move this code to else?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, we cannot move this code to else. This is an extra step for activation quant