Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add activation quantization support to per-channel quantized linear layers #105

Merged
merged 6 commits into from
Jun 12, 2024

Conversation

lsy323
Copy link
Collaborator

@lsy323 lsy323 commented May 25, 2024

Activation quantization is only supported with per-channel quantized model.

Enable activation quantization with per-channel quant by using the flag --quantize_activation=True

The activation will be quantized to int8 and then do a int8 x int8 matmul operation. We need to call lax.dot_general because with torch matmul ops we cannot control the output dtype (int8 by default, and the output is easy to overflow). We use int32 as accumulation dtype to avoid overflow.

The correctness is verified in unit tests and llama/gemma model. Now get same performance on 7B int8 per-channel BS=32. In depth investigation is needed to understand the performance impact.

@lsy323 lsy323 marked this pull request as draft May 25, 2024 02:47
@lsy323 lsy323 marked this pull request as ready for review May 25, 2024 05:04
else:
out = torch.mul(F.linear(inputs, self.weight), self.weight_scaler)
result = torchjax.call_jax(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it a bit confusing that when quantize_activation not enabled the inputs and self.weight are torch tensor and when it's enabled it's Jax arrays. At least we need more detailed comments here.

Copy link
Collaborator Author

@lsy323 lsy323 Jun 11, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here we have to call jax because we need to do dot(int8, int8)->int32. This semantic cannot be represented in torch now. In torch, the inferred output dtype of 2 int8 operands will be int8, causing the dot result to overflow.. The dot_general in JAX support specifying output dtype, hence we use it here.

Let me add a comment to make it clear

use_dot_general=False,
block_size=128,
n_bit=8,
quant_config=QuantizationConfig(),
):
super().__init__()
self.in_features = in_features
self.out_features = out_features

# Use dot general instead of einsum
# Use dot general is slow now.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Known torch xla2 issue? Is there a bug tracker for this?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be an XLA issue I think, using dot_general and einsum should have the same semantics

self.zero_point,
)
), "Blockwise quantized linear doesn't support zero_point in dot_general or einsum flattened implementation."
blockwise_matmul_kernel = (
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit, maybe the following is a little simpler:
blockwise_matmul_kernel = (
blockwise_jax_kernel_dot_general
if self.use_dot_general
else blockwise_jax_kernel_einsum_flatten
if self.flatten
else blockwise_jax_kernel
)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, this will be cleaner, let me update

return out


def blockwise_jax_kernel_dot_general(inputs, weight, weight_scaler, zero_point):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since torch xla2 has fixed the torch einsum lowering, do we still need this?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, we don't need to call jax for it. Thanks for the heads up. Since I'm moving the existing kernel implementation to this new file, I will switch to torch in the following PR.

)
w_dq = dequantize_tensor(w_q, scale, zp)
self._load_quantized_weights(w_q, scale, zp)

def forward(self, inputs):
if not self.run_fake_quantize:
if self.is_symmetric:
return torch.mul(F.linear(inputs, self.weight), self.weight_scaler)
if self.quantize_activation:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we move this code to else?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, we cannot move this code to else. This is an extra step for activation quant



def blockwise_jax_kernel_einsum_flatten(
inputs, weight, weight_scaler, zero_point
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you handle zero_point is not None case?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add debug print to debug

remove print, add bias to asym quant tests

lint
@lsy323 lsy323 merged commit 8a125b6 into AI-Hypercomputer:main Jun 12, 2024
4 checks passed
@lsy323 lsy323 deleted the lsiyuan/act-quant branch June 12, 2024 22:29
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants