Skip to content

Commit

Permalink
add compile test (things are crashing)
Browse files Browse the repository at this point in the history
  • Loading branch information
gau-nernst committed Aug 10, 2024
1 parent ff69121 commit 45342ba
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 2 deletions.
27 changes: 26 additions & 1 deletion test/prototype/test_quantized_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,9 @@ def test_int8_stochastic_rounding(self, device):
# due to the statistical nature, this assertion may still fail, though very rarely.
torch.testing.assert_close(x_dequant_mean, x, atol=1e-4, rtol=1e-4)

@parametrize("device", _DEVICES)
@parametrize("leading_dims", [(), (2,), (2, 4)])
@parametrize("bias", [False, True])
@parametrize("device", _DEVICES)
def test_int8_linear_forward(self, leading_dims, bias, device):
embed_dim = 32

Expand Down Expand Up @@ -72,6 +72,31 @@ def test_int8_linear_backward(self, device):
for p_fp32, p_int8 in zip(model_fp32.parameters(), model_int8.parameters()):
torch.testing.assert_close(p_fp32.grad, p_int8.grad, atol=1e-3, rtol=1e-2)

@parametrize("bias", [False, True])
@parametrize("device", _DEVICES)
def test_int8_linear_compile(self, bias, device):
bsize = 4
embed_dim = 32
n_classes = 10

linear = nn.Linear(embed_dim, n_classes, bias=bias, device=device)
quantize_(linear, int8_weight_only_quantized_training())
linear_compiled = copy.deepcopy(linear)
linear_compiled.compile()

inputs = torch.randn((bsize, embed_dim,), device=device)
labels = torch.randint(n_classes, size=(bsize,), device=device)

out = linear(inputs)
out_compiled = linear_compiled(inputs)
torch.testing.assert_close(out, out_compiled, atol=1e-2, rtol=1e-2)

F.cross_entropy(out, labels).backward()
F.cross_entropy(out_compiled, labels).backward()

for p, p_compiled in zip(linear.parameters(), linear_compiled.parameters()):
torch.testing.assert_close(p.grad, p_compiled.grad)

@parametrize("device", _DEVICES)
def test_int8_linear_training(self, device):
bsize = 4
Expand Down
2 changes: 1 addition & 1 deletion torchao/prototype/quantized_training/subclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def backward(ctx, grad_output):
input, weight = ctx.saved_tensors

dinput = (grad_output * weight.scale) @ weight.int_data.to(grad_output.dtype)
dweight = grad_output.flatten(0, -2).T @ input.flatten(0, -2)
dweight = grad_output.view(-1, weight.shape[0]).T @ input.view(-1, weight.shape[1])
dbias = grad_output.sum(0) if ctx.bias else None
return dinput, dweight, dbias

Expand Down

0 comments on commit 45342ba

Please sign in to comment.