Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor QModuleMixin and Calibration and fix stream-lining bug #249

Merged
merged 9 commits into from
Jul 19, 2024
52 changes: 38 additions & 14 deletions optimum/quanto/calibrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ def __init__(self, *args, momentum: float = 0.9, streamline=True, debug=False, *
self.streamline = streamline
if streamline:
self.modules_qactivations = {}
self.streamline_hooks = {}
self.debug = debug

def __torch_function__(self, func, types, args=(), kwargs=None):
Expand Down Expand Up @@ -112,8 +113,15 @@ def __exit__(self, exc_type, exc_val, exc_tb):
super().__exit__(exc_type, exc_val, exc_tb)
self.pre_handle.remove()
self.post_handle.remove()
if self.streamline:
for handle in self.streamline_hooks.values():
handle.remove()

def calibrate_input(self, module: torch.nn.Module, input, momentum: float = 0.9):
"""Calibrate a module input scale

This is registered as a global hook that is called before any module forward pre hook.
"""
if isinstance(module, QModuleMixin) and module.activation_qtype is not None:
input = input[0]
if isinstance(input, QBytesTensor):
Expand All @@ -123,36 +131,37 @@ def calibrate_input(self, module: torch.nn.Module, input, momentum: float = 0.9)
# Evaluate the best scale
input_scale = absmax_scale(input, module.activation_qtype)
module.input_scale = _updated_scale(module.input_scale, input_scale, momentum)
if self.streamline and module not in self.streamline_hooks:
# Add a hook to tag the module outputs (after the module quantization hook in QModuleMixin)
self.streamline_hooks[module] = module.register_forward_hook(self.tag_outputs)
return input

def calibrate_output(
self,
module: torch.nn.Module,
input,
output,
input: torch.Tensor,
output: torch.Tensor,
):
"""Calibrate a module output scale

This is registered as a global hook that is called before any module forward hook.

When the module is a QModuleMixin, its outputs are not quantized yet because they
are only quantized in the QModuleMixin.quantize_output forward hook.
"""
if isinstance(module, (QModuleMixin)) and module.activation_qtype is not None:
# Re-evaluate raw module output
qoutput = module.qforward(input[0])
if isinstance(qoutput, QBytesTensor):
qoutput = qoutput.dequantize()
# Evaluate the optimal scale per-tensor and update output scale
output_scale = absmax_scale(qoutput, module.activation_qtype, axis=None)
output_scale = absmax_scale(output, module.activation_qtype, axis=None)
module.output_scale = _updated_scale(module.output_scale, output_scale, self.momentum)
# Re-evaluate output with the correct output scale
output = module.forward(input[0])
if isinstance(output, QBytesTensor):
# Mark the outputs as generated by this module
output.src_module = module
return output
else:
if self.streamline:
for name, child in module.named_children():
if isinstance(child, QModuleMixin) and child.activation_qtype is not None:
qactivations_required = self.modules_qactivations.get(child, False)
if not qactivations_required:
# Disable activations for this child as its outputs are only consumed by incompatible functions.
child.activation_qtype = None
# Disable output quantization for this child as its outputs are only consumed by incompatible functions.
child.disable_output_quantization()
if self.debug:
for name, child in module.named_children():
if isinstance(child, QModuleMixin):
Expand All @@ -163,3 +172,18 @@ def calibrate_output(
else:
trace += f" quantized to {child.activation_qtype} with scale {child.output_scale}."
print(trace)

def tag_outputs(
self,
module: torch.nn.Module,
input: torch.Tensor,
output: torch.Tensor,
):
"""Mark outputs as generated by a module

This is called as a module forward hook that is called after the QModuleMixin.quantize_output
forward hook.

This is useful in streamline mode to identify the module that generated a specific QTensor.
"""
output.src_module = module
8 changes: 2 additions & 6 deletions optimum/quanto/nn/qconv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

import torch

from ..tensor import Optimizer, QBytesTensor, qtype, quantize_activation
from ..tensor import Optimizer, qtype
from .qmodule import QModuleMixin, register_qmodule


Expand Down Expand Up @@ -46,9 +46,5 @@ def qcreate(
optimizer=optimizer,
)

def qforward(self, input: torch.Tensor) -> torch.Tensor:
if self.activation_qtype is not None and not isinstance(input, QBytesTensor):
# Quantize tensor to be able to take advantage of accelerated conv2d
input = quantize_activation(input, qtype=self.activation_qtype, scale=self.input_scale)
# We always use quantized weights
def forward(self, input: torch.Tensor) -> torch.Tensor:
return self._conv_forward(input, self.qweight, self.bias)
2 changes: 1 addition & 1 deletion optimum/quanto/nn/qlayernorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,5 +47,5 @@ def qcreate(
optimizer=None, # We never quantize QLayerNorm weights
)

def qforward(self, input: torch.Tensor) -> torch.Tensor:
def forward(self, input: torch.Tensor) -> torch.Tensor:
return torch.nn.functional.layer_norm(input, self.normalized_shape, self.weight, self.bias, self.eps)
9 changes: 3 additions & 6 deletions optimum/quanto/nn/qlinear.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

import torch

from ..tensor import Optimizer, QBytesTensor, qtype, quantize_activation
from ..tensor import Optimizer, qtype
from .qmodule import QModuleMixin, register_qmodule


Expand All @@ -38,11 +38,8 @@ def qcreate(
weights=weights,
activations=activations,
optimizer=optimizer,
quantize_input=True,
)

def qforward(self, input: torch.Tensor) -> torch.Tensor:
if self.activation_qtype is not None and not isinstance(input, QBytesTensor):
# Quantize activations to be able to take advantage of accelerated matmul
input = quantize_activation(input, qtype=self.activation_qtype, scale=self.input_scale)
# We always use quantized weights
def forward(self, input: torch.Tensor) -> torch.Tensor:
return torch.nn.functional.linear(input, self.qweight, bias=self.bias)
48 changes: 31 additions & 17 deletions optimum/quanto/nn/qmodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def register_qmodule(module_cls):
The QModule must implement two abstract methods:

- qcreate: class method to instantiate a new QModule from an nn.Module, without copying its weights,
- qforward: instance method for quantized inference.
- forward: instance method for quantized inference.

The code to register a new module looks like:

Expand All @@ -61,7 +61,7 @@ def qcreate(cls,
optimizer: Optional[Optimizer] = None):
...

def qforward(self, input: torch.Tensor) -> torch.Tensor:
def forward(self, input: torch.Tensor) -> torch.Tensor:
...
```

Expand Down Expand Up @@ -94,6 +94,7 @@ def __init__(
weights: Optional[Union[qtype, str]] = None,
activations: Optional[Union[qtype, str]] = None,
optimizer: Optional[Optimizer] = None,
quantize_input: Optional[bool] = False,
**kwargs,
):
# The tests below are meant to help people writing their own quantized Module class
Expand Down Expand Up @@ -122,10 +123,19 @@ def __init__(
if in_features % group_size == 0:
self.weight_group_size = group_size
self.activation_qtype = activations
self._quantize_hooks = {}
if activations is not None:
if quantize_input:
self._quantize_hooks["input"] = self.register_forward_pre_hook(self.quantize_input)
self._quantize_hooks["output"] = self.register_forward_hook(self.quantize_output)
self.optimizer = optimizer
self.register_buffer("input_scale", torch.ones(()))
self.register_buffer("output_scale", torch.ones(()))

def disable_output_quantization(self):
if "output" in self._quantize_hooks:
self._quantize_hooks["output"].remove()

def _save_to_state_dict(self, destination, prefix, keep_vars):
if self.weight_qtype is None or not self.frozen:
# Save standard weight Tensor
Expand Down Expand Up @@ -230,21 +240,25 @@ def qweight(self):
def qforward(self, input: torch.Tensor) -> torch.Tensor:
raise NotImplementedError

def forward(self, input: torch.Tensor) -> torch.Tensor:
def maybe_requantize(t, scale):
if t.qtype == self.activation_qtype and t.axis is None:
return t
return quantize_activation(t.dequantize(), qtype=self.activation_qtype, scale=scale)

if self.activation_qtype is not None and isinstance(input, QBytesTensor):
input = maybe_requantize(input, self.input_scale)
output = self.qforward(input)
if self.activation_qtype is not None:
if isinstance(output, QBytesTensor):
output = maybe_requantize(output, self.output_scale)
else:
output = quantize_activation(output, qtype=self.activation_qtype, scale=self.output_scale)
return output
def quantize_input(self, module: torch.nn.Module, input: torch.Tensor) -> torch.Tensor:
input = input[0]
if isinstance(input, QBytesTensor):
if input.qtype != self.activation_qtype:
raise ValueError(
"Models with heterogeneous quantized activations are not supported:"
f" expected {self.activation_qtype.name} input but got {input.qtype.name} instead."
)
else:
input = quantize_activation(input, qtype=self.activation_qtype, scale=self.input_scale)
return input

def quantize_output(
self,
module: torch.nn.Module,
input: torch.Tensor,
output: torch.Tensor,
) -> torch.Tensor:
return quantize_activation(output, qtype=self.activation_qtype, scale=self.output_scale)

def freeze(self):
qweight = self.qweight
Expand Down
6 changes: 4 additions & 2 deletions test/nn/test_calibrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@
def _test_calibrate_qlinear(batch_size, tokens, embeddings, use_bias, activations, device):
linear = torch.nn.Linear(embeddings, embeddings, bias=use_bias).to(device)
qlinear = QLinear.from_module(linear, weights=qint8, activations=activations)
qinputs = random_qactivation((batch_size,) + (tokens, embeddings), dtype=torch.float32).to(device)
qinputs = random_qactivation(
(batch_size, tokens, embeddings), qtype=activations, dtype=torch.float32, device=device
)
# Run a first inference without Calibration
with torch.no_grad():
qout = qlinear(qinputs)
Expand Down Expand Up @@ -73,7 +75,7 @@ def forward(self, input):
model = TwoLinearModel(embeddings).to(device)
model.linear1 = QLinear.from_module(model.linear1, weights=qint8, activations=activations)
model.linear2 = QLinear.from_module(model.linear2, weights=qint8, activations=activations)
qinputs = random_qactivation((1,) + (tokens, embeddings), dtype=torch.float32).to(device)
qinputs = random_qactivation((1, tokens, embeddings), qtype=activations, dtype=torch.float32, device=device)
with torch.no_grad(), Calibration():
qout = model(qinputs)
assert torch.any(model.linear1.input_scale != 1)
Expand Down
8 changes: 4 additions & 4 deletions test/nn/test_qconv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

import pytest
import torch
from helpers import assert_similar, random_qactivation
from helpers import assert_similar, random_qactivation, random_tensor

from optimum.quanto import Calibration, QBytesTensor, qfloat8_e4m3fn, qfloat8_e5m2, qint4, qint8
from optimum.quanto.nn import QConv2d
Expand All @@ -24,16 +24,16 @@ def _test_quantize_conv2d(batch_size, img_shape, out_channels, use_bias, weights
conv2d = torch.nn.Conv2d(img_shape[0], out_channels, kernel_size=3, bias=use_bias).to(dtype).to(device)
qconv2d = QConv2d.from_module(conv2d, weights=weights, activations=activations)
assert qconv2d.qweight.qtype == weights
qinputs = random_qactivation((batch_size,) + img_shape, dtype=dtype).to(device)
inputs = random_tensor((batch_size,) + img_shape, dtype=dtype, device=device)
# Run an inference with Calibration to get the correct output dtype
with torch.no_grad(), Calibration():
qout = qconv2d(qinputs)
qout = qconv2d(inputs)
if activations is not None:
assert isinstance(qout, QBytesTensor)
assert qout.qtype == activations
# Align weights with quantized linear weights for comparison
conv2d.weight = torch.nn.Parameter(qconv2d.qweight.dequantize())
out = conv2d(qinputs.dequantize())
out = conv2d(inputs)
# We need to increase atol for float16 dtype
dtype_atol = {torch.float32: 1e-4, torch.float16: 1e-3}[dtype]
# We also need to increase atol for float8 itypes
Expand Down
16 changes: 10 additions & 6 deletions test/nn/test_qlinear.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,12 @@ def _test_quantize_linear(batch_size, tokens, embeddings, use_bias, weights, act
linear = torch.nn.Linear(embeddings, embeddings, bias=use_bias).to(dtype).to(device)
qlinear = QLinear.from_module(linear, weights=weights, activations=activations)
assert qlinear.qweight.qtype == weights
qinputs = random_qactivation((batch_size,) + (tokens, embeddings), dtype=dtype).to(device)
inputs = qinputs.dequantize()
input_shape = (batch_size, tokens, embeddings)
if activations is not None:
qinputs = random_qactivation(input_shape, qtype=activations, dtype=dtype).to(device)
inputs = qinputs.dequantize()
else:
inputs = random_tensor(input_shape, dtype=dtype, device=device)
# Run an inference with Calibration to get the correct output dtype
context = nullcontext if activations is None else Calibration
with torch.no_grad(), context():
Expand Down Expand Up @@ -79,8 +83,8 @@ def test_quantize_linear_float32_activations_int8(batch_size, tokens, embeddings
@pytest.mark.parametrize("weights", [qint4, qint8], ids=["w-qint4", "w-qint8"])
@pytest.mark.parametrize(
"activations",
[qfloat8_e5m2, qfloat8_e4m3fn],
ids=["a-qfloat8-e5m2", "a-qfloat8-e4m3"],
[qfloat8_e4m3fn],
ids=["a-qfloat8-e4m3"],
)
@pytest.mark.skip_device("mps")
def test_quantize_linear_float16_activations_float8(
Expand Down Expand Up @@ -132,8 +136,8 @@ def test_qlinear_gradient(tokens, embeddings, activations, weights, device):
qlinear = QLinear.from_module(linear, weights=weights, activations=activations)
assert qlinear.weight.requires_grad is True
assert qlinear.bias.requires_grad is True
# Run an inference with quantized inputs
inputs = random_tensor((batch_size,) + (tokens, embeddings), dtype=torch.float32).to(device)
# Run an inference with dynamically quantized inputs
inputs = random_tensor((batch_size, tokens, embeddings), dtype=torch.float32, device=device)
inputs.requires_grad = True
qinputs = quantize_activation(inputs, qtype=qint8, scale=absmax_scale(inputs, qint8))
qout = qlinear(qinputs)
Expand Down
Loading
Loading