diff --git a/test/dtypes/test_fp8.py b/test/prototype/test_splitk.py similarity index 92% rename from test/dtypes/test_fp8.py rename to test/prototype/test_splitk.py index ae008fc91..a37dce91b 100644 --- a/test/dtypes/test_fp8.py +++ b/test/prototype/test_splitk.py @@ -10,7 +10,7 @@ from torchao.utils import TORCH_VERSION_AFTER_2_4 try: - from torchao.prototype.fp8 import gemm_split_k, to_float8 + from torchao.prototype.splitk import gemm_split_k, to_float8 triton_available = True except ImportError: triton_available = False @@ -38,7 +38,7 @@ def test_gemm_split_k(self): x_fp8, x_inv_s = to_float8(x, dtype=qdtype) w_fp8, w_inv_s = to_float8(w, dtype=qdtype) - y_torch, _ = torch._scaled_mm(x_fp8, w_fp8.t(), out_dtype=dtype, scale_a=x_inv_s, scale_b=w_inv_s) + y_torch = torch._scaled_mm(x_fp8, w_fp8.t(), out_dtype=dtype, scale_a=x_inv_s, scale_b=w_inv_s) y_triton = gemm_split_k(x_fp8, w_fp8.t(), scale_a=x_inv_s.item(), scale_b=w_inv_s.item()) y_fp16 = torch.nn.functional.linear(x, w) diff --git a/torchao/prototype/fp8/__init__.py b/torchao/prototype/splitk/__init__.py similarity index 100% rename from torchao/prototype/fp8/__init__.py rename to torchao/prototype/splitk/__init__.py diff --git a/torchao/prototype/fp8/splitk_gemm.py b/torchao/prototype/splitk/splitk_gemm.py similarity index 100% rename from torchao/prototype/fp8/splitk_gemm.py rename to torchao/prototype/splitk/splitk_gemm.py