Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[inference] Add smmoothquant for llama #4904

Merged
merged 10 commits into from
Oct 16, 2023
Empty file.
12 changes: 12 additions & 0 deletions colossalai/inference/quant/smoothquant/models/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
try:
import torch_int

HAS_TORCH_INT = True
except ImportError:
HAS_TORCH_INT = False
raise ImportError(
"Not install torch_int. Please install torch_int from https://github.com/Guangxuan-Xiao/torch-int"
)

if HAS_TORCH_INT:
from .llama import LLamaSmoothquantAttention, LlamaSmoothquantMLP
492 changes: 492 additions & 0 deletions colossalai/inference/quant/smoothquant/models/base_model.py

Large diffs are not rendered by default.

175 changes: 175 additions & 0 deletions colossalai/inference/quant/smoothquant/models/linear.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
import torch
from torch_int._CUDA import linear_a8_w8_b8_o8, linear_a8_w8_bfp32_ofp32
from torch_int.functional.quantization import quantize_per_tensor_absmax

try:
from colossalai.kernel.op_builder.smoothquant import SmoothquantBuilder

smoothquant_cuda = SmoothquantBuilder().load()
HAS_SMOOTHQUANT_CUDA = True
except ImportError:
HAS_SMOOTHQUANT_CUDA = False
raise ImportError("CUDA smoothquant linear is not installed")


class W8A8BFP32O32LinearSiLU(torch.nn.Module):
def __init__(self, in_features, out_features, alpha=1.0, beta=1.0):
super().__init__()
self.in_features = in_features
self.out_features = out_features

self.register_buffer(
"weight",
torch.randint(
-127,
127,
(self.out_features, self.in_features),
dtype=torch.int8,
requires_grad=False,
),
)
self.register_buffer(
"bias",
torch.zeros((1, self.out_features), dtype=torch.float, requires_grad=False),
)
self.register_buffer("a", torch.tensor(alpha))

def to(self, *args, **kwargs):
super().to(*args, **kwargs)
self.weight = self.weight.to(*args, **kwargs)
self.bias = self.bias.to(*args, **kwargs)
return self

@torch.no_grad()
def forward(self, x):
x_shape = x.shape
x = x.view(-1, x_shape[-1])
y = smoothquant_cuda.linear_silu_a8_w8_bfp32_ofp32(x, self.weight, self.bias, self.a.item(), 1.0)
y = y.view(*x_shape[:-1], -1)
return y

@staticmethod
def from_float(module: torch.nn.Linear, input_scale):
int8_module = W8A8BFP32O32LinearSiLU(module.in_features, module.out_features)
int8_weight, weight_scale = quantize_per_tensor_absmax(module.weight)
alpha = input_scale * weight_scale
int8_module.weight = int8_weight
if module.bias is not None:
int8_module.bias.data.copy_(module.bias.to(torch.float))
int8_module.a = alpha
return int8_module


class W8A8B8O8Linear(torch.nn.Module):
# For qkv_proj
def __init__(self, in_features, out_features, alpha=1.0, beta=1.0):
super().__init__()
self.in_features = in_features
self.out_features = out_features

self.register_buffer(
"weight",
torch.randint(
-127,
127,
(self.out_features, self.in_features),
dtype=torch.int8,
requires_grad=False,
),
)
self.register_buffer(
"bias",
torch.zeros((1, self.out_features), dtype=torch.int8, requires_grad=False),
)
self.register_buffer("a", torch.tensor(alpha))
self.register_buffer("b", torch.tensor(beta))

def to(self, *args, **kwargs):
super().to(*args, **kwargs)
self.weight = self.weight.to(*args, **kwargs)
self.bias = self.bias.to(*args, **kwargs)
return self

@torch.no_grad()
def forward(self, x):
x_shape = x.shape
x = x.view(-1, x_shape[-1])
y = linear_a8_w8_b8_o8(x, self.weight, self.bias, self.a.item(), self.b.item())
y = y.view(*x_shape[:-1], -1)
return y

@staticmethod
def from_float(module: torch.nn.Linear, input_scale, output_scale):
int8_module = W8A8B8O8Linear(module.in_features, module.out_features)
int8_weight, weight_scale = quantize_per_tensor_absmax(module.weight)
alpha = input_scale * weight_scale / output_scale
int8_module.weight = int8_weight
int8_module.a = alpha

if module.bias is not None:
int8_bias, bias_scale = quantize_per_tensor_absmax(module.bias)
int8_module.bias = int8_bias
beta = bias_scale / output_scale
int8_module.b = beta

return int8_module


class W8A8BFP32OFP32Linear(torch.nn.Module):
# For fc2 and out_proj
def __init__(self, in_features, out_features, alpha=1.0, beta=1.0):
super().__init__()
self.in_features = in_features
self.out_features = out_features

self.register_buffer(
"weight",
torch.randint(
-127,
127,
(self.out_features, self.in_features),
dtype=torch.int8,
requires_grad=False,
),
)
self.register_buffer(
"bias",
torch.zeros(self.out_features, dtype=torch.float32, requires_grad=False),
)
self.register_buffer("a", torch.tensor(alpha))

def _apply(self, fn):
# prevent the bias from being converted to half
super()._apply(fn)
self.bias = self.bias.to(torch.float32)
return self

def to(self, *args, **kwargs):
super().to(*args, **kwargs)
self.weight = self.weight.to(*args, **kwargs)
self.bias = self.bias.to(*args, **kwargs)
self.bias = self.bias.to(torch.float32)
return self

@torch.no_grad()
def forward(self, x):
x_shape = x.shape
x = x.view(-1, x_shape[-1])
y = linear_a8_w8_bfp32_ofp32(x, self.weight, self.bias, self.a.item(), 1)
y = y.view(*x_shape[:-1], -1)
return y

@staticmethod
def from_float(module: torch.nn.Linear, input_scale):
int8_module = W8A8BFP32OFP32Linear(module.in_features, module.out_features)
int8_weight, weight_scale = quantize_per_tensor_absmax(module.weight)
alpha = input_scale * weight_scale
int8_module.weight = int8_weight
int8_module.a = alpha
int8_module.input_scale = input_scale
int8_module.weight_scale = weight_scale

if module.bias is not None:
int8_module.bias = module.bias.to(torch.float32)

return int8_module
Loading
Loading