diff --git a/torchao/dtypes/affine_quantized_tensor.py b/torchao/dtypes/affine_quantized_tensor.py index 4dd17c24e..20f324b35 100644 --- a/torchao/dtypes/affine_quantized_tensor.py +++ b/torchao/dtypes/affine_quantized_tensor.py @@ -795,17 +795,15 @@ def _quantized_linear_op(input_tensor, weight_qtensor, bias): w_vals_int8_t = weight_qtensor.layout_tensor.int_data.t() scale = weight_qtensor.layout_tensor.scale orig_dtype = input_tensor.dtype - y = ( - torch.mm( + m = torch.mm( input_tensor.reshape(-1, input_tensor.shape[-1]), w_vals_int8_t.to(input_tensor.dtype), ) - * scale - ) + y = m * scale.to(m.dtype) y = y.reshape(*input_tensor.shape[:-1], y.shape[-1]) if bias is not None: - y += bias - return y.to(orig_dtype) + y += bias.to(m.dtype) + return y # is_cpu and is_mps only, some issue with is_contiguous() currently # return torch.ops.aten._weight_int8pack_mm(input_tensor.contiguous(), w_vals_int8_t, weight_qtensor.layout_tensor.scale)