From dd05b5285b08c1092c176d7fb4fd14ec412d4b8e Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Mon, 28 Oct 2024 09:59:08 +0000 Subject: [PATCH] Fix --- src/brevitas/quant_tensor/int_torch_handler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/brevitas/quant_tensor/int_torch_handler.py b/src/brevitas/quant_tensor/int_torch_handler.py index 91627633d..fb5db7ca1 100644 --- a/src/brevitas/quant_tensor/int_torch_handler.py +++ b/src/brevitas/quant_tensor/int_torch_handler.py @@ -237,7 +237,7 @@ def quant_output_scale_impl( if len(quant_input_scale.shape) == 0: quant_input_scale = quant_input_scale.view(output_scale_shape) quant_input_scale = quant_input_scale.view(output_scale_shape) - if not is_broadcastable(output_scale_shape, quant_input_scale.shape): + if not is_broadcastable(quant_weight_scale.shape, quant_input_scale.shape): return None output_scale = quant_weight_scale * quant_input_scale