diff --git a/test/kernel/test_autotuner.py b/test/kernel/test_autotuner.py index 33678995a..82fb11736 100644 --- a/test/kernel/test_autotuner.py +++ b/test/kernel/test_autotuner.py @@ -52,13 +52,15 @@ def test_int_mm(self, device, dtype): @parameterized.expand( [ ("cuda", torch.bfloat16), - # TODO: ("cpu", torch.bfloat16), + ("cpu", torch.bfloat16), ("cuda", torch.float16), - # TODO: ("cpu", torch.float16), + ("cpu", torch.float16), ] ) - @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") def test_int_scaled_mm(self, device, dtype): + if device == "cuda" and not torch.cuda.is_available(): + self.skipTest(f"{device} not available") + from torchao.kernel import intmm dtype = torch.bfloat16 diff --git a/torchao/kernel/intmm_triton.py b/torchao/kernel/intmm_triton.py index 4e84d9cd3..d10dac0ab 100644 --- a/torchao/kernel/intmm_triton.py +++ b/torchao/kernel/intmm_triton.py @@ -356,3 +356,9 @@ def int_scaled_matmul_cuda(a, b, scales1): int_scaled_matmul_kernel, [a, b, scales1, c], int8_mm_kernel_configs ) return int_scaled_matmul_kernel(a, b, scales1, c, best_config) + + +@torch.library.impl(lib, "int_scaled_matmul", "CPU") +def int_scaled_matmul_cpu(a, b, scales1): + c = torch._int_mm(a, b) + return c.to(scales1.dtype) * scales1