-
Notifications
You must be signed in to change notification settings - Fork 15
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add activation quantization support to per-channel quantized linear layers #105
Conversation
else: | ||
out = torch.mul(F.linear(inputs, self.weight), self.weight_scaler) | ||
result = torchjax.call_jax( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it a bit confusing that when quantize_activation not enabled the inputs and self.weight are torch tensor and when it's enabled it's Jax arrays. At least we need more detailed comments here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here we have to call jax because we need to do dot(int8, int8)->int32
. This semantic cannot be represented in torch now. In torch, the inferred output dtype of 2 int8
operands will be int8
, causing the dot result to overflow.. The dot_general
in JAX support specifying output dtype, hence we use it here.
Let me add a comment to make it clear
use_dot_general=False, | ||
block_size=128, | ||
n_bit=8, | ||
quant_config=QuantizationConfig(), | ||
): | ||
super().__init__() | ||
self.in_features = in_features | ||
self.out_features = out_features | ||
|
||
# Use dot general instead of einsum | ||
# Use dot general is slow now. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Known torch xla2 issue? Is there a bug tracker for this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should be an XLA issue I think, using dot_general
and einsum
should have the same semantics
self.zero_point, | ||
) | ||
), "Blockwise quantized linear doesn't support zero_point in dot_general or einsum flattened implementation." | ||
blockwise_matmul_kernel = ( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit, maybe the following is a little simpler:
blockwise_matmul_kernel = (
blockwise_jax_kernel_dot_general
if self.use_dot_general
else blockwise_jax_kernel_einsum_flatten
if self.flatten
else blockwise_jax_kernel
)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, this will be cleaner, let me update
return out | ||
|
||
|
||
def blockwise_jax_kernel_dot_general(inputs, weight, weight_scaler, zero_point): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since torch xla2 has fixed the torch einsum lowering, do we still need this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, we don't need to call jax for it. Thanks for the heads up. Since I'm moving the existing kernel implementation to this new file, I will switch to torch in the following PR.
) | ||
w_dq = dequantize_tensor(w_q, scale, zp) | ||
self._load_quantized_weights(w_q, scale, zp) | ||
|
||
def forward(self, inputs): | ||
if not self.run_fake_quantize: | ||
if self.is_symmetric: | ||
return torch.mul(F.linear(inputs, self.weight), self.weight_scaler) | ||
if self.quantize_activation: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we move this code to else?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, we cannot move this code to else. This is an extra step for activation quant
|
||
|
||
def blockwise_jax_kernel_einsum_flatten( | ||
inputs, weight, weight_scaler, zero_point |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you handle zero_point is not None case?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, there is an assertion on the caller side https://github.com/google/jetstream-pytorch/blob/main/jetstream_pt/layers.py#L297-L299
add debug print to debug remove print, add bias to asym quant tests lint
0c42219
to
8d613e7
Compare
Activation quantization is only supported with per-channel quantized model.
Enable activation quantization with per-channel quant by using the flag
--quantize_activation=True
The activation will be quantized to int8 and then do a int8 x int8 matmul operation. We need to call
lax.dot_general
because with torch matmul ops we cannot control the output dtype (int8 by default, and the output is easy to overflow). We use int32 as accumulation dtype to avoid overflow.The correctness is verified in unit tests and llama/gemma model. Now get same performance on 7B int8 per-channel BS=32. In depth investigation is needed to understand the performance impact.