Skip to content

Commit

Permalink
Initial support for 8da4w QAT (#138)
Browse files Browse the repository at this point in the history
Summary: This commit adds support for QAT, where linear layers
are fake quantized with int8 per token dynamic activations (8da)
and int4 grouped per channel weights (4w). This initial
implementation uses the same module swap approach as 8da4w PTQ
for simplicity and code reuse. In the future, we may wish to
consider migrating both flows to use tensor subclasses for
better composability with other PyTorch features.

Test Plan:
python test/quantization/test_qat.py -k test_fake_quantize_per_channel_group
python test/quantization/test_qat.py -k test_fake_quantize_per_token
python test/quantization/test_qat.py -k test_qat_8da4w_linear
python test/quantization/test_qat.py -k test_qat_8da4w_quantizer

Reviewers: jerryzh168, cpuhrsch, HDCharles

Subscribers: jerryzh168, cpuhrsch, HDCharles, supriyar

Tasks: #86
  • Loading branch information
andrewor14 authored Apr 18, 2024
1 parent 2bc1617 commit d3f4a70
Show file tree
Hide file tree
Showing 5 changed files with 460 additions and 12 deletions.
177 changes: 177 additions & 0 deletions test/quantization/test_qat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,177 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

# mypy: ignore-errors
# This test takes a long time to run

import copy
import unittest

import torch
from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib # noqa: F401
from torchao.quantization.prototype.qat import (
_choose_qparams_per_token_asymmetric,
fake_quantize_per_channel_group,
fake_quantize_per_token,
)
from torchao.quantization.quant_primitives import get_group_qparams_symmetric
from torchao.quantization.utils import TORCH_VERSION_AFTER_2_3


# TODO: put this in a common test utils file
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear1 = torch.nn.Linear(64, 32, bias=False).to(torch.float)
self.linear2 = torch.nn.Linear(32, 64, bias=False).to(torch.float)

def example_inputs(self):
return (torch.randn(1, 64).to(torch.float),)

def forward(self, x):
x = self.linear1(x)
x = self.linear2(x)
return x


class TestQAT(unittest.TestCase):
SEED = 123

def _get_qmin_qmax(self, n_bit: int):
qmin = -(2 ** (n_bit - 1))
qmax = 2 ** (n_bit - 1) - 1
return (qmin, qmax)

@unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "skipping when torch verion is 2.3 or lower")
def test_fake_quantize_per_channel_group(self):
n_bit = 4
(qmin, qmax) = self._get_qmin_qmax(n_bit)
group_size = 128

torch.manual_seed(self.SEED)
x = torch.randn(100, 256).requires_grad_()
(s, zp) = get_group_qparams_symmetric(x, n_bit, group_size)
x2 = copy.deepcopy(x)

# fake quant op
out = fake_quantize_per_channel_group(
x, s, zp, qmin, qmax, group_size,
)
out.sum().backward()

# compare against PTQ ops
out_ptq = torch.ops.quantized_decomposed.quantize_per_channel_group(
x2, s, zp, qmin, qmax, torch.int8, group_size,
)
out_ptq = torch.ops.quantized_decomposed.dequantize_per_channel_group(
out_ptq, s, zp, qmin, qmax, torch.int8, group_size, torch.float32,
)
torch.testing.assert_close(out, out_ptq, atol=0, rtol=0)

@unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "skipping when torch verion is 2.3 or lower")
def test_fake_quantize_per_token(self):
(qmin, qmax) = self._get_qmin_qmax(8)

torch.manual_seed(self.SEED)
x = torch.randn(100, 256).requires_grad_()
x2 = copy.deepcopy(x)
# TODO: use torch.ops.aten.quantized_decomposed version instead
(s, zp) = _choose_qparams_per_token_asymmetric(
x,
torch.int8, # not used
)

# fake quant op
out = fake_quantize_per_token(x, s, zp, qmin, qmax)
out.sum().backward()

# compare against PTQ ops
out_ptq = torch.ops.quantized_decomposed.quantize_per_token(
x2, s, zp, qmin, qmax, torch.int8,
)
out_ptq = torch.ops.quantized_decomposed.dequantize_per_token(
out_ptq, s, zp, qmin, qmax, torch.int8, torch.float32,
)
torch.testing.assert_close(out, out_ptq, atol=0, rtol=0)

def _set_ptq_weight(
self,
ptq_linear: "Int8DynActInt4WeightLinear",
fp32_weight: torch.Tensor,
group_size: int,
):
"""
Set the weight to the quantized version of the given fp32 weights,
for making linear outputs comparable with QAT.
"""
n_bit = 4
(qmin, qmax) = self._get_qmin_qmax(n_bit)
(s, zp) = get_group_qparams_symmetric(fp32_weight, n_bit, group_size)
q_weight = torch.ops.quantized_decomposed.quantize_per_channel_group(
fp32_weight, s, zp, qmin, qmax, torch.int8, group_size,
)
ptq_linear.weight = q_weight
ptq_linear.scales = s
ptq_linear.zeros = zp

@unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "skipping when torch verion is 2.3 or lower")
def test_qat_8da4w_linear(self):
from torchao.quantization.prototype.qat import Int8DynActInt4WeightQATLinear
from torchao.quantization.GPTQ import Int8DynActInt4WeightLinear

group_size = 128
torch.manual_seed(self.SEED)
qat_linear = Int8DynActInt4WeightQATLinear(
256, 688, bias=False, groupsize=group_size,
)
ptq_linear = Int8DynActInt4WeightLinear(
256, 688, bias=False, groupsize=group_size,
)

# Force the weights to be the same
self._set_ptq_weight(ptq_linear, qat_linear.weight, group_size)

# Compare linear values
torch.manual_seed(self.SEED)
x = torch.randn(100, 256)
x2 = copy.deepcopy(x)
qat_out = qat_linear(x)
ptq_out = ptq_linear(x2)
torch.testing.assert_close(ptq_out, qat_out, atol=0, rtol=0)

@unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "skipping when torch verion is 2.3 or lower")
def test_qat_8da4w_quantizer(self):
from torchao.quantization.prototype.qat import Int8DynActInt4WeightQATQuantizer
from torchao.quantization.GPTQ import Int8DynActInt4WeightQuantizer

group_size = 16
torch.manual_seed(self.SEED)
m = M()
m2 = copy.deepcopy(m)
qat_quantizer = Int8DynActInt4WeightQATQuantizer(groupsize=group_size)
ptq_quantizer = Int8DynActInt4WeightQuantizer(groupsize=group_size)
qat_model = qat_quantizer.prepare(m)
ptq_model = ptq_quantizer.quantize(m2)

# Force the weights to be the same
self._set_ptq_weight(
ptq_model.linear1, qat_model.linear1.weight, group_size,
)
self._set_ptq_weight(
ptq_model.linear2, qat_model.linear2.weight, group_size,
)

# Compare model values
torch.manual_seed(self.SEED)
x = m.example_inputs()
x2 = copy.deepcopy(x)
qat_out = qat_model(*x)
ptq_out = ptq_model(*x2)
torch.testing.assert_close(ptq_out, qat_out, atol=0, rtol=0)


if __name__ == "__main__":
unittest.main()
4 changes: 2 additions & 2 deletions test/quantization/test_quant_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,12 @@
import unittest
import torch
from torchao.quantization.quant_primitives import get_group_qparams_symmetric
from torchao.quantization.utils import TORCH_VERSION_AFTER_2_4
from torchao.quantization.utils import TORCH_VERSION_AFTER_2_3

class TestQuantPrimitives(unittest.TestCase):
SEED = 123

@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch verion is 2.3 or lower")
@unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "skipping when torch verion is 2.3 or lower")
def test_get_group_qparams_symmetric(self):
"""
Test that `get_group_qparams_symmetric` produces the exact same scales as
Expand Down
36 changes: 26 additions & 10 deletions torchao/quantization/GPTQ.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
# LICENSE file in the root directory of this source tree.

import logging
from typing import Optional, List
from typing import Optional, List, Type

import torch

Expand Down Expand Up @@ -1120,21 +1120,21 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:
self.precision,
)


def replace_linear_8da4w(
module,
groupsize,
padding_allowed,
precision,
scales_precision,
def _replace_linear_8da4w(
module: torch.nn.Module,
groupsize: int,
padding_allowed: bool,
precision: torch.dtype,
scales_precision: torch.dtype,
linear_class: Type[torch.nn.Module],
):
for name, child in module.named_children():
if isinstance(child, nn.Linear):
if _check_linear_int4_k(child.in_features, groupsize) or padding_allowed:
setattr(
module,
name,
Int8DynActInt4WeightLinear(
linear_class(
child.in_features,
child.out_features,
bias=False,
Expand All @@ -1144,14 +1144,30 @@ def replace_linear_8da4w(
),
)
else:
replace_linear_8da4w(
_replace_linear_8da4w(
child,
groupsize,
padding_allowed,
precision,
scales_precision,
)

def replace_linear_8da4w(
module: torch.nn.Module,
groupsize: int,
padding_allowed: bool,
precision: torch.dtype,
scales_precision: torch.dtype,
):
_replace_linear_8da4w(
module,
groupsize,
padding_allowed,
precision,
scales_precision,
Int8DynActInt4WeightLinear,
)

class Int8DynActInt4WeightQuantizer(Quantizer):
def __init__(
self,
Expand Down
Empty file.
Loading

0 comments on commit d3f4a70

Please sign in to comment.