Skip to content

Commit

Permalink
Add more argument sanity checks (#3307)
Browse files Browse the repository at this point in the history
Signed-off-by: Kyunggeun Lee <[email protected]>
  • Loading branch information
quic-kyunggeu authored Sep 3, 2024
1 parent 1cef72f commit 66babf3
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -69,19 +69,19 @@ def __init__(self, scale: torch.Tensor, offset: torch.Tensor, bitwidth: int, sig
block_size: Optional[Tuple[int, ...]] = None):
...

def __init__(self, scale: torch.Tensor, offset: torch.Tensor, *args, **kwargs):
def __init__(self, scale: torch.Tensor, offset: torch.Tensor, *args, **kwargs): # pylint: disable=too-many-locals
self._scale = scale
self._offset = offset
full_args = (scale, offset, *args)

# Pad positional args with None's such that len(args) == 4
args = tuple(chain(args, repeat(None, 4 - len(args))))
arg0 = kwargs.get('qmin', kwargs.get('bitwidth', args[0]))
arg1 = kwargs.get('qmax', kwargs.get('signed', args[1]))
symmetry = kwargs.get('symmetry', args[2])
arg0 = kwargs.pop('qmin', kwargs.pop('bitwidth', args[0]))
arg1 = kwargs.pop('qmax', kwargs.pop('signed', args[1]))
symmetry = kwargs.pop('symmetry', args[2])
if symmetry is None:
symmetry = False
block_size = kwargs.get('block_size', args[3])
block_size = kwargs.pop('block_size', args[3])

if arg1 is None or isinstance(arg1, bool):
# (arg0, arg1) == (bitwidth, signed)
Expand All @@ -98,6 +98,14 @@ def __init__(self, scale: torch.Tensor, offset: torch.Tensor, *args, **kwargs):
assert qmin is not None
assert qmax is not None

if kwargs:
cls = type(self).__qualname__
unexpected_keys = ', '.join(kwargs.keys())
raise TypeError(f"{cls}.__init__ got unexpected keyword argument: {unexpected_keys}")

if qmin >= qmax:
raise ValueError(f"qmax should be strictly larger than qmin. Got qmax={qmax}, qmin={qmin}")

self.qmin = qmin
self.qmax = qmax
self._symmetry = symmetry
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,35 +97,43 @@ def __init__(self, shape, *args, **kwargs):

# Pad positional args with None's such that len(args) == 5
args = tuple(chain(args, repeat(None, 5 - len(args))))
arg0 = kwargs.get('qmin', kwargs.get('bitwidth', args[0]))
arg1 = kwargs.get('qmax', args[1])
arg0 = kwargs.pop('qmin', kwargs.pop('bitwidth', args[0]))
arg1 = kwargs.pop('qmax', args[1])

if arg1 is not None and not isinstance(arg1, bool):
# (arg0, arg1, arg2) == (qmin, qmax, symmetric)
qmin, qmax = arg0, arg1
symmetric = kwargs.get('symmetric', args[2])
symmetric = kwargs.pop('symmetric', args[2])

if (qmin is None) or (qmax is None) or (symmetric is None):
raise self._arg_parsing_error(full_args, kwargs)

encoding_analyzer = kwargs.get('encoding_analyzer', args[3])
block_size = kwargs.get('block_size', args[4])
encoding_analyzer = kwargs.pop('encoding_analyzer', args[3])
block_size = kwargs.pop('block_size', args[4])
else:
# (arg0, arg1) == (bitwidth, symmetric)
bitwidth = arg0
symmetric = kwargs.get('symmetric', args[1])
symmetric = kwargs.pop('symmetric', args[1])

if (bitwidth is None) or (symmetric is None):
raise self._arg_parsing_error(full_args, kwargs)

# We support two quantization modes: (unsigned) asymmetric and signed-symmetric
qmin, qmax = _derive_qmin_qmax(bitwidth=bitwidth, signed=symmetric)
encoding_analyzer = kwargs.get('encoding_analyzer', args[2])
block_size = kwargs.get('block_size', args[3])
encoding_analyzer = kwargs.pop('encoding_analyzer', args[2])
block_size = kwargs.pop('block_size', args[3])

assert qmin is not None
assert qmax is not None

if kwargs:
cls = type(self).__qualname__
unexpected_keys = ', '.join(kwargs.keys())
raise TypeError(f"{cls}.__init__ got unexpected keyword argument: {unexpected_keys}")

if qmin >= qmax:
raise ValueError(f"qmax should be strictly larger than qmin. Got qmax={qmax}, qmin={qmin}")

self.qmin = qmin
self.qmax = qmax
self._symmetric = symmetric
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1506,6 +1506,22 @@ def test_parse_args_error():
"""
with pytest.raises(TypeError):
Quantize((1, 10), -128, 127)
"""
When: Instantiate with (tuple, int, int, bool, 'foo'=any)
Then: Throw TypeError
"""
with pytest.raises(TypeError):
Quantize((1, 10), -128, 127, True, foo=None)

"""
When: Instantiate with qmin >= qmax
Then: Throw ValueError
"""
with pytest.raises(ValueError):
Quantize((1, 10), 127, -128, True)

with pytest.raises(ValueError):
Quantize((1, 10), 127, 127, True)

"""
When: Instantiate with (tuple, int, bool)
Expand Down

0 comments on commit 66babf3

Please sign in to comment.