Skip to content

Commit

Permalink
[int8 woq] make the scale type the same as input for bf16 autocast (p…
Browse files Browse the repository at this point in the history
  • Loading branch information
Valentine233 authored Jul 29, 2024
1 parent 4abe4b8 commit 77c99d1
Showing 1 changed file with 4 additions and 6 deletions.
10 changes: 4 additions & 6 deletions torchao/dtypes/affine_quantized_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -795,17 +795,15 @@ def _quantized_linear_op(input_tensor, weight_qtensor, bias):
w_vals_int8_t = weight_qtensor.layout_tensor.int_data.t()
scale = weight_qtensor.layout_tensor.scale
orig_dtype = input_tensor.dtype
y = (
torch.mm(
m = torch.mm(
input_tensor.reshape(-1, input_tensor.shape[-1]),
w_vals_int8_t.to(input_tensor.dtype),
)
* scale
)
y = m * scale.to(m.dtype)
y = y.reshape(*input_tensor.shape[:-1], y.shape[-1])
if bias is not None:
y += bias
return y.to(orig_dtype)
y += bias.to(m.dtype)
return y

# is_cpu and is_mps only, some issue with is_contiguous() currently
# return torch.ops.aten._weight_int8pack_mm(input_tensor.contiguous(), w_vals_int8_t, weight_qtensor.layout_tensor.scale)
Expand Down

0 comments on commit 77c99d1

Please sign in to comment.