From d33a775d54b282fd45716669cd5a46b403fa9d64 Mon Sep 17 00:00:00 2001 From: David Corvoysier Date: Fri, 19 Jul 2024 13:38:04 +0200 Subject: [PATCH 1/9] test(quantize): fix bogus test --- test/quantize/test_quantize_mlp.py | 38 +++++++++++++++++------------- 1 file changed, 22 insertions(+), 16 deletions(-) diff --git a/test/quantize/test_quantize_mlp.py b/test/quantize/test_quantize_mlp.py index 8518ba44..80275da9 100644 --- a/test/quantize/test_quantize_mlp.py +++ b/test/quantize/test_quantize_mlp.py @@ -12,10 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. +from contextlib import nullcontext import pytest import torch -from helpers import assert_similar, get_device_memory, random_qactivation +from helpers import assert_similar, get_device_memory, random_tensor from optimum.quanto import ( AbsmaxOptimizer, @@ -24,12 +25,14 @@ QBytesTensor, QLinear, QTensor, + absmax_scale, freeze, qfloat8_e4m3fn, qfloat8_e5m2, qint4, qint8, quantize, + quantize_activation, ) @@ -56,24 +59,25 @@ def check_mlp(model, frozen): assert isinstance(model.output_layer.weight, QTensor) -def get_outputs(model, batch_size, input_features, device): - qinputs = random_qactivation((batch_size, input_features), dtype=torch.float32).to(device) - return model(qinputs) - - -def _test_quantize_mlp(weights, activations, optimizer, frozen, device): +def _test_quantize_mlp(weights, activations, optimizer, frozen, device, atol=1e-6): model = MLP(32, 10, 128).to(device) - output = get_outputs(model, 1, 32, device) + inputs = random_tensor((1, 32), dtype=torch.float32, device=device) + output = model(inputs) quantize(model, weights=weights, activations=activations, optimizer=optimizer) if frozen: freeze(model) check_mlp(model, frozen) - with Calibration(): - qoutput = get_outputs(model, 1, 32, device) + if activations is not None: + inputs = quantize_activation(inputs, qtype=activations, scale=absmax_scale(inputs)) + context = Calibration + else: + context = nullcontext + with context(): + qoutput = model(inputs) if activations is not None: assert isinstance(qoutput, QBytesTensor) # Don't expect more than a 0.99 similarity - assert_similar(output, qoutput, atol=1e-2) + assert_similar(output, qoutput, atol=atol) @pytest.mark.parametrize("weights", [qint8], ids=["w-qint8"]) @@ -86,19 +90,20 @@ def test_quantize_mlp_weights_only(weights, frozen, device): @pytest.mark.parametrize("frozen", [True, False], ids=["frozen", "non-frozen"]) @pytest.mark.skip_device("mps") def test_quantize_mlp_int8_activations(weights, frozen, device): - _test_quantize_mlp(weights, qint8, None, frozen, device) + _test_quantize_mlp(weights, qint8, None, frozen, device, atol=1e-3) @pytest.mark.parametrize("weights", [qint8], ids=["w-qint8"]) @pytest.mark.parametrize( "activations", - [None, qint8, qfloat8_e5m2, qfloat8_e4m3fn], - ids=["a-float", "a-qint8", "a-qfloat8-e5m2", "a-qfloat8-e4m3"], + [qfloat8_e5m2, qfloat8_e4m3fn], + ids=["a-qfloat8-e5m2", "a-qfloat8-e4m3"], ) @pytest.mark.parametrize("frozen", [True, False], ids=["frozen", "non-frozen"]) @pytest.mark.skip_device("mps") def test_quantize_mlp_float8_activations(weights, activations, frozen, device): - _test_quantize_mlp(weights, activations, None, frozen, device) + atol = {qfloat8_e4m3fn: 1e-3, qfloat8_e5m2: 1e-2}[activations] + _test_quantize_mlp(weights, activations, None, frozen, device, atol=atol) @pytest.mark.skip_device("cpu") @@ -126,7 +131,8 @@ def test_quantized_mlp_device_memory(weights, dtype, weights_only, device): ) @pytest.mark.parametrize("frozen", [True, False], ids=["frozen", "non-frozen"]) def test_quantize_mlp_weights_only_optimizers(weights, optimizer, frozen, device): - _test_quantize_mlp(weights, None, optimizer, frozen, device) + atol = {qint4: 1e-4, qint8: 1e-6}[weights] + _test_quantize_mlp(weights, None, optimizer, frozen, device, atol=atol) @pytest.mark.parametrize( From 7e42886903b65a919ef37a0e3b9d6edea2cc07cb Mon Sep 17 00:00:00 2001 From: David Corvoysier Date: Fri, 19 Jul 2024 14:20:39 +0200 Subject: [PATCH 2/9] test(QConv2D): use float inputs --- test/nn/test_qconv2d.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/test/nn/test_qconv2d.py b/test/nn/test_qconv2d.py index aed56cd8..7d4273b5 100644 --- a/test/nn/test_qconv2d.py +++ b/test/nn/test_qconv2d.py @@ -14,7 +14,7 @@ import pytest import torch -from helpers import assert_similar, random_qactivation +from helpers import assert_similar, random_qactivation, random_tensor from optimum.quanto import Calibration, QBytesTensor, qfloat8_e4m3fn, qfloat8_e5m2, qint4, qint8 from optimum.quanto.nn import QConv2d @@ -24,16 +24,16 @@ def _test_quantize_conv2d(batch_size, img_shape, out_channels, use_bias, weights conv2d = torch.nn.Conv2d(img_shape[0], out_channels, kernel_size=3, bias=use_bias).to(dtype).to(device) qconv2d = QConv2d.from_module(conv2d, weights=weights, activations=activations) assert qconv2d.qweight.qtype == weights - qinputs = random_qactivation((batch_size,) + img_shape, dtype=dtype).to(device) + inputs = random_tensor((batch_size,) + img_shape, dtype=dtype, device=device) # Run an inference with Calibration to get the correct output dtype with torch.no_grad(), Calibration(): - qout = qconv2d(qinputs) + qout = qconv2d(inputs) if activations is not None: assert isinstance(qout, QBytesTensor) assert qout.qtype == activations # Align weights with quantized linear weights for comparison conv2d.weight = torch.nn.Parameter(qconv2d.qweight.dequantize()) - out = conv2d(qinputs.dequantize()) + out = conv2d(inputs) # We need to increase atol for float16 dtype dtype_atol = {torch.float32: 1e-4, torch.float16: 1e-3}[dtype] # We also need to increase atol for float8 itypes From 3f47114a87c202267f02ed8570a9f5098aa6700f Mon Sep 17 00:00:00 2001 From: David Corvoysier Date: Fri, 19 Jul 2024 14:37:40 +0200 Subject: [PATCH 3/9] test(qlinear): avoid passing quantized inputs when activation is none This reveals an overflow in qfloat8_e5m2 activations test, that is removed for now. --- test/nn/test_qlinear.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/test/nn/test_qlinear.py b/test/nn/test_qlinear.py index 2ff5bffe..86c82ae5 100644 --- a/test/nn/test_qlinear.py +++ b/test/nn/test_qlinear.py @@ -38,8 +38,12 @@ def _test_quantize_linear(batch_size, tokens, embeddings, use_bias, weights, act linear = torch.nn.Linear(embeddings, embeddings, bias=use_bias).to(dtype).to(device) qlinear = QLinear.from_module(linear, weights=weights, activations=activations) assert qlinear.qweight.qtype == weights - qinputs = random_qactivation((batch_size,) + (tokens, embeddings), dtype=dtype).to(device) - inputs = qinputs.dequantize() + input_shape = (batch_size, tokens, embeddings) + if activations is not None: + qinputs = random_qactivation(input_shape, qtype=activations, dtype=dtype).to(device) + inputs = qinputs.dequantize() + else: + inputs = random_tensor(input_shape, dtype=dtype, device=device) # Run an inference with Calibration to get the correct output dtype context = nullcontext if activations is None else Calibration with torch.no_grad(), context(): @@ -79,8 +83,8 @@ def test_quantize_linear_float32_activations_int8(batch_size, tokens, embeddings @pytest.mark.parametrize("weights", [qint4, qint8], ids=["w-qint4", "w-qint8"]) @pytest.mark.parametrize( "activations", - [qfloat8_e5m2, qfloat8_e4m3fn], - ids=["a-qfloat8-e5m2", "a-qfloat8-e4m3"], + [qfloat8_e4m3fn], + ids=["a-qfloat8-e4m3"], ) @pytest.mark.skip_device("mps") def test_quantize_linear_float16_activations_float8( @@ -132,8 +136,8 @@ def test_qlinear_gradient(tokens, embeddings, activations, weights, device): qlinear = QLinear.from_module(linear, weights=weights, activations=activations) assert qlinear.weight.requires_grad is True assert qlinear.bias.requires_grad is True - # Run an inference with quantized inputs - inputs = random_tensor((batch_size,) + (tokens, embeddings), dtype=torch.float32).to(device) + # Run an inference with dynamically quantized inputs + inputs = random_tensor((batch_size, tokens, embeddings), dtype=torch.float32, device=device) inputs.requires_grad = True qinputs = quantize_activation(inputs, qtype=qint8, scale=absmax_scale(inputs, qint8)) qout = qlinear(qinputs) From 4696e8124b35be0bb10ebb8ca9cf16c6f92d2156 Mon Sep 17 00:00:00 2001 From: David Corvoysier Date: Fri, 19 Jul 2024 15:47:40 +0200 Subject: [PATCH 4/9] test(calibrate): use correct activation qtype --- test/nn/test_calibrate.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/test/nn/test_calibrate.py b/test/nn/test_calibrate.py index 25167239..12ca9252 100644 --- a/test/nn/test_calibrate.py +++ b/test/nn/test_calibrate.py @@ -23,7 +23,9 @@ def _test_calibrate_qlinear(batch_size, tokens, embeddings, use_bias, activations, device): linear = torch.nn.Linear(embeddings, embeddings, bias=use_bias).to(device) qlinear = QLinear.from_module(linear, weights=qint8, activations=activations) - qinputs = random_qactivation((batch_size,) + (tokens, embeddings), dtype=torch.float32).to(device) + qinputs = random_qactivation( + (batch_size, tokens, embeddings), qtype=activations, dtype=torch.float32, device=device + ) # Run a first inference without Calibration with torch.no_grad(): qout = qlinear(qinputs) @@ -73,7 +75,7 @@ def forward(self, input): model = TwoLinearModel(embeddings).to(device) model.linear1 = QLinear.from_module(model.linear1, weights=qint8, activations=activations) model.linear2 = QLinear.from_module(model.linear2, weights=qint8, activations=activations) - qinputs = random_qactivation((1,) + (tokens, embeddings), dtype=torch.float32).to(device) + qinputs = random_qactivation((1, tokens, embeddings), qtype=activations, dtype=torch.float32, device=device) with torch.no_grad(), Calibration(): qout = model(qinputs) assert torch.any(model.linear1.input_scale != 1) From ed3cc7128f60bd18f1d9225cf3072d9f2cf13782 Mon Sep 17 00:00:00 2001 From: David Corvoysier Date: Fri, 19 Jul 2024 12:25:19 +0200 Subject: [PATCH 5/9] refactor(QModuleMixin): quantize inputs if needed Only QLinear might request its inputs to be always quantized, as it is the only layer for which an optimized kernel exists. --- optimum/quanto/nn/qconv2d.py | 6 +----- optimum/quanto/nn/qlinear.py | 7 ++----- optimum/quanto/nn/qmodule.py | 9 +++++++-- 3 files changed, 10 insertions(+), 12 deletions(-) diff --git a/optimum/quanto/nn/qconv2d.py b/optimum/quanto/nn/qconv2d.py index fc5bcb0d..86a5eb83 100644 --- a/optimum/quanto/nn/qconv2d.py +++ b/optimum/quanto/nn/qconv2d.py @@ -16,7 +16,7 @@ import torch -from ..tensor import Optimizer, QBytesTensor, qtype, quantize_activation +from ..tensor import Optimizer, qtype from .qmodule import QModuleMixin, register_qmodule @@ -47,8 +47,4 @@ def qcreate( ) def qforward(self, input: torch.Tensor) -> torch.Tensor: - if self.activation_qtype is not None and not isinstance(input, QBytesTensor): - # Quantize tensor to be able to take advantage of accelerated conv2d - input = quantize_activation(input, qtype=self.activation_qtype, scale=self.input_scale) - # We always use quantized weights return self._conv_forward(input, self.qweight, self.bias) diff --git a/optimum/quanto/nn/qlinear.py b/optimum/quanto/nn/qlinear.py index 92b74cac..f189f07f 100644 --- a/optimum/quanto/nn/qlinear.py +++ b/optimum/quanto/nn/qlinear.py @@ -16,7 +16,7 @@ import torch -from ..tensor import Optimizer, QBytesTensor, qtype, quantize_activation +from ..tensor import Optimizer, qtype from .qmodule import QModuleMixin, register_qmodule @@ -38,11 +38,8 @@ def qcreate( weights=weights, activations=activations, optimizer=optimizer, + quantize_input=True, ) def qforward(self, input: torch.Tensor) -> torch.Tensor: - if self.activation_qtype is not None and not isinstance(input, QBytesTensor): - # Quantize activations to be able to take advantage of accelerated matmul - input = quantize_activation(input, qtype=self.activation_qtype, scale=self.input_scale) - # We always use quantized weights return torch.nn.functional.linear(input, self.qweight, bias=self.bias) diff --git a/optimum/quanto/nn/qmodule.py b/optimum/quanto/nn/qmodule.py index d754444e..ac86974d 100644 --- a/optimum/quanto/nn/qmodule.py +++ b/optimum/quanto/nn/qmodule.py @@ -94,6 +94,7 @@ def __init__( weights: Optional[Union[qtype, str]] = None, activations: Optional[Union[qtype, str]] = None, optimizer: Optional[Optimizer] = None, + quantize_input: Optional[bool] = False, **kwargs, ): # The tests below are meant to help people writing their own quantized Module class @@ -122,6 +123,7 @@ def __init__( if in_features % group_size == 0: self.weight_group_size = group_size self.activation_qtype = activations + self.quantize_input = quantize_input self.optimizer = optimizer self.register_buffer("input_scale", torch.ones(())) self.register_buffer("output_scale", torch.ones(())) @@ -236,8 +238,11 @@ def maybe_requantize(t, scale): return t return quantize_activation(t.dequantize(), qtype=self.activation_qtype, scale=scale) - if self.activation_qtype is not None and isinstance(input, QBytesTensor): - input = maybe_requantize(input, self.input_scale) + if self.activation_qtype is not None: + if isinstance(input, QBytesTensor): + input = maybe_requantize(input, self.input_scale) + elif self.quantize_input: + input = quantize_activation(input, qtype=self.activation_qtype, scale=self.input_scale) output = self.qforward(input) if self.activation_qtype is not None: if isinstance(output, QBytesTensor): From 3a4acd0aafda35b039a011b224d13b4c3941370a Mon Sep 17 00:00:00 2001 From: David Corvoysier Date: Fri, 19 Jul 2024 15:11:40 +0200 Subject: [PATCH 6/9] feat(QModuleMixin): sanitize inputs --- optimum/quanto/nn/qmodule.py | 34 ++++++++++++++++++++-------------- 1 file changed, 20 insertions(+), 14 deletions(-) diff --git a/optimum/quanto/nn/qmodule.py b/optimum/quanto/nn/qmodule.py index ac86974d..0089a6a6 100644 --- a/optimum/quanto/nn/qmodule.py +++ b/optimum/quanto/nn/qmodule.py @@ -123,7 +123,7 @@ def __init__( if in_features % group_size == 0: self.weight_group_size = group_size self.activation_qtype = activations - self.quantize_input = quantize_input + self._quantize_input = quantize_input self.optimizer = optimizer self.register_buffer("input_scale", torch.ones(())) self.register_buffer("output_scale", torch.ones(())) @@ -232,25 +232,31 @@ def qweight(self): def qforward(self, input: torch.Tensor) -> torch.Tensor: raise NotImplementedError - def forward(self, input: torch.Tensor) -> torch.Tensor: - def maybe_requantize(t, scale): - if t.qtype == self.activation_qtype and t.axis is None: - return t - return quantize_activation(t.dequantize(), qtype=self.activation_qtype, scale=scale) - + def quantize_input(self, input: torch.Tensor) -> torch.Tensor: if self.activation_qtype is not None: if isinstance(input, QBytesTensor): - input = maybe_requantize(input, self.input_scale) - elif self.quantize_input: + if input.qtype != self.activation_qtype: + raise ValueError( + "Models with heterogeneous quantized activations are not supported:" + f" expected {self.activation_qtype.name} input but got {input.qtype.name} instead." + ) + elif self._quantize_input: input = quantize_activation(input, qtype=self.activation_qtype, scale=self.input_scale) - output = self.qforward(input) + return input + + def quantize_output( + self, + output: torch.Tensor, + ) -> torch.Tensor: if self.activation_qtype is not None: - if isinstance(output, QBytesTensor): - output = maybe_requantize(output, self.output_scale) - else: - output = quantize_activation(output, qtype=self.activation_qtype, scale=self.output_scale) + output = quantize_activation(output, qtype=self.activation_qtype, scale=self.output_scale) return output + def forward(self, input: torch.Tensor) -> torch.Tensor: + input = self.quantize_input(input) + output = self.qforward(input) + return self.quantize_output(output) + def freeze(self): qweight = self.qweight if qweight is not None: From abf97a9549fffd8b4f1dedebe38956acfbdbe194 Mon Sep 17 00:00:00 2001 From: David Corvoysier Date: Fri, 19 Jul 2024 12:11:00 +0200 Subject: [PATCH 7/9] refactor(QModuleMixin): avoid multiple forward calls in calibration By putting the input/output quantization code inside module forward hooks, it allows them to be called only after the calibration hooks. This simplifies a lot the calibration code, in particular avoiding several calls to forward during output calibration. --- optimum/quanto/calibrate.py | 50 ++++++++++++++++++++++++++---------- optimum/quanto/nn/qmodule.py | 41 +++++++++++++++++------------ 2 files changed, 61 insertions(+), 30 deletions(-) diff --git a/optimum/quanto/calibrate.py b/optimum/quanto/calibrate.py index 077a5990..7055314a 100644 --- a/optimum/quanto/calibrate.py +++ b/optimum/quanto/calibrate.py @@ -84,6 +84,7 @@ def __init__(self, *args, momentum: float = 0.9, streamline=True, debug=False, * self.streamline = streamline if streamline: self.modules_qactivations = {} + self.streamline_hooks = {} self.debug = debug def __torch_function__(self, func, types, args=(), kwargs=None): @@ -112,8 +113,15 @@ def __exit__(self, exc_type, exc_val, exc_tb): super().__exit__(exc_type, exc_val, exc_tb) self.pre_handle.remove() self.post_handle.remove() + if self.streamline: + for handle in self.streamline_hooks.values(): + handle.remove() def calibrate_input(self, module: torch.nn.Module, input, momentum: float = 0.9): + """Calibrate a module input scale + + This is registered as a global hook that is called before any module forward pre hook. + """ if isinstance(module, QModuleMixin) and module.activation_qtype is not None: input = input[0] if isinstance(input, QBytesTensor): @@ -123,27 +131,28 @@ def calibrate_input(self, module: torch.nn.Module, input, momentum: float = 0.9) # Evaluate the best scale input_scale = absmax_scale(input, module.activation_qtype) module.input_scale = _updated_scale(module.input_scale, input_scale, momentum) + if self.streamline and module not in self.streamline_hooks: + # Add a hook to tag the module outputs (after the module quantization hook in QModuleMixin) + self.streamline_hooks[module] = module.register_forward_hook(self.tag_outputs) return input def calibrate_output( self, module: torch.nn.Module, - input, - output, + input: torch.Tensor, + output: torch.Tensor, ): + """Calibrate a module output scale + + This is registered as a global hook that is called before any module forward hook. + + When the module is a QModuleMixin, its outputs are not quantized yet because they + are only quantized in the QModuleMixin.quantize_output forward hook. + """ if isinstance(module, (QModuleMixin)) and module.activation_qtype is not None: - # Re-evaluate raw module output - qoutput = module.qforward(input[0]) - if isinstance(qoutput, QBytesTensor): - qoutput = qoutput.dequantize() # Evaluate the optimal scale per-tensor and update output scale - output_scale = absmax_scale(qoutput, module.activation_qtype, axis=None) + output_scale = absmax_scale(output, module.activation_qtype, axis=None) module.output_scale = _updated_scale(module.output_scale, output_scale, self.momentum) - # Re-evaluate output with the correct output scale - output = module.forward(input[0]) - if isinstance(output, QBytesTensor): - # Mark the outputs as generated by this module - output.src_module = module return output else: if self.streamline: @@ -152,7 +161,7 @@ def calibrate_output( qactivations_required = self.modules_qactivations.get(child, False) if not qactivations_required: # Disable activations for this child as its outputs are only consumed by incompatible functions. - child.activation_qtype = None + child.disable_activation_quantization() if self.debug: for name, child in module.named_children(): if isinstance(child, QModuleMixin): @@ -163,3 +172,18 @@ def calibrate_output( else: trace += f" quantized to {child.activation_qtype} with scale {child.output_scale}." print(trace) + + def tag_outputs( + self, + module: torch.nn.Module, + input: torch.Tensor, + output: torch.Tensor, + ): + """Mark outputs as generated by a module + + This is called as a module forward hook that is called after the QModuleMixin.quantize_output + forward hook. + + This is useful in streamline mode to identify the module that generated a specific QTensor. + """ + output.src_module = module diff --git a/optimum/quanto/nn/qmodule.py b/optimum/quanto/nn/qmodule.py index 0089a6a6..7b2229f0 100644 --- a/optimum/quanto/nn/qmodule.py +++ b/optimum/quanto/nn/qmodule.py @@ -123,11 +123,20 @@ def __init__( if in_features % group_size == 0: self.weight_group_size = group_size self.activation_qtype = activations - self._quantize_input = quantize_input + self._quantize_hooks = [] + if activations is not None: + if quantize_input: + self._quantize_hooks.append(self.register_forward_pre_hook(self.quantize_input)) + self._quantize_hooks.append(self.register_forward_hook(self.quantize_output)) self.optimizer = optimizer self.register_buffer("input_scale", torch.ones(())) self.register_buffer("output_scale", torch.ones(())) + def disable_activation_quantization(self): + for hook in self._quantize_hooks: + hook.remove() + self.activation_qtype = None + def _save_to_state_dict(self, destination, prefix, keep_vars): if self.weight_qtype is None or not self.frozen: # Save standard weight Tensor @@ -232,30 +241,28 @@ def qweight(self): def qforward(self, input: torch.Tensor) -> torch.Tensor: raise NotImplementedError - def quantize_input(self, input: torch.Tensor) -> torch.Tensor: - if self.activation_qtype is not None: - if isinstance(input, QBytesTensor): - if input.qtype != self.activation_qtype: - raise ValueError( - "Models with heterogeneous quantized activations are not supported:" - f" expected {self.activation_qtype.name} input but got {input.qtype.name} instead." - ) - elif self._quantize_input: - input = quantize_activation(input, qtype=self.activation_qtype, scale=self.input_scale) + def quantize_input(self, module: torch.nn.Module, input: torch.Tensor) -> torch.Tensor: + input = input[0] + if isinstance(input, QBytesTensor): + if input.qtype != self.activation_qtype: + raise ValueError( + "Models with heterogeneous quantized activations are not supported:" + f" expected {self.activation_qtype.name} input but got {input.qtype.name} instead." + ) + else: + input = quantize_activation(input, qtype=self.activation_qtype, scale=self.input_scale) return input def quantize_output( self, + module: torch.nn.Module, + input: torch.Tensor, output: torch.Tensor, ) -> torch.Tensor: - if self.activation_qtype is not None: - output = quantize_activation(output, qtype=self.activation_qtype, scale=self.output_scale) - return output + return quantize_activation(output, qtype=self.activation_qtype, scale=self.output_scale) def forward(self, input: torch.Tensor) -> torch.Tensor: - input = self.quantize_input(input) - output = self.qforward(input) - return self.quantize_output(output) + return self.qforward(input) def freeze(self): qweight = self.qweight From d783a3bd39b78eb6068e84e4407be43794c2a4ca Mon Sep 17 00:00:00 2001 From: David Corvoysier Date: Fri, 19 Jul 2024 14:53:56 +0000 Subject: [PATCH 8/9] fix(calibration): only disable output quantization when streamlining --- optimum/quanto/calibrate.py | 4 ++-- optimum/quanto/nn/qmodule.py | 13 ++++++------- 2 files changed, 8 insertions(+), 9 deletions(-) diff --git a/optimum/quanto/calibrate.py b/optimum/quanto/calibrate.py index 7055314a..2d5bfc49 100644 --- a/optimum/quanto/calibrate.py +++ b/optimum/quanto/calibrate.py @@ -160,8 +160,8 @@ def calibrate_output( if isinstance(child, QModuleMixin) and child.activation_qtype is not None: qactivations_required = self.modules_qactivations.get(child, False) if not qactivations_required: - # Disable activations for this child as its outputs are only consumed by incompatible functions. - child.disable_activation_quantization() + # Disable output quantization for this child as its outputs are only consumed by incompatible functions. + child.disable_output_quantization() if self.debug: for name, child in module.named_children(): if isinstance(child, QModuleMixin): diff --git a/optimum/quanto/nn/qmodule.py b/optimum/quanto/nn/qmodule.py index 7b2229f0..fbf1ebf9 100644 --- a/optimum/quanto/nn/qmodule.py +++ b/optimum/quanto/nn/qmodule.py @@ -123,19 +123,18 @@ def __init__( if in_features % group_size == 0: self.weight_group_size = group_size self.activation_qtype = activations - self._quantize_hooks = [] + self._quantize_hooks = {} if activations is not None: if quantize_input: - self._quantize_hooks.append(self.register_forward_pre_hook(self.quantize_input)) - self._quantize_hooks.append(self.register_forward_hook(self.quantize_output)) + self._quantize_hooks["input"] = self.register_forward_pre_hook(self.quantize_input) + self._quantize_hooks["output"] = self.register_forward_hook(self.quantize_output) self.optimizer = optimizer self.register_buffer("input_scale", torch.ones(())) self.register_buffer("output_scale", torch.ones(())) - def disable_activation_quantization(self): - for hook in self._quantize_hooks: - hook.remove() - self.activation_qtype = None + def disable_output_quantization(self): + if "output" in self._quantize_hooks: + self._quantize_hooks["output"].remove() def _save_to_state_dict(self, destination, prefix, keep_vars): if self.weight_qtype is None or not self.frozen: From 59b9983741ae92ab1884e9eb221bfe7af42644b9 Mon Sep 17 00:00:00 2001 From: David Corvoysier Date: Fri, 19 Jul 2024 15:10:56 +0000 Subject: [PATCH 9/9] refactor(QModuleMixin): remove forward indirection --- optimum/quanto/nn/qconv2d.py | 2 +- optimum/quanto/nn/qlayernorm.py | 2 +- optimum/quanto/nn/qlinear.py | 2 +- optimum/quanto/nn/qmodule.py | 7 ++----- 4 files changed, 5 insertions(+), 8 deletions(-) diff --git a/optimum/quanto/nn/qconv2d.py b/optimum/quanto/nn/qconv2d.py index 86a5eb83..f2ddd892 100644 --- a/optimum/quanto/nn/qconv2d.py +++ b/optimum/quanto/nn/qconv2d.py @@ -46,5 +46,5 @@ def qcreate( optimizer=optimizer, ) - def qforward(self, input: torch.Tensor) -> torch.Tensor: + def forward(self, input: torch.Tensor) -> torch.Tensor: return self._conv_forward(input, self.qweight, self.bias) diff --git a/optimum/quanto/nn/qlayernorm.py b/optimum/quanto/nn/qlayernorm.py index cfa844b0..48bd6c0a 100644 --- a/optimum/quanto/nn/qlayernorm.py +++ b/optimum/quanto/nn/qlayernorm.py @@ -47,5 +47,5 @@ def qcreate( optimizer=None, # We never quantize QLayerNorm weights ) - def qforward(self, input: torch.Tensor) -> torch.Tensor: + def forward(self, input: torch.Tensor) -> torch.Tensor: return torch.nn.functional.layer_norm(input, self.normalized_shape, self.weight, self.bias, self.eps) diff --git a/optimum/quanto/nn/qlinear.py b/optimum/quanto/nn/qlinear.py index f189f07f..fc290ba4 100644 --- a/optimum/quanto/nn/qlinear.py +++ b/optimum/quanto/nn/qlinear.py @@ -41,5 +41,5 @@ def qcreate( quantize_input=True, ) - def qforward(self, input: torch.Tensor) -> torch.Tensor: + def forward(self, input: torch.Tensor) -> torch.Tensor: return torch.nn.functional.linear(input, self.qweight, bias=self.bias) diff --git a/optimum/quanto/nn/qmodule.py b/optimum/quanto/nn/qmodule.py index fbf1ebf9..df84e299 100644 --- a/optimum/quanto/nn/qmodule.py +++ b/optimum/quanto/nn/qmodule.py @@ -44,7 +44,7 @@ def register_qmodule(module_cls): The QModule must implement two abstract methods: - qcreate: class method to instantiate a new QModule from an nn.Module, without copying its weights, - - qforward: instance method for quantized inference. + - forward: instance method for quantized inference. The code to register a new module looks like: @@ -61,7 +61,7 @@ def qcreate(cls, optimizer: Optional[Optimizer] = None): ... - def qforward(self, input: torch.Tensor) -> torch.Tensor: + def forward(self, input: torch.Tensor) -> torch.Tensor: ... ``` @@ -260,9 +260,6 @@ def quantize_output( ) -> torch.Tensor: return quantize_activation(output, qtype=self.activation_qtype, scale=self.output_scale) - def forward(self, input: torch.Tensor) -> torch.Tensor: - return self.qforward(input) - def freeze(self): qweight = self.qweight if qweight is not None: