From a3c900ffe1ff4ac7f210481abf27c731e26fa365 Mon Sep 17 00:00:00 2001 From: Kyunggeun Lee Date: Tue, 27 Aug 2024 21:51:35 -0700 Subject: [PATCH] Add more argument sanity checks Signed-off-by: Kyunggeun Lee --- .../v2/quantization/affine/encoding.py | 18 ++++++++++---- .../v2/quantization/affine/quantizer.py | 24 ++++++++++++------- .../affine/test_affine_quantizer.py | 16 +++++++++++++ 3 files changed, 45 insertions(+), 13 deletions(-) diff --git a/TrainingExtensions/torch/src/python/aimet_torch/v2/quantization/affine/encoding.py b/TrainingExtensions/torch/src/python/aimet_torch/v2/quantization/affine/encoding.py index 08405d2e7a..37a931cdc6 100644 --- a/TrainingExtensions/torch/src/python/aimet_torch/v2/quantization/affine/encoding.py +++ b/TrainingExtensions/torch/src/python/aimet_torch/v2/quantization/affine/encoding.py @@ -69,19 +69,19 @@ def __init__(self, scale: torch.Tensor, offset: torch.Tensor, bitwidth: int, sig block_size: Optional[Tuple[int, ...]] = None): ... - def __init__(self, scale: torch.Tensor, offset: torch.Tensor, *args, **kwargs): + def __init__(self, scale: torch.Tensor, offset: torch.Tensor, *args, **kwargs): # pylint: disable=too-many-locals self._scale = scale self._offset = offset full_args = (scale, offset, *args) # Pad positional args with None's such that len(args) == 4 args = tuple(chain(args, repeat(None, 4 - len(args)))) - arg0 = kwargs.get('qmin', kwargs.get('bitwidth', args[0])) - arg1 = kwargs.get('qmax', kwargs.get('signed', args[1])) - symmetry = kwargs.get('symmetry', args[2]) + arg0 = kwargs.pop('qmin', kwargs.pop('bitwidth', args[0])) + arg1 = kwargs.pop('qmax', kwargs.pop('signed', args[1])) + symmetry = kwargs.pop('symmetry', args[2]) if symmetry is None: symmetry = False - block_size = kwargs.get('block_size', args[3]) + block_size = kwargs.pop('block_size', args[3]) if arg1 is None or isinstance(arg1, bool): # (arg0, arg1) == (bitwidth, signed) @@ -98,6 +98,14 @@ def __init__(self, scale: torch.Tensor, offset: torch.Tensor, *args, **kwargs): assert qmin is not None assert qmax is not None + if kwargs: + cls = type(self).__qualname__ + unexpected_keys = ', '.join(kwargs.keys()) + raise TypeError(f"{cls}.__init__ got unexpected keyword argument: {unexpected_keys}") + + if qmin >= qmax: + raise ValueError(f"qmax should be strictly larger than qmin. Got qmax={qmax}, qmin={qmin}") + self.qmin = qmin self.qmax = qmax self._symmetry = symmetry diff --git a/TrainingExtensions/torch/src/python/aimet_torch/v2/quantization/affine/quantizer.py b/TrainingExtensions/torch/src/python/aimet_torch/v2/quantization/affine/quantizer.py index 59153c03ae..c84f3df57c 100644 --- a/TrainingExtensions/torch/src/python/aimet_torch/v2/quantization/affine/quantizer.py +++ b/TrainingExtensions/torch/src/python/aimet_torch/v2/quantization/affine/quantizer.py @@ -97,35 +97,43 @@ def __init__(self, shape, *args, **kwargs): # Pad positional args with None's such that len(args) == 5 args = tuple(chain(args, repeat(None, 5 - len(args)))) - arg0 = kwargs.get('qmin', kwargs.get('bitwidth', args[0])) - arg1 = kwargs.get('qmax', args[1]) + arg0 = kwargs.pop('qmin', kwargs.pop('bitwidth', args[0])) + arg1 = kwargs.pop('qmax', args[1]) if arg1 is not None and not isinstance(arg1, bool): # (arg0, arg1, arg2) == (qmin, qmax, symmetric) qmin, qmax = arg0, arg1 - symmetric = kwargs.get('symmetric', args[2]) + symmetric = kwargs.pop('symmetric', args[2]) if (qmin is None) or (qmax is None) or (symmetric is None): raise self._arg_parsing_error(full_args, kwargs) - encoding_analyzer = kwargs.get('encoding_analyzer', args[3]) - block_size = kwargs.get('block_size', args[4]) + encoding_analyzer = kwargs.pop('encoding_analyzer', args[3]) + block_size = kwargs.pop('block_size', args[4]) else: # (arg0, arg1) == (bitwidth, symmetric) bitwidth = arg0 - symmetric = kwargs.get('symmetric', args[1]) + symmetric = kwargs.pop('symmetric', args[1]) if (bitwidth is None) or (symmetric is None): raise self._arg_parsing_error(full_args, kwargs) # We support two quantization modes: (unsigned) asymmetric and signed-symmetric qmin, qmax = _derive_qmin_qmax(bitwidth=bitwidth, signed=symmetric) - encoding_analyzer = kwargs.get('encoding_analyzer', args[2]) - block_size = kwargs.get('block_size', args[3]) + encoding_analyzer = kwargs.pop('encoding_analyzer', args[2]) + block_size = kwargs.pop('block_size', args[3]) assert qmin is not None assert qmax is not None + if kwargs: + cls = type(self).__qualname__ + unexpected_keys = ', '.join(kwargs.keys()) + raise TypeError(f"{cls}.__init__ got unexpected keyword argument: {unexpected_keys}") + + if qmin >= qmax: + raise ValueError(f"qmax should be strictly larger than qmin. Got qmax={qmax}, qmin={qmin}") + self.qmin = qmin self.qmax = qmax self._symmetric = symmetric diff --git a/TrainingExtensions/torch/test/python/v2/quantization/affine/test_affine_quantizer.py b/TrainingExtensions/torch/test/python/v2/quantization/affine/test_affine_quantizer.py index e52c533661..27d9b9b528 100644 --- a/TrainingExtensions/torch/test/python/v2/quantization/affine/test_affine_quantizer.py +++ b/TrainingExtensions/torch/test/python/v2/quantization/affine/test_affine_quantizer.py @@ -1506,6 +1506,22 @@ def test_parse_args_error(): """ with pytest.raises(TypeError): Quantize((1, 10), -128, 127) + """ + When: Instantiate with (tuple, int, int, bool, 'foo'=any) + Then: Throw TypeError + """ + with pytest.raises(TypeError): + Quantize((1, 10), -128, 127, True, foo=None) + + """ + When: Instantiate with qmin >= qmax + Then: Throw ValueError + """ + with pytest.raises(ValueError): + Quantize((1, 10), 127, -128, True) + + with pytest.raises(ValueError): + Quantize((1, 10), 127, 127, True) """ When: Instantiate with (tuple, int, bool)