Skip to content

Commit

Permalink
g_idx to fakequantize
Browse files Browse the repository at this point in the history
  • Loading branch information
horheynm committed Jun 26, 2024
1 parent 2525f69 commit c6b5b28
Showing 1 changed file with 9 additions and 7 deletions.
16 changes: 9 additions & 7 deletions src/sparseml/modifiers/quantization/gptq/utils/gptq_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,10 +192,10 @@ def fasterprune(
requires_grad=False,
)
else:
g_idx = torch.Tensor(
g_idx = torch.tensor(
[j // group_size for j in range(self.columns)],

device=W.device,
dtype=torch.int32,
device=W.device
)

from compressed_tensors.quantization import QuantizationStrategy
Expand All @@ -204,14 +204,12 @@ def fasterprune(
)

strategy = quant_scheme.weights.strategy
breakpoint()
if strategy == QuantizationStrategy.TENSOR:
q = fake_quantize(
q,
scale,
zero_point,
self.layer.quantization_scheme.weights,
g_idx,
)
elif strategy == QuantizationStrategy.CHANNEL:
# TODO: for channelwise why isn't this just a 1d tensor?
Expand All @@ -228,16 +226,20 @@ def fasterprune(
input_dim_group = (
column_idx // quant_scheme.weights.group_size
)

# Since we're only applying quantization to a slice, this
# ends up being a channelwise application
altered_qargs = copy(quant_scheme.weights)
altered_qargs.strategy = QuantizationStrategy.CHANNEL

# # apply g_idx
# if g_idx is not None:
# scale = scale[g_idx]
# zero_point = zero_point[g_idx]

q = fake_quantize(
q,
scale[:, input_dim_group],
zero_point[:, input_dim_group],
# g_idx,
altered_qargs,
)

Expand Down

0 comments on commit c6b5b28

Please sign in to comment.