Skip to content

Commit

Permalink
Fix (core/stats): add threshold to MSE local loss (#1047)
Browse files Browse the repository at this point in the history
* Fix (core/stats): MSE local loss threshold

* Test (brevitas_examples): test for mse weights

* Update test_quantize_model.py
  • Loading branch information
Giuseppe5 authored Oct 9, 2024
1 parent 424ce6f commit 5bdd1a2
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 1 deletion.
2 changes: 1 addition & 1 deletion src/brevitas/core/scaling/standalone.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ def forward(self, ignored: Tensor, threshold: Optional[Tensor] = None) -> Tensor
# workaround to avoid find_ununsed_parameter=True in DDP
stats = stats + 0. * self.value
if self.local_loss_mode:
return self.stats_scaling_impl(stats)
return self.stats_scaling_impl(stats, threshold)
stats = self.restrict_inplace_preprocess(stats)
threshold = self.restrict_inplace_preprocess(threshold)
inplace_tensor_mul(self.value.detach(), stats)
Expand Down
63 changes: 63 additions & 0 deletions tests/brevitas_examples/test_quantize_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -654,6 +654,69 @@ def get_qmse(
assert torch.isclose(diff_mse, orig_mse) or (diff_mse > orig_mse)


@pytest.mark.parametrize("quant_granularity", ["per_tensor", "per_channel"])
@jit_disabled_for_local_loss()
def test_layerwise_stats_vs_mse(simple_model, quant_granularity):
"""
We test layerwise quantization, with the weight and activation quantization `mse` parameter
methods.
We test:
- Recostruction error of MSE should be smaller or equal to stats
"""
weight_bit_width = 8
act_bit_width = 8
bias_bit_width = 32
quant_model_mse = quantize_model(
model=deepcopy(simple_model),
backend='layerwise',
weight_bit_width=weight_bit_width,
act_bit_width=act_bit_width,
bias_bit_width=bias_bit_width if bias_bit_width > 0 else None,
weight_quant_granularity=quant_granularity,
act_quant_type='asym',
act_quant_percentile=99.9, # Unused
scale_factor_type='float_scale',
quant_format='int',
weight_param_method='mse',
act_param_method='mse')

quant_model_stats = quantize_model(
model=deepcopy(simple_model),
backend='layerwise',
weight_bit_width=weight_bit_width,
act_bit_width=act_bit_width,
bias_bit_width=bias_bit_width if bias_bit_width > 0 else None,
weight_quant_granularity=quant_granularity,
act_quant_type='asym',
act_quant_percentile=99.9, # Unused
scale_factor_type='float_scale',
quant_format='int',
weight_param_method='stats',
act_param_method='mse')

# We create an input with values linearly scaled between 0 and 1.
input = torch.arange(0, 1, step=1 / (10 * IMAGE_DIM ** 2))
input = input.view(1, 10, IMAGE_DIM, IMAGE_DIM).float()
with torch.no_grad():
with calibration_mode(quant_model_mse):
quant_model_mse(input)
quant_model_mse.eval()
with torch.no_grad():
with calibration_mode(quant_model_stats):
quant_model_stats(input)
quant_model_stats.eval()
weight = simple_model.layers.get_submodule('0').weight
first_conv_layer_mse = quant_model_mse.layers.get_submodule('0')
first_conv_layer_stats = quant_model_stats.layers.get_submodule('0')

l2_stats = ((weight - first_conv_layer_stats.quant_weight().value) ** 2).sum()
l2_mse = ((weight - first_conv_layer_mse.quant_weight().value) ** 2).sum()

# Recostruction error of MSE should be smaller or equal to stats
assert l2_mse - l2_stats <= torch.tensor(1e-5)


@pytest.mark.parametrize("weight_bit_width", [2, 5, 8, 16])
@pytest.mark.parametrize("act_bit_width", [2, 5, 8])
@pytest.mark.parametrize("bias_bit_width", [16, 32])
Expand Down

0 comments on commit 5bdd1a2

Please sign in to comment.