diff --git a/test/prototype/test_quantized_training.py b/test/prototype/test_quantized_training.py index cf76db256..cbd5f7683 100644 --- a/test/prototype/test_quantized_training.py +++ b/test/prototype/test_quantized_training.py @@ -34,9 +34,9 @@ def test_int8_stochastic_rounding(self, device): # due to the statistical nature, this assertion may still fail, though very rarely. torch.testing.assert_close(x_dequant_mean, x, atol=1e-4, rtol=1e-4) - @parametrize("device", _DEVICES) @parametrize("leading_dims", [(), (2,), (2, 4)]) @parametrize("bias", [False, True]) + @parametrize("device", _DEVICES) def test_int8_linear_forward(self, leading_dims, bias, device): embed_dim = 32 @@ -72,6 +72,31 @@ def test_int8_linear_backward(self, device): for p_fp32, p_int8 in zip(model_fp32.parameters(), model_int8.parameters()): torch.testing.assert_close(p_fp32.grad, p_int8.grad, atol=1e-3, rtol=1e-2) + @parametrize("bias", [False, True]) + @parametrize("device", _DEVICES) + def test_int8_linear_compile(self, bias, device): + bsize = 4 + embed_dim = 32 + n_classes = 10 + + linear = nn.Linear(embed_dim, n_classes, bias=bias, device=device) + quantize_(linear, int8_weight_only_quantized_training()) + linear_compiled = copy.deepcopy(linear) + linear_compiled.compile() + + inputs = torch.randn((bsize, embed_dim,), device=device) + labels = torch.randint(n_classes, size=(bsize,), device=device) + + out = linear(inputs) + out_compiled = linear_compiled(inputs) + torch.testing.assert_close(out, out_compiled, atol=1e-2, rtol=1e-2) + + F.cross_entropy(out, labels).backward() + F.cross_entropy(out_compiled, labels).backward() + + for p, p_compiled in zip(linear.parameters(), linear_compiled.parameters()): + torch.testing.assert_close(p.grad, p_compiled.grad) + @parametrize("device", _DEVICES) def test_int8_linear_training(self, device): bsize = 4 diff --git a/torchao/prototype/quantized_training/subclass.py b/torchao/prototype/quantized_training/subclass.py index 06d1f3bc5..264fc78aa 100644 --- a/torchao/prototype/quantized_training/subclass.py +++ b/torchao/prototype/quantized_training/subclass.py @@ -113,7 +113,7 @@ def backward(ctx, grad_output): input, weight = ctx.saved_tensors dinput = (grad_output * weight.scale) @ weight.int_data.to(grad_output.dtype) - dweight = grad_output.flatten(0, -2).T @ input.flatten(0, -2) + dweight = grad_output.view(-1, weight.shape[0]).T @ input.view(-1, weight.shape[1]) dbias = grad_output.sum(0) if ctx.bias else None return dinput, dweight, dbias