Skip to content

Commit

Permalink
Test (brevitas_examples): test for mse weights
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Oct 9, 2024
1 parent f06d99c commit eef36e2
Showing 1 changed file with 67 additions and 0 deletions.
67 changes: 67 additions & 0 deletions tests/brevitas_examples/test_quantize_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -654,6 +654,73 @@ def get_qmse(
assert torch.isclose(diff_mse, orig_mse) or (diff_mse > orig_mse)


@pytest.mark.parametrize("quant_granularity", ["per_tensor", "per_channel"])
@jit_disabled_for_local_loss()
def test_layerwise_stats_vs_mse(simple_model, quant_granularity):
"""
We test layerwise quantization, with the weight and activation quantization `mse` parameter
methods.
We test:
- We can feed data through the model.
- That the stat observer is explictly MSE.
- That the view on the quantization granularity is as desired.
- That during calibration, the qparams are derived by finding values that minimize the MSE
between the floating point and quantized tensor.
"""
weight_bit_width = 8
act_bit_width = 8
bias_bit_width = 32
quant_model_mse = quantize_model(
model=deepcopy(simple_model),
backend='layerwise',
weight_bit_width=weight_bit_width,
act_bit_width=act_bit_width,
bias_bit_width=bias_bit_width if bias_bit_width > 0 else None,
weight_quant_granularity=quant_granularity,
act_quant_type='asym',
act_quant_percentile=99.9, # Unused
scale_factor_type='float_scale',
quant_format='int',
weight_param_method='mse',
act_param_method='mse')

quant_model_stats = quantize_model(
model=deepcopy(simple_model),
backend='layerwise',
weight_bit_width=weight_bit_width,
act_bit_width=act_bit_width,
bias_bit_width=bias_bit_width if bias_bit_width > 0 else None,
weight_quant_granularity=quant_granularity,
act_quant_type='asym',
act_quant_percentile=99.9, # Unused
scale_factor_type='float_scale',
quant_format='int',
weight_param_method='stats',
act_param_method='mse')

# We create an input with values linearly scaled between 0 and 1.
input = torch.arange(0, 1, step=1 / (10 * IMAGE_DIM ** 2))
input = input.view(1, 10, IMAGE_DIM, IMAGE_DIM).float()
with torch.no_grad():
with calibration_mode(quant_model_mse):
quant_model_mse(input)
quant_model_mse.eval()
with torch.no_grad():
with calibration_mode(quant_model_stats):
quant_model_stats(input)
quant_model_stats.eval()
weight = simple_model.layers.get_submodule('0').weight
first_conv_layer_mse = quant_model_mse.layers.get_submodule('0')
first_conv_layer_stats = quant_model_stats.layers.get_submodule('0')

l2_stats = ((weight - first_conv_layer_stats.quant_weight().value) ** 2).sum()
l2_mse = ((weight - first_conv_layer_mse.quant_weight().value) ** 2).sum()

# Recostruction error of MSE should be smaller or equal to stats
assert l2_mse - l2_stats <= torch.tensor(1e-5)


@pytest.mark.parametrize("weight_bit_width", [2, 5, 8, 16])
@pytest.mark.parametrize("act_bit_width", [2, 5, 8])
@pytest.mark.parametrize("bias_bit_width", [16, 32])
Expand Down

0 comments on commit eef36e2

Please sign in to comment.