Skip to content

Commit

Permalink
rename splitk code to not mention float8, try 2
Browse files Browse the repository at this point in the history
Differential Revision: D59977582

Pull Request resolved: pytorch#529
  • Loading branch information
vkuzo authored Jul 19, 2024
1 parent 66b5213 commit 460e6db
Show file tree
Hide file tree
Showing 3 changed files with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions test/dtypes/test_fp8.py → test/prototype/test_splitk.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from torchao.utils import TORCH_VERSION_AFTER_2_4

try:
from torchao.prototype.fp8 import gemm_split_k, to_float8
from torchao.prototype.splitk import gemm_split_k, to_float8
triton_available = True
except ImportError:
triton_available = False
Expand Down Expand Up @@ -38,7 +38,7 @@ def test_gemm_split_k(self):
x_fp8, x_inv_s = to_float8(x, dtype=qdtype)
w_fp8, w_inv_s = to_float8(w, dtype=qdtype)

y_torch, _ = torch._scaled_mm(x_fp8, w_fp8.t(), out_dtype=dtype, scale_a=x_inv_s, scale_b=w_inv_s)
y_torch = torch._scaled_mm(x_fp8, w_fp8.t(), out_dtype=dtype, scale_a=x_inv_s, scale_b=w_inv_s)
y_triton = gemm_split_k(x_fp8, w_fp8.t(), scale_a=x_inv_s.item(), scale_b=w_inv_s.item())
y_fp16 = torch.nn.functional.linear(x, w)

Expand Down
File renamed without changes.
File renamed without changes.

0 comments on commit 460e6db

Please sign in to comment.