diff --git a/optimum/quanto/calibrate.py b/optimum/quanto/calibrate.py index 077a5990..2d5bfc49 100644 --- a/optimum/quanto/calibrate.py +++ b/optimum/quanto/calibrate.py @@ -84,6 +84,7 @@ def __init__(self, *args, momentum: float = 0.9, streamline=True, debug=False, * self.streamline = streamline if streamline: self.modules_qactivations = {} + self.streamline_hooks = {} self.debug = debug def __torch_function__(self, func, types, args=(), kwargs=None): @@ -112,8 +113,15 @@ def __exit__(self, exc_type, exc_val, exc_tb): super().__exit__(exc_type, exc_val, exc_tb) self.pre_handle.remove() self.post_handle.remove() + if self.streamline: + for handle in self.streamline_hooks.values(): + handle.remove() def calibrate_input(self, module: torch.nn.Module, input, momentum: float = 0.9): + """Calibrate a module input scale + + This is registered as a global hook that is called before any module forward pre hook. + """ if isinstance(module, QModuleMixin) and module.activation_qtype is not None: input = input[0] if isinstance(input, QBytesTensor): @@ -123,27 +131,28 @@ def calibrate_input(self, module: torch.nn.Module, input, momentum: float = 0.9) # Evaluate the best scale input_scale = absmax_scale(input, module.activation_qtype) module.input_scale = _updated_scale(module.input_scale, input_scale, momentum) + if self.streamline and module not in self.streamline_hooks: + # Add a hook to tag the module outputs (after the module quantization hook in QModuleMixin) + self.streamline_hooks[module] = module.register_forward_hook(self.tag_outputs) return input def calibrate_output( self, module: torch.nn.Module, - input, - output, + input: torch.Tensor, + output: torch.Tensor, ): + """Calibrate a module output scale + + This is registered as a global hook that is called before any module forward hook. + + When the module is a QModuleMixin, its outputs are not quantized yet because they + are only quantized in the QModuleMixin.quantize_output forward hook. + """ if isinstance(module, (QModuleMixin)) and module.activation_qtype is not None: - # Re-evaluate raw module output - qoutput = module.qforward(input[0]) - if isinstance(qoutput, QBytesTensor): - qoutput = qoutput.dequantize() # Evaluate the optimal scale per-tensor and update output scale - output_scale = absmax_scale(qoutput, module.activation_qtype, axis=None) + output_scale = absmax_scale(output, module.activation_qtype, axis=None) module.output_scale = _updated_scale(module.output_scale, output_scale, self.momentum) - # Re-evaluate output with the correct output scale - output = module.forward(input[0]) - if isinstance(output, QBytesTensor): - # Mark the outputs as generated by this module - output.src_module = module return output else: if self.streamline: @@ -151,8 +160,8 @@ def calibrate_output( if isinstance(child, QModuleMixin) and child.activation_qtype is not None: qactivations_required = self.modules_qactivations.get(child, False) if not qactivations_required: - # Disable activations for this child as its outputs are only consumed by incompatible functions. - child.activation_qtype = None + # Disable output quantization for this child as its outputs are only consumed by incompatible functions. + child.disable_output_quantization() if self.debug: for name, child in module.named_children(): if isinstance(child, QModuleMixin): @@ -163,3 +172,18 @@ def calibrate_output( else: trace += f" quantized to {child.activation_qtype} with scale {child.output_scale}." print(trace) + + def tag_outputs( + self, + module: torch.nn.Module, + input: torch.Tensor, + output: torch.Tensor, + ): + """Mark outputs as generated by a module + + This is called as a module forward hook that is called after the QModuleMixin.quantize_output + forward hook. + + This is useful in streamline mode to identify the module that generated a specific QTensor. + """ + output.src_module = module diff --git a/optimum/quanto/nn/qconv2d.py b/optimum/quanto/nn/qconv2d.py index fc5bcb0d..f2ddd892 100644 --- a/optimum/quanto/nn/qconv2d.py +++ b/optimum/quanto/nn/qconv2d.py @@ -16,7 +16,7 @@ import torch -from ..tensor import Optimizer, QBytesTensor, qtype, quantize_activation +from ..tensor import Optimizer, qtype from .qmodule import QModuleMixin, register_qmodule @@ -46,9 +46,5 @@ def qcreate( optimizer=optimizer, ) - def qforward(self, input: torch.Tensor) -> torch.Tensor: - if self.activation_qtype is not None and not isinstance(input, QBytesTensor): - # Quantize tensor to be able to take advantage of accelerated conv2d - input = quantize_activation(input, qtype=self.activation_qtype, scale=self.input_scale) - # We always use quantized weights + def forward(self, input: torch.Tensor) -> torch.Tensor: return self._conv_forward(input, self.qweight, self.bias) diff --git a/optimum/quanto/nn/qlayernorm.py b/optimum/quanto/nn/qlayernorm.py index cfa844b0..48bd6c0a 100644 --- a/optimum/quanto/nn/qlayernorm.py +++ b/optimum/quanto/nn/qlayernorm.py @@ -47,5 +47,5 @@ def qcreate( optimizer=None, # We never quantize QLayerNorm weights ) - def qforward(self, input: torch.Tensor) -> torch.Tensor: + def forward(self, input: torch.Tensor) -> torch.Tensor: return torch.nn.functional.layer_norm(input, self.normalized_shape, self.weight, self.bias, self.eps) diff --git a/optimum/quanto/nn/qlinear.py b/optimum/quanto/nn/qlinear.py index 92b74cac..fc290ba4 100644 --- a/optimum/quanto/nn/qlinear.py +++ b/optimum/quanto/nn/qlinear.py @@ -16,7 +16,7 @@ import torch -from ..tensor import Optimizer, QBytesTensor, qtype, quantize_activation +from ..tensor import Optimizer, qtype from .qmodule import QModuleMixin, register_qmodule @@ -38,11 +38,8 @@ def qcreate( weights=weights, activations=activations, optimizer=optimizer, + quantize_input=True, ) - def qforward(self, input: torch.Tensor) -> torch.Tensor: - if self.activation_qtype is not None and not isinstance(input, QBytesTensor): - # Quantize activations to be able to take advantage of accelerated matmul - input = quantize_activation(input, qtype=self.activation_qtype, scale=self.input_scale) - # We always use quantized weights + def forward(self, input: torch.Tensor) -> torch.Tensor: return torch.nn.functional.linear(input, self.qweight, bias=self.bias) diff --git a/optimum/quanto/nn/qmodule.py b/optimum/quanto/nn/qmodule.py index d754444e..df84e299 100644 --- a/optimum/quanto/nn/qmodule.py +++ b/optimum/quanto/nn/qmodule.py @@ -44,7 +44,7 @@ def register_qmodule(module_cls): The QModule must implement two abstract methods: - qcreate: class method to instantiate a new QModule from an nn.Module, without copying its weights, - - qforward: instance method for quantized inference. + - forward: instance method for quantized inference. The code to register a new module looks like: @@ -61,7 +61,7 @@ def qcreate(cls, optimizer: Optional[Optimizer] = None): ... - def qforward(self, input: torch.Tensor) -> torch.Tensor: + def forward(self, input: torch.Tensor) -> torch.Tensor: ... ``` @@ -94,6 +94,7 @@ def __init__( weights: Optional[Union[qtype, str]] = None, activations: Optional[Union[qtype, str]] = None, optimizer: Optional[Optimizer] = None, + quantize_input: Optional[bool] = False, **kwargs, ): # The tests below are meant to help people writing their own quantized Module class @@ -122,10 +123,19 @@ def __init__( if in_features % group_size == 0: self.weight_group_size = group_size self.activation_qtype = activations + self._quantize_hooks = {} + if activations is not None: + if quantize_input: + self._quantize_hooks["input"] = self.register_forward_pre_hook(self.quantize_input) + self._quantize_hooks["output"] = self.register_forward_hook(self.quantize_output) self.optimizer = optimizer self.register_buffer("input_scale", torch.ones(())) self.register_buffer("output_scale", torch.ones(())) + def disable_output_quantization(self): + if "output" in self._quantize_hooks: + self._quantize_hooks["output"].remove() + def _save_to_state_dict(self, destination, prefix, keep_vars): if self.weight_qtype is None or not self.frozen: # Save standard weight Tensor @@ -230,21 +240,25 @@ def qweight(self): def qforward(self, input: torch.Tensor) -> torch.Tensor: raise NotImplementedError - def forward(self, input: torch.Tensor) -> torch.Tensor: - def maybe_requantize(t, scale): - if t.qtype == self.activation_qtype and t.axis is None: - return t - return quantize_activation(t.dequantize(), qtype=self.activation_qtype, scale=scale) - - if self.activation_qtype is not None and isinstance(input, QBytesTensor): - input = maybe_requantize(input, self.input_scale) - output = self.qforward(input) - if self.activation_qtype is not None: - if isinstance(output, QBytesTensor): - output = maybe_requantize(output, self.output_scale) - else: - output = quantize_activation(output, qtype=self.activation_qtype, scale=self.output_scale) - return output + def quantize_input(self, module: torch.nn.Module, input: torch.Tensor) -> torch.Tensor: + input = input[0] + if isinstance(input, QBytesTensor): + if input.qtype != self.activation_qtype: + raise ValueError( + "Models with heterogeneous quantized activations are not supported:" + f" expected {self.activation_qtype.name} input but got {input.qtype.name} instead." + ) + else: + input = quantize_activation(input, qtype=self.activation_qtype, scale=self.input_scale) + return input + + def quantize_output( + self, + module: torch.nn.Module, + input: torch.Tensor, + output: torch.Tensor, + ) -> torch.Tensor: + return quantize_activation(output, qtype=self.activation_qtype, scale=self.output_scale) def freeze(self): qweight = self.qweight diff --git a/test/nn/test_calibrate.py b/test/nn/test_calibrate.py index 25167239..12ca9252 100644 --- a/test/nn/test_calibrate.py +++ b/test/nn/test_calibrate.py @@ -23,7 +23,9 @@ def _test_calibrate_qlinear(batch_size, tokens, embeddings, use_bias, activations, device): linear = torch.nn.Linear(embeddings, embeddings, bias=use_bias).to(device) qlinear = QLinear.from_module(linear, weights=qint8, activations=activations) - qinputs = random_qactivation((batch_size,) + (tokens, embeddings), dtype=torch.float32).to(device) + qinputs = random_qactivation( + (batch_size, tokens, embeddings), qtype=activations, dtype=torch.float32, device=device + ) # Run a first inference without Calibration with torch.no_grad(): qout = qlinear(qinputs) @@ -73,7 +75,7 @@ def forward(self, input): model = TwoLinearModel(embeddings).to(device) model.linear1 = QLinear.from_module(model.linear1, weights=qint8, activations=activations) model.linear2 = QLinear.from_module(model.linear2, weights=qint8, activations=activations) - qinputs = random_qactivation((1,) + (tokens, embeddings), dtype=torch.float32).to(device) + qinputs = random_qactivation((1, tokens, embeddings), qtype=activations, dtype=torch.float32, device=device) with torch.no_grad(), Calibration(): qout = model(qinputs) assert torch.any(model.linear1.input_scale != 1) diff --git a/test/nn/test_qconv2d.py b/test/nn/test_qconv2d.py index aed56cd8..7d4273b5 100644 --- a/test/nn/test_qconv2d.py +++ b/test/nn/test_qconv2d.py @@ -14,7 +14,7 @@ import pytest import torch -from helpers import assert_similar, random_qactivation +from helpers import assert_similar, random_qactivation, random_tensor from optimum.quanto import Calibration, QBytesTensor, qfloat8_e4m3fn, qfloat8_e5m2, qint4, qint8 from optimum.quanto.nn import QConv2d @@ -24,16 +24,16 @@ def _test_quantize_conv2d(batch_size, img_shape, out_channels, use_bias, weights conv2d = torch.nn.Conv2d(img_shape[0], out_channels, kernel_size=3, bias=use_bias).to(dtype).to(device) qconv2d = QConv2d.from_module(conv2d, weights=weights, activations=activations) assert qconv2d.qweight.qtype == weights - qinputs = random_qactivation((batch_size,) + img_shape, dtype=dtype).to(device) + inputs = random_tensor((batch_size,) + img_shape, dtype=dtype, device=device) # Run an inference with Calibration to get the correct output dtype with torch.no_grad(), Calibration(): - qout = qconv2d(qinputs) + qout = qconv2d(inputs) if activations is not None: assert isinstance(qout, QBytesTensor) assert qout.qtype == activations # Align weights with quantized linear weights for comparison conv2d.weight = torch.nn.Parameter(qconv2d.qweight.dequantize()) - out = conv2d(qinputs.dequantize()) + out = conv2d(inputs) # We need to increase atol for float16 dtype dtype_atol = {torch.float32: 1e-4, torch.float16: 1e-3}[dtype] # We also need to increase atol for float8 itypes diff --git a/test/nn/test_qlinear.py b/test/nn/test_qlinear.py index 2ff5bffe..86c82ae5 100644 --- a/test/nn/test_qlinear.py +++ b/test/nn/test_qlinear.py @@ -38,8 +38,12 @@ def _test_quantize_linear(batch_size, tokens, embeddings, use_bias, weights, act linear = torch.nn.Linear(embeddings, embeddings, bias=use_bias).to(dtype).to(device) qlinear = QLinear.from_module(linear, weights=weights, activations=activations) assert qlinear.qweight.qtype == weights - qinputs = random_qactivation((batch_size,) + (tokens, embeddings), dtype=dtype).to(device) - inputs = qinputs.dequantize() + input_shape = (batch_size, tokens, embeddings) + if activations is not None: + qinputs = random_qactivation(input_shape, qtype=activations, dtype=dtype).to(device) + inputs = qinputs.dequantize() + else: + inputs = random_tensor(input_shape, dtype=dtype, device=device) # Run an inference with Calibration to get the correct output dtype context = nullcontext if activations is None else Calibration with torch.no_grad(), context(): @@ -79,8 +83,8 @@ def test_quantize_linear_float32_activations_int8(batch_size, tokens, embeddings @pytest.mark.parametrize("weights", [qint4, qint8], ids=["w-qint4", "w-qint8"]) @pytest.mark.parametrize( "activations", - [qfloat8_e5m2, qfloat8_e4m3fn], - ids=["a-qfloat8-e5m2", "a-qfloat8-e4m3"], + [qfloat8_e4m3fn], + ids=["a-qfloat8-e4m3"], ) @pytest.mark.skip_device("mps") def test_quantize_linear_float16_activations_float8( @@ -132,8 +136,8 @@ def test_qlinear_gradient(tokens, embeddings, activations, weights, device): qlinear = QLinear.from_module(linear, weights=weights, activations=activations) assert qlinear.weight.requires_grad is True assert qlinear.bias.requires_grad is True - # Run an inference with quantized inputs - inputs = random_tensor((batch_size,) + (tokens, embeddings), dtype=torch.float32).to(device) + # Run an inference with dynamically quantized inputs + inputs = random_tensor((batch_size, tokens, embeddings), dtype=torch.float32, device=device) inputs.requires_grad = True qinputs = quantize_activation(inputs, qtype=qint8, scale=absmax_scale(inputs, qint8)) qout = qlinear(qinputs) diff --git a/test/quantize/test_quantize_mlp.py b/test/quantize/test_quantize_mlp.py index 8518ba44..80275da9 100644 --- a/test/quantize/test_quantize_mlp.py +++ b/test/quantize/test_quantize_mlp.py @@ -12,10 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. +from contextlib import nullcontext import pytest import torch -from helpers import assert_similar, get_device_memory, random_qactivation +from helpers import assert_similar, get_device_memory, random_tensor from optimum.quanto import ( AbsmaxOptimizer, @@ -24,12 +25,14 @@ QBytesTensor, QLinear, QTensor, + absmax_scale, freeze, qfloat8_e4m3fn, qfloat8_e5m2, qint4, qint8, quantize, + quantize_activation, ) @@ -56,24 +59,25 @@ def check_mlp(model, frozen): assert isinstance(model.output_layer.weight, QTensor) -def get_outputs(model, batch_size, input_features, device): - qinputs = random_qactivation((batch_size, input_features), dtype=torch.float32).to(device) - return model(qinputs) - - -def _test_quantize_mlp(weights, activations, optimizer, frozen, device): +def _test_quantize_mlp(weights, activations, optimizer, frozen, device, atol=1e-6): model = MLP(32, 10, 128).to(device) - output = get_outputs(model, 1, 32, device) + inputs = random_tensor((1, 32), dtype=torch.float32, device=device) + output = model(inputs) quantize(model, weights=weights, activations=activations, optimizer=optimizer) if frozen: freeze(model) check_mlp(model, frozen) - with Calibration(): - qoutput = get_outputs(model, 1, 32, device) + if activations is not None: + inputs = quantize_activation(inputs, qtype=activations, scale=absmax_scale(inputs)) + context = Calibration + else: + context = nullcontext + with context(): + qoutput = model(inputs) if activations is not None: assert isinstance(qoutput, QBytesTensor) # Don't expect more than a 0.99 similarity - assert_similar(output, qoutput, atol=1e-2) + assert_similar(output, qoutput, atol=atol) @pytest.mark.parametrize("weights", [qint8], ids=["w-qint8"]) @@ -86,19 +90,20 @@ def test_quantize_mlp_weights_only(weights, frozen, device): @pytest.mark.parametrize("frozen", [True, False], ids=["frozen", "non-frozen"]) @pytest.mark.skip_device("mps") def test_quantize_mlp_int8_activations(weights, frozen, device): - _test_quantize_mlp(weights, qint8, None, frozen, device) + _test_quantize_mlp(weights, qint8, None, frozen, device, atol=1e-3) @pytest.mark.parametrize("weights", [qint8], ids=["w-qint8"]) @pytest.mark.parametrize( "activations", - [None, qint8, qfloat8_e5m2, qfloat8_e4m3fn], - ids=["a-float", "a-qint8", "a-qfloat8-e5m2", "a-qfloat8-e4m3"], + [qfloat8_e5m2, qfloat8_e4m3fn], + ids=["a-qfloat8-e5m2", "a-qfloat8-e4m3"], ) @pytest.mark.parametrize("frozen", [True, False], ids=["frozen", "non-frozen"]) @pytest.mark.skip_device("mps") def test_quantize_mlp_float8_activations(weights, activations, frozen, device): - _test_quantize_mlp(weights, activations, None, frozen, device) + atol = {qfloat8_e4m3fn: 1e-3, qfloat8_e5m2: 1e-2}[activations] + _test_quantize_mlp(weights, activations, None, frozen, device, atol=atol) @pytest.mark.skip_device("cpu") @@ -126,7 +131,8 @@ def test_quantized_mlp_device_memory(weights, dtype, weights_only, device): ) @pytest.mark.parametrize("frozen", [True, False], ids=["frozen", "non-frozen"]) def test_quantize_mlp_weights_only_optimizers(weights, optimizer, frozen, device): - _test_quantize_mlp(weights, None, optimizer, frozen, device) + atol = {qint4: 1e-4, qint8: 1e-6}[weights] + _test_quantize_mlp(weights, None, optimizer, frozen, device, atol=atol) @pytest.mark.parametrize(