Skip to content

Commit

Permalink
Feat (gptq): support for groupwise conv (#690)
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 authored Jul 30, 2023
1 parent ae65a79 commit 3c5ca54
Showing 1 changed file with 38 additions and 80 deletions.
118 changes: 38 additions & 80 deletions src/brevitas/graph/gptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,10 +85,7 @@ def __init__(

def _is_module_supported(self, module):
if isinstance(module, SUPPORTED_CONV_OP):
if (module.groups == 1 or (module.groups == module.out_channels)):
return True
else:
return False
return True
elif isinstance(module, qnn.QuantLinear):
return True
else:
Expand Down Expand Up @@ -324,38 +321,21 @@ def single_layer_update(self, percdamp=.01):
# thus len(permutation_list) is always equal to self.groups.
# We do not explicity permute the weight matrix, only the Hessian.
permutation_list = []
if self.groups > 1:
# For groupwise convolution, these operations are groupwise so we iterate
for i in range(self.groups):
# If a diagonal element on the Hessian is zero, we can set to 0 the corresponding
# column in the weight matrix.
# The diagonal element is set to 1 to avoid division-by-zero
dead = torch.diag(self.H[i, :, :]) == 0
self.H[i, dead, dead] = 1
# If the diagonal of activations is zero, we set the weight to zero
weight[i, dead] = 0
if self.act_order:
# Re-order Hessian so that weights associated to
# higher magnitude activations are quantized first
perm = torch.argsort(torch.diag(self.H[i, :, :]), descending=True)
self.H[i, :, :] = self.H[i, perm, :][:, perm]
else:
# No permutation, permutation tensor is a ordered index
perm = torch.tensor(range(self.H.shape[-1]), device=dev)
permutation_list.append(perm)
else:
weight = weight.view(self.groups, -1, weight.shape[-1])
# For groupwise convolution, these operations are groupwise so we iterate
for i in range(self.groups):
# If a diagonal element on the Hessian is zero, we can set to 0 the corresponding
# column in the weight matrix.
# The diagonal element is set to 1 to avoid division-by-zero
dead = torch.diag(self.H[0, :, :]) == 0
self.H[0, dead, dead] = 1
dead = torch.diag(self.H[i, :, :]) == 0
self.H[i, dead, dead] = 1
# If the diagonal of activations is zero, we set the weight to zero
weight[:, dead] = 0
weight[i, :, dead] = 0
if self.act_order:
# Re-order Hessian so that weights associated to
# higher magnitude activations are quantized first
perm = torch.argsort(torch.diag(self.H[0, :, :]), descending=True)
self.H = self.H[:, perm, :][:, :, perm]
perm = torch.argsort(torch.diag(self.H[i, :, :]), descending=True)
self.H[i, :, :] = self.H[i, perm, :][:, perm]
else:
# No permutation, permutation tensor is a ordered index
perm = torch.tensor(range(self.H.shape[-1]), device=dev)
Expand Down Expand Up @@ -383,64 +363,41 @@ def single_layer_update(self, percdamp=.01):
for i1 in range(0, self.columns, self.blocksize):
i2 = min(i1 + self.blocksize, self.columns)
count = i2 - i1
error_block = torch.zeros_like(
weight[:, :, perm[i1:i2]], dtype=torch.float32) # [groups, OC/groups, i2-i1]

# len(permutation_list) == self.groups
if self.groups == 1:
perm = permutation_list[0]
weight_block = weight[:, perm[i1:i2]].to(torch.float32) # This creates a copy
else:
# For groups > 1, we permute each row independently
weight_block = torch.empty(
weight.shape[0], count, device=dev, dtype=torch.float32) # [OC, i2-i1]
for ii, perm in enumerate(permutation_list):
weight_block[ii, :] = weight[ii, perm[i1:i2]].to(
torch.float32) # This creates a copy

error_block = torch.zeros_like(weight_block) # [OC, i2-i1]
h_inv_block = h_inv[:, i1:i2, i1:i2]
for i in range(count):
w = weight_block[:, i] # [OC]
d = h_inv_block[:, i, i] # [groups]
q = self.get_quant_weights(i, i1, i2, permutation_list) # [OC]

error = (w - q) / d # [OC]
if self.groups > 1:
# In case of depthwise convs, each weight matrix interacts with only
# part of the input values, thus with only one of the hessian matrix
for ii, perm in enumerate(permutation_list):
weight_block[ii, i:] -= error[ii] * h_inv_block[ii, i, i:]
# We need to update the original weights
weight[ii, perm[i1:i2][i:]] = weight_block[ii, i:].to(dtype)
else:
perm = permutation_list[0]
weight_block[:, i:] -= error.unsqueeze(1).matmul(
h_inv_block[0, i, i:].unsqueeze(0))
q_groups = self.get_quant_weights(i, i1, permutation_list) # [groups, OC/groups]
for group_index in range(self.groups):
perm = permutation_list[group_index]
q = q_groups[group_index] # [OC/groups]
w = weight[group_index, :, perm[i1:i2][i]].to(torch.float32) # [OC/groups]
d = h_inv[group_index, i, i] # [1]
error = (w - q) / d # [OC/groups]
error_block[group_index, :, i] = error
# We need to update the original weights
weight[:, perm[i1:i2][i:]] = weight_block[:, i:].to(dtype)
error_block[:, i] = error

if self.groups > 1:
# In case of depthwise convs, each weight matrix interacts with only
# part of the input values, thus with only one of the hessian matrix
for ii, perm in enumerate(permutation_list):
weight[ii:ii + 1,
perm[i2:]] -= (error_block[ii:ii + 1, :].matmul(h_inv[ii, i1:i2,
i2:])).to(dtype)
else:
perm = permutation_list[0]
weight[:, perm[i2:]] -= (error_block.matmul(h_inv[0, i1:i2, i2:])).to(dtype)
weight[group_index, :, perm[i1:i2][i:]] -= (
error.unsqueeze(1).matmul(h_inv_block[group_index, i,
i:].unsqueeze(0))).to(dtype)

def get_quant_weights(self, i, i1, i2, permutation_list):
# We need to recompute quant weights at runtime since our float weights are being updated
for group_index in range(self.groups):
perm = permutation_list[group_index]
weight[group_index, :, perm[i2:]] -= (
error_block[group_index].matmul(h_inv[group_index, i1:i2, i2:])).to(dtype)

def get_quant_weights(self, i, i1, permutation_list):
# We need to recompute quant weights at runtime since our float weights are being updated
# Add offset in case of blockwise computation (e.g., GPTQ)
i = i1 + i
# For QuantLinear and for some QuantConvolutional layers, we exploit the possibility
# of quantizing only a subset of the entire matrix speeding up the computation of GPTQ
if isinstance(self.layer, qnn.QuantLinear):
index = permutation_list[0][i1:i2][i]
index = permutation_list[0][i]
subtensor_slice_list = [None, (index, index + 1)]
q = self.layer.quant_weight(
subtensor_slice_list=subtensor_slice_list,
quant_input=self.quant_input).value # [OC, 1]
quant_input=self.quant_input).value.unsqueeze(0) # [1, OC, 1]
elif isinstance(self.layer, SUPPORTED_CONV_OP):
# For depthwise and ConvTranspose we fall back to quantizing the entire martix.
# For all other cases, we create a mask that represent the slicing we will perform on the weight matrix
Expand All @@ -454,15 +411,15 @@ def get_quant_weights(self, i, i1, i2, permutation_list):
if isinstance(self.layer, (qnn.QuantConvTranspose1d, qnn.QuantConvTranspose2d)):
quant_weight = quant_weight.transpose(1, 0) # This performs a view
quant_weight = quant_weight.flatten(1)
quant_weight = quant_weight.view(self.groups, -1, quant_weight.shape[-1])

if self.act_order:
for ii, perm in enumerate(permutation_list):
quant_weight[ii, :] = quant_weight[ii, perm]
quant_weight[ii, :, :] = quant_weight[ii, :, perm]

quant_weight_block = quant_weight[:, i1:i2]
q = quant_weight_block[:, i:i + 1] # [OC, 1]
q = quant_weight[:, :, i:i + 1] # [groups, OC/groups, 1]
else:
index = permutation_list[0][i1:i2][i]
index = permutation_list[0][i]
shapes = self.layer.weight.shape[1:]
index_2d_to_nd = []
residual_index = index.item()
Expand All @@ -474,6 +431,7 @@ def get_quant_weights(self, i, i1, i2, permutation_list):
q = self.layer.quant_weight(
subtensor_slice_list=index_2d_to_nd,
quant_input=self.quant_input).value.flatten(1) # [OC, 1]
q = q.unsqueeze(0) # [1, OC, 1]
# We need to remove the last dim
q = q.squeeze(1) # [OC]
q = q.squeeze(2) # [groups, OC/groups] or [1, OC]
return q

0 comments on commit 3c5ca54

Please sign in to comment.