Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add activation quantization support to per-channel quantized linear layers #105

Merged
merged 6 commits into from
Jun 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions jetstream_pt/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,11 @@

# Quantization related flags
flags.DEFINE_bool("quantize_weights", False, "weight quantization")
flags.DEFINE_bool(
"quantize_activation",
False,
"Quantize Q,K,V projection and FeedForward activation.",
)
flags.DEFINE_string(
"quantize_type", "int8_per_channel", "Type of quantization."
)
Expand Down Expand Up @@ -90,6 +95,9 @@ def create_quantization_config_from_flags():
config.enable_weight_quantization = True
config.num_bits_weight = 8 if "int8" in quantize_type else 4
config.is_blockwise_weight = "blockwise" in quantize_type

config.enable_activation_quantization = FLAGS.quantize_activation

config.enable_kv_quantization = FLAGS.quantize_kv_cache
return config

Expand Down
4 changes: 4 additions & 0 deletions jetstream_pt/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,10 @@ class QuantizationConfig:
enable_weight_quantization: bool = False
num_bits_weight: int = 8
is_blockwise_weight: bool = False
block_size_weight: int = 128
is_symmetric_weight: bool = True

enable_activation_quantization: bool = False

enable_kv_quantization: bool = False

Expand Down
202 changes: 85 additions & 117 deletions jetstream_pt/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,14 @@
import torch_xla2
from jax import lax
from jetstream_pt import torchjax
from jetstream_pt.environment import QuantizationConfig
from jetstream_pt.quantize import (
dequantize_tensor,
load_q_weight_helper,
quantize_tensor,
blockwise_jax_kernel,
blockwise_jax_kernel_dot_general,
blockwise_jax_kernel_einsum_flatten,
)
from torch import nn
from . import attention_kernel as ak
Expand Down Expand Up @@ -68,8 +72,7 @@ def __init__(
out_features,
bias=False,
device=None,
is_symmetric=True,
n_bit=8,
quant_config=QuantizationConfig(),
):
super().__init__()
self.in_features = in_features
Expand All @@ -85,8 +88,9 @@ def __init__(
)
self.register_buffer("weight_scaler", weight_scaler)

self.is_symmetric = is_symmetric
if not is_symmetric:
self.is_symmetric_weight = quant_config.is_symmetric_weight

if not self.is_symmetric_weight:
zero_point = torch.ones(
(out_features,), dtype=torch.bfloat16, device=device
)
Expand All @@ -96,7 +100,12 @@ def __init__(

assert not bias, "Quantized Linear doesn't support bias."

self.n_bit = n_bit
# Number of bits of weight tensor
self.n_bit = quant_config.num_bits_weight

# Quantize activation
self.quantize_activation = quant_config.enable_activation_quantization

# Flag to enable dequantize weight first, then do matmul. Useful for debugging.
self.run_fake_quantize = False

Expand All @@ -115,23 +124,40 @@ def quantize_weight_from_nn_linear(self, weight):
self.in_features,
), f"Got unexpected weight of shape {weight.shape}, expected weight shape ({self.out_features}, {self.in_features})."
w_q, scale, zp = quantize_tensor(
weight, (1,), self.n_bit, self.is_symmetric, block_size=-1
weight, (1,), self.n_bit, self.is_symmetric_weight, block_size=-1
)
w_dq = dequantize_tensor(w_q, scale, zp)
self._load_quantized_weights(w_q, scale, zp)

def forward(self, inputs):
if not self.run_fake_quantize:
if self.is_symmetric:
return torch.mul(F.linear(inputs, self.weight), self.weight_scaler)
if self.quantize_activation:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we move this code to else?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, we cannot move this code to else. This is an extra step for activation quant

inputs, act_s, _ = quantize_tensor(inputs, reduce_axis=(2,))
if not self.quantize_activation:
result = F.linear(inputs, self.weight)
else:
out = torch.mul(F.linear(inputs, self.weight), self.weight_scaler)
# We have to call jax because we need to do dot(int8, int8)->int32.
# This semantic cannot be represented in torch. The inferred output dtype
# will be int8 in torch, causing the dot result to overflow.
result = torchjax.call_jax(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it a bit confusing that when quantize_activation not enabled the inputs and self.weight are torch tensor and when it's enabled it's Jax arrays. At least we need more detailed comments here.

Copy link
Collaborator Author

@lsy323 lsy323 Jun 11, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here we have to call jax because we need to do dot(int8, int8)->int32. This semantic cannot be represented in torch now. In torch, the inferred output dtype of 2 int8 operands will be int8, causing the dot result to overflow.. The dot_general in JAX support specifying output dtype, hence we use it here.

Let me add a comment to make it clear

jax.lax.dot_general,
inputs,
self.weight,
(((2,), (1)), ((), ())),
None,
jnp.int32.dtype,
)
result = result * self.weight_scaler
if self.quantize_activation:
result = result * act_s
if not self.is_symmetric_weight:
zp_out = torch.einsum("...c,z->...z", inputs, self.zero_point)
return out - zp_out
result = result - zp_out
return result
else:
# Fake quantization, debugging purpose.
scaler = self.weight_scaler.unsqueeze(-1)
if not self.is_symmetric:
if not self.is_symmetric_weight:
zero_point = self.zero_point.unsqueeze(-1) / scaler
else:
zero_point = None
Expand All @@ -149,32 +175,37 @@ def __init__(
out_features,
bias=False,
device=None,
is_symmetric=True,
use_dot_general=False,
block_size=128,
n_bit=8,
quant_config=QuantizationConfig(),
):
super().__init__()
self.in_features = in_features
self.out_features = out_features

# Use dot general instead of einsum
# Use dot general is slow now.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Known torch xla2 issue? Is there a bug tracker for this?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be an XLA issue I think, using dot_general and einsum should have the same semantics

self.use_dot_general = use_dot_general
self.use_dot_general = False
# Flatten einsum operands to 3D. XLA was slow if operands are 4D. But it's fixed now.
# Same perf as non flattened one now.
self.flatten = False

self.block_size = block_size
n_blocks = in_features // block_size
self.block_size = quant_config.block_size_weight
n_blocks = in_features // self.block_size

assert (
not quant_config.enable_activation_quantization
), "Activation quantization not supported for blockwise quantized matmul."

if self.use_dot_general:
weight = torch.ones(
(n_blocks, out_features, block_size), dtype=torch.int8, device=device
(n_blocks, out_features, self.block_size),
dtype=torch.int8,
device=device,
)
else:
weight = torch.ones(
(n_blocks, block_size, out_features), dtype=torch.int8, device=device
(n_blocks, self.block_size, out_features),
dtype=torch.int8,
device=device,
)
self.register_buffer("weight", weight)

Expand All @@ -183,16 +214,20 @@ def __init__(
)
self.register_buffer("weight_scaler", weight_scaler)

self.is_symmetric = is_symmetric
if not self.is_symmetric:
self.is_symmetric_weight = quant_config.is_symmetric_weight
if not self.is_symmetric_weight:
zero_point = torch.ones(
(n_blocks, out_features), dtype=torch.bfloat16, device=device
)
self.register_buffer("zero_point", zero_point)
else:
self.register_buffer("zero_point", None)

self.n_bit = n_bit
self.n_bit = quant_config.num_bits_weight

# Quantize activation
self.quantize_activation = quant_config.enable_activation_quantization

# Flag to enable dequantize weight first, then do matmul. Useful for debugging.
self.run_fake_quantize = False

Expand All @@ -211,112 +246,37 @@ def quantize_weight_from_nn_linear(self, weight):
self.in_features,
), f"Unexpected weight shape ({self.out_features}, {self.in_features})."
w_q, scale, zp = quantize_tensor(
weight, (1,), self.n_bit, self.is_symmetric, self.block_size
weight, (1,), self.n_bit, self.is_symmetric_weight, self.block_size
)
w_dq = dequantize_tensor(w_q, scale, zp)
print("check qweight cosine dist: ", _calc_cosine_dist(weight, w_dq))
# breakpoint()
self._load_quantized_weights(w_q, scale, zp)

@staticmethod
def blockwise_jax_kernel(inputs, weight, weight_scaler, zero_point):
"""Blockwise Matmul kernel impl in JAX using einsum"""
weight = weight.astype(jnp.int8)
block_size = weight.shape[1]
inputs_shape = inputs.shape
inputs_new_shape = inputs_shape[:-1] + (
inputs_shape[-1] // block_size,
block_size,
)
inputs = inputs.reshape(inputs_new_shape)
out = jnp.einsum("scz,bdsc->bdsz", weight, inputs)
out = jnp.einsum("bdsz,sz->bdz", out, weight_scaler)
if zero_point is not None:
zp_out = jnp.einsum("bdsc,sz->bdz", inputs, zero_point)
out = out - zp_out
return out

@staticmethod
def blockwise_jax_kernel_dot_general(
inputs, weight, weight_scaler, zero_point
):
"""Blockwise Matmul kernel impl in JAX using dot general"""
inputs_shape = inputs.shape
block_size = weight.shape[2]
bs = inputs_shape[0]
inputs_new_shape = inputs_shape[:-1] + (
inputs_shape[-1] // block_size,
block_size,
)
inputs = inputs.reshape(inputs_new_shape)
inputs = jax.lax.collapse(inputs, 0, 2)
out = jax.lax.dot_general(
inputs, weight, dimension_numbers=([(2), (2)], [(1), (0)])
)
out = jax.lax.dot_general(
out, weight_scaler, dimension_numbers=([(0), (0)], [(2), (1)])
)
out = jax.lax.transpose(out, [1, 0])
out = out.reshape((bs, -1) + out.shape[1:])
return out

@staticmethod
def blockwise_jax_kernel_einsum_flatten(
inputs, weight, weight_scaler, zero_point
):
"""Blockwise Matmul kernel impl in JAX using einsum, with operands flattened"""
weight = weight.astype(jnp.int8)
block_size = weight.shape[1]
inputs_shape = inputs.shape
bs = inputs_shape[0]
inputs_new_shape = inputs_shape[:-1] + (
inputs_shape[-1] // block_size,
block_size,
)
inputs = inputs.reshape(inputs_new_shape)
inputs = jax.lax.collapse(inputs, 0, 2)
out = jnp.einsum("scz,bsc->bsz", weight, inputs)
out = jnp.einsum("bsz,sz->bz", out, weight_scaler)
out = out.reshape((bs, -1) + out.shape[1:])
return out

def forward(self, inputs):
if not self.run_fake_quantize:
if self.use_dot_general:
if self.use_dot_general or self.flatten:
assert (
self.zero_point is None
), "Blockwise quantized linear doesn't support zero_point in dot_general implementation."
return torchjax.call_jax(
WeightOnlyBlockwiseQuantizedLinear.blockwise_jax_kernel_dot_general,
inputs,
self.weight,
self.weight_scaler,
self.zero_point,
)
if self.flatten:
assert (
self.zero_point is None
), "Blockwise quantized linear doesn't support zero_point in einsum (flattened) implementation."
return torchjax.call_jax(
WeightOnlyBlockwiseQuantizedLinear.blockwise_jax_kernel_einsum_flatten,
inputs,
self.weight,
self.weight_scaler,
self.zero_point,
)
else:
return torchjax.call_jax(
WeightOnlyBlockwiseQuantizedLinear.blockwise_jax_kernel,
inputs,
self.weight,
self.weight_scaler,
self.zero_point,
)
), "Blockwise quantized linear doesn't support zero_point in dot_general or einsum flattened implementation."
blockwise_matmul_kernel = (
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit, maybe the following is a little simpler:
blockwise_matmul_kernel = (
blockwise_jax_kernel_dot_general
if self.use_dot_general
else blockwise_jax_kernel_einsum_flatten
if self.flatten
else blockwise_jax_kernel
)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, this will be cleaner, let me update

blockwise_jax_kernel
if not self.use_dot_general and not self.flatten
else blockwise_jax_kernel_dot_general
if self.use_dot_general
else blockwise_jax_kernel_einsum_flatten
)
result = torchjax.call_jax(
blockwise_matmul_kernel,
inputs,
self.weight,
self.weight_scaler,
self.zero_point,
)
return result
else:
# Fake quantization, debugging purpose.
weight = self.weight.permute(2, 0, 1).to(torch.bfloat16)
scaler = self.weight_scaler.unsqueeze(-1).transpose(1, 0)
if not self.is_symmetric:
if not self.is_symmetric_weight:
zero_point = self.zero_point.unsqueeze(-1).transpose(1, 0) / scaler
else:
zero_point = None
Expand Down Expand Up @@ -554,12 +514,16 @@ def __init__(self, n_heads, n_kv_heads, head_dim, hidden_size, device, env):
self.hidden_size = hidden_size

LinearLayer = get_quantized_linear_layer(env.quant_config)
linear_kwargs = {}
if LinearLayer != torch.nn.Linear:
linear_kwargs = {"quant_config": env.quant_config}

self.wo = LinearLayer(
n_heads * self.head_dim,
hidden_size,
bias=False,
device=device,
**linear_kwargs,
)

Kernel = (
Expand All @@ -578,25 +542,29 @@ def __init__(self, n_heads, n_kv_heads, head_dim, hidden_size, device, env):
(n_heads + 2 * self.n_kv_heads) * self.head_dim,
bias=False,
device=device,
**linear_kwargs,
)
else:
self.wq = LinearLayer(
hidden_size,
n_heads * self.head_dim,
bias=False,
device=device,
**linear_kwargs,
)
self.wk = LinearLayer(
hidden_size,
self.n_kv_heads * self.head_dim,
bias=False,
device=device,
**linear_kwargs,
)
self.wv = LinearLayer(
hidden_size,
self.n_kv_heads * self.head_dim,
bias=False,
device=device,
**linear_kwargs,
)

def load_hook(self, state_dict, prefix, *args):
Expand Down
Loading
Loading