Skip to content

Commit

Permalink
Support int_scaled_mm on CPU (pytorch#121)
Browse files Browse the repository at this point in the history
  • Loading branch information
Xia-Weiwen authored Apr 5, 2024
1 parent 17c2024 commit d4b112f
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 3 deletions.
8 changes: 5 additions & 3 deletions test/kernel/test_autotuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,13 +52,15 @@ def test_int_mm(self, device, dtype):
@parameterized.expand(
[
("cuda", torch.bfloat16),
# TODO: ("cpu", torch.bfloat16),
("cpu", torch.bfloat16),
("cuda", torch.float16),
# TODO: ("cpu", torch.float16),
("cpu", torch.float16),
]
)
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
def test_int_scaled_mm(self, device, dtype):
if device == "cuda" and not torch.cuda.is_available():
self.skipTest(f"{device} not available")

from torchao.kernel import intmm

dtype = torch.bfloat16
Expand Down
6 changes: 6 additions & 0 deletions torchao/kernel/intmm_triton.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,3 +356,9 @@ def int_scaled_matmul_cuda(a, b, scales1):
int_scaled_matmul_kernel, [a, b, scales1, c], int8_mm_kernel_configs
)
return int_scaled_matmul_kernel(a, b, scales1, c, best_config)


@torch.library.impl(lib, "int_scaled_matmul", "CPU")
def int_scaled_matmul_cpu(a, b, scales1):
c = torch._int_mm(a, b)
return c.to(scales1.dtype) * scales1

0 comments on commit d4b112f

Please sign in to comment.