From fc7ff8e0ffc293594fef39f47a5c55a51c3ef840 Mon Sep 17 00:00:00 2001 From: Alessandro Pappalardo <1934033+volcacius@users.noreply.github.com> Date: Fri, 10 Nov 2023 15:35:33 +0000 Subject: [PATCH] Fix (core): bug in zero-point statistics with positive only values (#670) --- src/brevitas/core/stats/stats_op.py | 28 ++++++++++++++++++++-------- 1 file changed, 20 insertions(+), 8 deletions(-) diff --git a/src/brevitas/core/stats/stats_op.py b/src/brevitas/core/stats/stats_op.py index 2fa5498c8..194631953 100644 --- a/src/brevitas/core/stats/stats_op.py +++ b/src/brevitas/core/stats/stats_op.py @@ -38,8 +38,7 @@ def forward(self, x: Tensor) -> Tensor: min_val = torch.min(x) else: min_val = torch.min(x, dim=self.stats_reduce_dim, keepdim=self.keepdim)[0] - min_val = torch.where( - min_val <= self.zero().to(min_val.dtype), min_val, self.zero().to(min_val.dtype)) + min_val = torch.clamp(min_val, max=self.zero()) return min_val @@ -107,8 +106,7 @@ def forward(self, x: Tensor) -> Tensor: # k is 1-indexed, so round away from zero k = int(math.ceil(.01 * self.q * dim_slice.numel())) result = x.kthvalue(k, dim=self.stats_reduce_dim, keepdim=self.keepdim).values - result = torch.where( - result <= self.zero().to(result.dtype), result, self.zero().to(result.dtype)) + result = torch.clamp(result, max=self.zero()) return result @@ -120,12 +118,15 @@ def __init__( low_percentile_q, high_percentile_q, stats_reduce_dim: Optional[int] = None, - keepdim: bool = False) -> None: + keepdim: bool = False, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None) -> None: super(PercentileInterval, self).__init__() self.stats_reduce_dim = stats_reduce_dim self.low_q = low_percentile_q self.high_q = high_percentile_q self.keepdim = keepdim + self.zero = StatelessBuffer(torch.tensor(0.0, dtype=dtype, device=device)) @brevitas.jit.script_method def forward(self, x: Tensor) -> Tensor: @@ -145,6 +146,8 @@ def forward(self, x: Tensor) -> Tensor: high_k = int(math.floor(.01 * self.high_q * dim_slice.numel() + 0.5)) low_result = x.kthvalue(low_k, dim=self.stats_reduce_dim, keepdim=self.keepdim).values high_result = x.kthvalue(high_k, dim=self.stats_reduce_dim, keepdim=self.keepdim).values + # We need to make sure the lower bound is not positive to align with zero-point statistics + low_result = torch.clamp(low_result, max=self.zero()) interval = high_result - low_result abs_interval = torch.abs(interval) return abs_interval @@ -169,19 +172,28 @@ def forward(self, x: Tensor): class AbsMinMax(brevitas.jit.ScriptModule): __constants__ = ['stats_reduce_dim', 'keepdim'] - def __init__(self, stats_reduce_dim: Optional[int] = None, keepdim: bool = False) -> None: + def __init__( + self, + stats_reduce_dim: Optional[int] = None, + keepdim: bool = False, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None) -> None: super(AbsMinMax, self).__init__() self.stats_reduce_dim = stats_reduce_dim self.keepdim = keepdim + self.zero = StatelessBuffer(torch.tensor(0.0, dtype=dtype, device=device)) @brevitas.jit.script_method def forward(self, x: Tensor): if self.stats_reduce_dim is None: - return torch.abs(torch.max(x) - torch.min(x)) + max_val = torch.max(x) + min_val = torch.min(x) else: max_val = torch.max(x, dim=self.stats_reduce_dim, keepdim=self.keepdim)[0] min_val = torch.min(x, dim=self.stats_reduce_dim, keepdim=self.keepdim)[0] - return torch.abs(max_val - min_val) + # We need to make sure the lower bound is not positive to align with zero-point statistics + min_val = torch.clamp(min_val, max=self.zero()) + return torch.abs(max_val - min_val) class AbsMaxAve(brevitas.jit.ScriptModule):