Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Speed up minifloat #1030

Draft
wants to merge 12 commits into
base: dev
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 8 additions & 4 deletions src/brevitas/core/function_wrapper/clamp.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from brevitas.core.utils import StatelessBuffer
from brevitas.function import tensor_clamp
from brevitas.function.ops import max_float
from brevitas.utils.torch_utils import MAX_MANTISSA_DICT


class TensorClamp(brevitas.jit.ScriptModule):
Expand Down Expand Up @@ -106,6 +107,7 @@ def __init__(
self.inf_values = inf_values
self.nan_values = nan_values
self.signed = signed
self.max_mantissa_dict = MAX_MANTISSA_DICT

if max_available_float:
max_available_float = torch.tensor(max_available_float, device=device, dtype=dtype)
Expand Down Expand Up @@ -144,15 +146,17 @@ def forward(
mantissa_bit_width: Tensor,
exponent_bias: Tensor):

max_value = max_float(exponent_bit_width, mantissa_bit_width, exponent_bias)
max_value = max_float(
exponent_bit_width, self.max_mantissa_dict[mantissa_bit_width.item()], exponent_bias)
max_value = max_value if self.max_available_float is None else torch.min(
max_value, self.max_available_float())
min_value = torch.tensor(0.) if not self.signed else -max_value

# Compute masks
inf_mask = x.isinf()
p_max_val_mask = x > max_value
n_max_val_mask = -x > max_value
if not self.saturating:
inf_mask = x.isinf()
p_max_val_mask = x > max_value
n_max_val_mask = -x > max_value

# first clamp everything to +- max_value, basically the saturating case
x = self.saturating_clamp(x, max_value, min_value)
Expand Down
21 changes: 13 additions & 8 deletions src/brevitas/core/quant/float.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause

import time
from typing import Optional, Tuple

import torch
Expand Down Expand Up @@ -64,11 +65,10 @@ def __init__(
if dtype is None:
dtype = torch.get_default_dtype()
self.eps = torch.finfo(dtype).tiny
self.observer_only = brevitas.jit.Attribute(False, bool)

@brevitas.jit.script_method
def quantize(self, x: torch.Tensor):
scale = self.scaling_impl(x)

def quantize(self, x: torch.Tensor, scale: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
if self.float_scaling_impl is not None:
float_scaling_impl_value = self.float_scaling_impl(
self.exponent_bit_width(), self.mantissa_bit_width(), self.exponent_bias())
Expand All @@ -86,10 +86,15 @@ def dequantize(self, y, scale):

@brevitas.jit.script_method
def forward(self, x):
y, scale = self.quantize(x)
# after quantizing, clamp to special cases like NaN/inf if they are set
y, saturating, inf_values, nan_values = self.float_clamp_impl(
y, self.exponent_bit_width(), self.mantissa_bit_width(), self.exponent_bias())
y = self.dequantize(y, scale)
scale = self.scaling_impl(x)
if self.observer_only:
y = x
saturating, inf_values, nan_values = self.float_clamp_impl.saturating, self.float_clamp_impl.inf_values, self.float_clamp_impl.nan_values
else:
y, scale = self.quantize(x, scale)
# after quantizing, clamp to special cases like NaN/inf if they are set
y, saturating, inf_values, nan_values = self.float_clamp_impl(
y, self.exponent_bit_width(), self.mantissa_bit_width(), self.exponent_bias())
y = self.dequantize(y, scale)
# This is to respect the current interface of proxies
return y, scale, self.zero_point_impl(), self.exponent_bit_width(), self.mantissa_bit_width(), self.exponent_bias(), saturating, inf_values, nan_values
17 changes: 14 additions & 3 deletions src/brevitas/core/quant/int.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@ def __init__(
self.int_scaling_impl = int_scaling_impl
self.zero_point_impl = zero_point_impl
self.msb_clamp_bit_width_impl = bit_width_impl
self.observer_only = brevitas.jit.Attribute(False, bool)

@brevitas.jit.script_method
def forward(self, x: Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
Expand All @@ -153,7 +154,10 @@ def forward(self, x: Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
int_threshold = self.int_scaling_impl(bit_width)
scale = threshold / int_threshold
zero_point = self.zero_point_impl(x, scale, bit_width)
y = self.int_quant(scale, zero_point, bit_width, x)
if self.observer_only:
y = x
else:
y = self.int_quant(scale, zero_point, bit_width, x)
return y, scale, zero_point, bit_width


Expand All @@ -176,6 +180,7 @@ def __init__(
self.pre_zero_point_impl = pre_zero_point_impl
self.zero_point_impl = zero_point_impl
self.msb_clamp_bit_width_impl = bit_width_impl
self.observer_only = brevitas.jit.Attribute(False, bool)

@brevitas.jit.script_method
def forward(self, x: Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]:
Expand All @@ -187,7 +192,10 @@ def forward(self, x: Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Te
threshold = self.scaling_impl(x)
scale = threshold / int_threshold
zero_point = self.zero_point_impl(x, scale, bit_width)
y = self.decoupled_int_quant(pre_scale, pre_zero_point, scale, zero_point, bit_width, x)
if self.observer_only:
y = x
else:
y = self.decoupled_int_quant(pre_scale, pre_zero_point, scale, zero_point, bit_width, x)
return y, scale, zero_point, bit_width, pre_scale, pre_zero_point


Expand Down Expand Up @@ -253,5 +261,8 @@ def forward(self, x: Tensor, input_bit_width: Tensor,
threshold = self.scaling_impl(x)
scale = threshold / int_threshold
zero_point = self.zero_point_impl(x, scale, bit_width)
y = self.decoupled_int_quant(pre_scale, pre_zero_point, scale, zero_point, bit_width, x)
if self.observer_only:
y = x
else:
y = self.decoupled_int_quant(pre_scale, pre_zero_point, scale, zero_point, bit_width, x)
return y, scale, zero_point, bit_width, pre_scale, pre_zero_point
5 changes: 4 additions & 1 deletion src/brevitas/core/scaling/float_scaling.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import brevitas
from brevitas.core.utils import StatelessBuffer
from brevitas.function.ops import max_float
from brevitas.utils.torch_utils import MAX_MANTISSA_DICT


class FloatScaling(brevitas.jit.ScriptModule):
Expand All @@ -25,6 +26,7 @@ def __init__(
self.inf_values = inf_values
self.nan_values = nan_values
self.saturating = saturating
self.max_mantissa_dict = MAX_MANTISSA_DICT

if max_available_float:
max_available_float = torch.tensor(max_available_float, device=device, dtype=dtype)
Expand All @@ -36,7 +38,8 @@ def __init__(
def forward(
self, exponent_bit_width: Tensor, mantissa_bit_width: Tensor,
exponent_bias: Tensor) -> Tensor:
max_value = max_float(exponent_bit_width, mantissa_bit_width, exponent_bias)
max_value = max_float(
exponent_bit_width, self.max_mantissa_dict[mantissa_bit_width.item()], exponent_bias)
max_value = max_value if self.max_available_float is None else torch.min(
max_value, self.max_available_float())
return max_value
20 changes: 20 additions & 0 deletions src/brevitas/core/stats/stats_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -442,6 +442,19 @@ def _set_local_loss_mode(module, enabled):
m.local_loss_mode = enabled


def _set_observer_mode(module, enabled, previous_observer_mode):
for m in module.modules():
if hasattr(m, 'observer_only'):
previous_observer_mode[m] = m.observer_only
m.observer_only = enabled


def _restore_observer_mode(module, previous_observer_mode):
for m in module.modules():
if hasattr(m, 'observer_only'):
m.observer_only = previous_observer_mode[m]


class MSE(torch.nn.Module):
# References:
# https://github.com/cornell-zhang/dnn-quant-ocs/blob/master/distiller/quantization/clip.py
Expand All @@ -459,7 +472,12 @@ def __init__(
self.mse_init_op = mse_init_op
self.input_view_shape_impl = inner_stats_input_view_shape_impl
self.proxy_forward = proxy_module.forward
self.previous_observer_mode = dict()
self.set_local_loss_mode = lambda enabled: _set_local_loss_mode(proxy_module, enabled)
self.set_observer_mode = lambda enabled: _set_observer_mode(
proxy_module, enabled, self.previous_observer_mode)
self.restore_observer_mode = lambda: _restore_observer_mode(
proxy_module, self.previous_observer_mode)
self.internal_candidate = None
self.num = mse_iters
self.search_method = mse_search_method
Expand All @@ -480,10 +498,12 @@ def evaluate_loss(self, x, candidate):
self.internal_candidate = candidate
# Set to local_loss_mode before calling the proxy
self.set_local_loss_mode(True)
self.set_observer_mode(False)
quant_value = self.proxy_forward(x)
quant_value = _unpack_quant_tensor(quant_value)
loss = self.mse_loss_fn(x, quant_value)
self.set_local_loss_mode(False)
self.restore_observer_mode()
return loss

def mse_grid_search(self, xl, x):
Expand Down
8 changes: 4 additions & 4 deletions src/brevitas/export/inference/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from brevitas.proxy.parameter_quant import WeightQuantProxyFromInjector
from brevitas.proxy.runtime_quant import ActQuantProxyFromInjector
from brevitas.utils.torch_utils import float_internal_scale
from brevitas.utils.torch_utils import MAX_MANTISSA_DICT


class InferenceHandler(torch.nn.Module, ABC):
Expand Down Expand Up @@ -101,12 +102,11 @@ def prepare_for_export(self, module):
self.float_to_int_impl = module.fused_activation_quant_proxy.tensor_quant.float_to_int_impl
self.float_clamp_impl = module.fused_activation_quant_proxy.tensor_quant.float_clamp_impl

self.max_clamp = max_float(
self.exponent_bit_width, self.mantissa_bit_width, self.exponent_bias)
self.min_clamp = -self.max_clamp
self.fp_internal_scale_min = 1. - self.exponent_bias - self.mantissa_bit_width
self.max_value = max_float(
self.exponent_bit_width, self.mantissa_bit_width, self.exponent_bias)
self.exponent_bit_width,
MAX_MANTISSA_DICT[self.mantissa_bit_width.item()],
self.exponent_bias)
self.min_value = torch.tensor(0.) if not module.is_signed else -self.max_value

def quantize(self, x):
Expand Down
10 changes: 1 addition & 9 deletions src/brevitas/function/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,16 +189,8 @@ def min_int(signed: bool, narrow_range: bool, bit_width: Tensor) -> Tensor:
return value


@brevitas.jit.ignore
def max_float(exponent_bit_width: Tensor, mantissa_bit_width: Tensor, exponent_bias: Tensor):
def max_float(exponent_bit_width: Tensor, max_mantissa: Tensor, exponent_bias: Tensor):
max_exponent = (2. ** exponent_bit_width) - 1. - exponent_bias
max_mantissa = torch.sum((
2. ** torch.arange(
0,
-1. * mantissa_bit_width - 1.,
-1.,
dtype=mantissa_bit_width.dtype,
device=mantissa_bit_width.device)))
max_val = max_mantissa * (2 ** max_exponent)
return max_val

Expand Down
11 changes: 6 additions & 5 deletions src/brevitas/graph/calibrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,8 +201,9 @@ def disable_act_quantization(self, model, is_training):
if isinstance(module, ActQuantProxyFromInjectorBase):
module.train(is_training)
if self.call_act_quantizer_impl:
hook = module.register_forward_hook(self.disable_act_quant_hook)
self.disable_act_quant_hooks.append(hook)
for m in module.modules():
if hasattr(m, 'observer_only'):
m.observer_only = True
else:
module.disable_quant = True
elif isinstance(module, _ACC_PROXIES):
Expand All @@ -229,9 +230,9 @@ def enable_act_quantization(self, model, is_training):
elif isinstance(module, ActQuantProxyFromInjectorBase):
module.disable_quant = False
module.train(is_training)
for hook in self.disable_act_quant_hooks:
hook.remove()
self.disable_act_quant_hooks = []
for m in module.modules():
if hasattr(m, 'observer_only'):
m.observer_only = False

def enable_param_quantization(self, model, is_training):
for module in model.modules():
Expand Down
8 changes: 8 additions & 0 deletions src/brevitas/utils/torch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,3 +113,11 @@ def padding(x: torch.Tensor, group_size: int, group_dim: int) -> List[int]:
padding[2 * group_dim] = group_size - size[group_dim] % group_size
padding = list(reversed(padding))
return padding


def max_mantissa_func(val):
import torch
return torch.sum((2. ** torch.arange(0, -1. * val - 1., -1.)))


MAX_MANTISSA_DICT = {x: max_mantissa_func(x) for x in range(0, 16)}
5 changes: 3 additions & 2 deletions tests/brevitas/core/test_clamp.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from brevitas.quant.experimental.float_quant_ocp import Fp8e5m2OCPWeight
from brevitas.utils.float_quant_utils import get_max_available_float
from brevitas.utils.float_quant_utils import get_min_available_float
from brevitas.utils.torch_utils import MAX_MANTISSA_DICT
from tests.brevitas.hyp_helper import float_tensor_random_shape_st

from .minifloat_fixtures import *
Expand Down Expand Up @@ -51,7 +52,7 @@
def test_max_value(minifloat, expected_max_val):
max_val = max_float(
torch.tensor(minifloat.exponent_bit_width, dtype=torch.float32),
torch.tensor(minifloat.mantissa_bit_width, dtype=torch.float32),
MAX_MANTISSA_DICT[minifloat.mantissa_bit_width],
torch.tensor(minifloat.exponent_bias, dtype=torch.float32))
max_available_float = get_max_available_float(
minifloat.exponent_bit_width,
Expand Down Expand Up @@ -84,7 +85,7 @@ def test_float_clamp(inp, fp8_clamp):

max_val = max_float(
torch.tensor(fp8_clamp.exponent_bit_width, dtype=torch.float32),
torch.tensor(fp8_clamp.mantissa_bit_width, dtype=torch.float32),
MAX_MANTISSA_DICT[fp8_clamp.mantissa_bit_width],
torch.tensor(fp8_clamp.exponent_bias, dtype=torch.float32))
max_available_float = get_max_available_float(
fp8_clamp.exponent_bit_width,
Expand Down
10 changes: 6 additions & 4 deletions tests/brevitas/core/test_float_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from brevitas.core.scaling import FloatScaling
from brevitas.function.ops import max_float
from brevitas.utils.torch_utils import float_internal_scale
from brevitas.utils.torch_utils import MAX_MANTISSA_DICT
from tests.brevitas.hyp_helper import float_st
from tests.brevitas.hyp_helper import float_tensor_random_shape_st
from tests.brevitas.hyp_helper import random_minifloat_format
Expand Down Expand Up @@ -98,8 +99,8 @@ def test_float_to_quant_float(inp, minifloat_format):
signed=signed,
float_clamp_impl=float_clamp)
expected_out, *_ = float_quant(inp)

out_quant, scale = float_quant.quantize(inp)
scale = float_quant.scaling_impl(inp)
out_quant, scale = float_quant.quantize(inp, scale)
exponent_bit_width, mantissa_bit_width, exponent_bias = torch.tensor(exponent_bit_width, dtype=torch.float), torch.tensor(mantissa_bit_width, dtype=torch.float), torch.tensor(exponent_bias, dtype=torch.float)
out_quant, *_ = float_quant.float_clamp_impl(
out_quant, exponent_bit_width, mantissa_bit_width, exponent_bias)
Expand Down Expand Up @@ -142,7 +143,8 @@ def test_scaling_impls_called_once(inp, minifloat_format):
scaling_impl=scaling_impl,
float_scaling_impl=float_scaling_impl,
float_clamp_impl=float_clamp)
_ = float_quant.quantize(inp)
scale = float_quant.scaling_impl(inp)
_ = float_quant.quantize(inp, scale)
# scaling implementations should be called exaclty once on the input
float_scaling_impl.assert_called_once_with(
torch.tensor(exponent_bit_width),
Expand Down Expand Up @@ -196,7 +198,7 @@ def test_inner_scale(inp, minifloat_format, scale):
scaled_inp = inp / scale
max_val = max_float(
torch.tensor(exponent_bit_width),
torch.tensor(mantissa_bit_width),
MAX_MANTISSA_DICT[mantissa_bit_width],
torch.tensor(exponent_bias))
max_available_float = float_clamp.max_available_float
max_value = max_val if max_available_float is None else torch.min(
Expand Down
Loading
Loading