From 1efcc81e93b6f6627cb841f0163e10946b976a93 Mon Sep 17 00:00:00 2001 From: ESI-SYD Date: Thu, 10 Oct 2024 05:24:48 +0000 Subject: [PATCH 1/2] Add xetla streamk gemm into benchmark --- .../gemm_streamk_benchmark.py | 19 +++++++++++++++++-- benchmarks/xetla_kernel/python_main.cpp | 5 ++++- .../stream_k_gemm/stream_k_gemm.h | 5 +---- 3 files changed, 22 insertions(+), 7 deletions(-) diff --git a/benchmarks/triton_kernels_benchmark/gemm_streamk_benchmark.py b/benchmarks/triton_kernels_benchmark/gemm_streamk_benchmark.py index e8179cd45a..97b4af15f0 100644 --- a/benchmarks/triton_kernels_benchmark/gemm_streamk_benchmark.py +++ b/benchmarks/triton_kernels_benchmark/gemm_streamk_benchmark.py @@ -10,6 +10,7 @@ import triton.language as tl import triton_kernels_benchmark as benchmark_suit +import xetla_kernel if benchmark_suit.USE_IPEX_OPTION: import intel_extension_for_pytorch # type: ignore # noqa: F401 @@ -253,9 +254,9 @@ def matmul(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor): line_arg='provider', # argument name whose value corresponds to a different line in the plot # possible values for `line_arg`` - line_vals=['triton'], + line_vals=['triton', 'xetla'], # label name for the lines - line_names=['Triton'], + line_names=['Triton', 'XeTLA'], # line styles styles=[('green', '-'), ('green', '--'), ('blue', '-'), ('blue', '--')], ylabel=['GB/s', 'TFlops'], # label name for the y-axis @@ -280,6 +281,20 @@ def benchmark(M, N, K, provider): benchmark_suit.assert_close(triton_fn(), torch_fn(), atol=1e-4, rtol=1e-2, err_msg='triton to torch') _, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(triton_fn, warmup=10, rep=10, quantiles=quantiles, kernel_name=['first_wave', 'full_tiles']) + elif provider == 'xetla': + c = torch.empty((M, N), device='xpu', dtype=torch.float32) + acc = torch.empty((M, N), device='xpu', dtype=torch.float32) + cnt = torch.empty((M, N), device='xpu', dtype=torch.int32) + + name = f'gemm_streamk_shape_{M}_{K}_{N}' + func = getattr(xetla_kernel, name) + xetla_fn = lambda: func(a, b, c, acc, cnt) + torch_fn = lambda: torch.matmul(a, b).to(torch.float32) + + # benchmark_suit.assert_close(xetla_fn(), torch_fn(), atol=1e-4, rtol=1.0, err_msg='xetla to torch') + _, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench( + xetla_fn, warmup=10, rep=10, quantiles=quantiles, + kernel_name='gpu::xetla::kernel::gemm_universal_t, "bf16_gemm (XeTLA)"); m.def("gemm_shape_4096_8_16384_128", &bf16_gemm, "bf16_gemm (XeTLA)"); - // flash_attn_fwd + // gemm stream k + m.def("gemm_streamk_shape_3072_4096_3072", &bf16_stream_k_gemm, + "bf16_gemm_streamk (XeTLA)"); + // flash_attn m.def("flash_attn_causal_false", &flash_attn, "flash attn fwd (XeTLA)"); m.def("flash_attn_causal_true", &flash_attn, diff --git a/benchmarks/xetla_kernel/stream_k_gemm/stream_k_gemm.h b/benchmarks/xetla_kernel/stream_k_gemm/stream_k_gemm.h index 58b7937448..ac91124323 100644 --- a/benchmarks/xetla_kernel/stream_k_gemm/stream_k_gemm.h +++ b/benchmarks/xetla_kernel/stream_k_gemm/stream_k_gemm.h @@ -36,9 +36,6 @@ sycl::event stream_k_gemm_run(void *_A, void *_B, void *_C, void *_Acc, using data_type_c = float; using data_type_acc = float; - auto context = queue.get_info(); - auto device = queue.get_info(); - data_type_a *A = static_cast(_A); data_type_b *B = static_cast(_B); data_type_c *C = static_cast(_C); @@ -52,7 +49,7 @@ sycl::event stream_k_gemm_run(void *_A, void *_B, void *_C, void *_Acc, constexpr uint32_t sg_tile_k = 32; // StreamK parameters - xecores available for stream_k dispatch - uint32_t avail_xecores = 32; + uint32_t avail_xecores = 64; // Org the compute shape for sub-matrix using tile_shape = From 27572e689a4fa30db986c34c950c740f30e1b889 Mon Sep 17 00:00:00 2001 From: Whitney Tsang Date: Tue, 15 Oct 2024 19:01:27 +0000 Subject: [PATCH 2/2] Fix merge --- benchmarks/triton_kernels_benchmark/gemm_streamk_benchmark.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/benchmarks/triton_kernels_benchmark/gemm_streamk_benchmark.py b/benchmarks/triton_kernels_benchmark/gemm_streamk_benchmark.py index a46212f973..46f8bce543 100644 --- a/benchmarks/triton_kernels_benchmark/gemm_streamk_benchmark.py +++ b/benchmarks/triton_kernels_benchmark/gemm_streamk_benchmark.py @@ -294,7 +294,7 @@ def benchmark(M, N, K, provider): # benchmark_suit.assert_close(xetla_fn(), torch_fn(), atol=1e-4, rtol=1.0, err_msg='xetla to torch') _, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench( - xetla_fn, warmup=10, rep=10, quantiles=quantiles, + xetla_fn, n_warmup=10, n_repeat=10, quantiles=quantiles, kernel_name='gpu::xetla::kernel::gemm_universal_t