Skip to content

Commit

Permalink
[Benchmark] Run xetla streamk gemm in benchmark (#2438)
Browse files Browse the repository at this point in the history
  • Loading branch information
ESI-SYD authored Oct 15, 2024
1 parent 35130dc commit 6018c7b
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 7 deletions.
19 changes: 17 additions & 2 deletions benchmarks/triton_kernels_benchmark/gemm_streamk_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import triton.language as tl

import triton_kernels_benchmark as benchmark_suit
import xetla_kernel

if benchmark_suit.USE_IPEX_OPTION:
import intel_extension_for_pytorch # type: ignore # noqa: F401
Expand Down Expand Up @@ -253,9 +254,9 @@ def matmul(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor):
line_arg='provider',
# argument name whose value corresponds to a different line in the plot
# possible values for `line_arg``
line_vals=['triton'],
line_vals=['triton', 'xetla'],
# label name for the lines
line_names=['Triton'],
line_names=['Triton', 'XeTLA'],
# line styles
styles=[('green', '-'), ('green', '--'), ('blue', '-'), ('blue', '--')],
ylabel=['GB/s', 'TFlops'], # label name for the y-axis
Expand All @@ -281,6 +282,20 @@ def benchmark(M, N, K, provider):
_, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(triton_fn, n_warmup=10, n_repeat=10,
quantiles=quantiles,
kernel_name=['first_wave', 'full_tiles'])
elif provider == 'xetla':
c = torch.empty((M, N), device='xpu', dtype=torch.float32)
acc = torch.empty((M, N), device='xpu', dtype=torch.float32)
cnt = torch.empty((M, N), device='xpu', dtype=torch.int32)

name = f'gemm_streamk_shape_{M}_{K}_{N}'
func = getattr(xetla_kernel, name)
xetla_fn = lambda: func(a, b, c, acc, cnt)
torch_fn = lambda: torch.matmul(a, b).to(torch.float32)

# benchmark_suit.assert_close(xetla_fn(), torch_fn(), atol=1e-4, rtol=1.0, err_msg='xetla to torch')
_, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(
xetla_fn, n_warmup=10, n_repeat=10, quantiles=quantiles,
kernel_name='gpu::xetla::kernel::gemm_universal_t<dispatch_stream_k')
else:
raise NotImplementedError(f'Unsupported provider {provider}')

Expand Down
5 changes: 4 additions & 1 deletion benchmarks/xetla_kernel/python_main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,10 @@ PYBIND11_MODULE(xetla_kernel, m) {
&bf16_gemm<Test_4096x8x128x16384_row_row>, "bf16_gemm (XeTLA)");
m.def("gemm_shape_4096_8_16384_128",
&bf16_gemm<Test_4096x8x16384x128_row_row>, "bf16_gemm (XeTLA)");
// flash_attn_fwd
// gemm stream k
m.def("gemm_streamk_shape_3072_4096_3072", &bf16_stream_k_gemm,
"bf16_gemm_streamk (XeTLA)");
// flash_attn
m.def("flash_attn_causal_false", &flash_attn<false, false, false>,
"flash attn fwd (XeTLA)");
m.def("flash_attn_causal_true", &flash_attn<false, true, false>,
Expand Down
5 changes: 1 addition & 4 deletions benchmarks/xetla_kernel/stream_k_gemm/stream_k_gemm.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,6 @@ sycl::event stream_k_gemm_run(void *_A, void *_B, void *_C, void *_Acc,
using data_type_c = float;
using data_type_acc = float;

auto context = queue.get_info<sycl::info::queue::context>();
auto device = queue.get_info<sycl::info::queue::device>();

data_type_a *A = static_cast<data_type_a *>(_A);
data_type_b *B = static_cast<data_type_b *>(_B);
data_type_c *C = static_cast<data_type_c *>(_C);
Expand All @@ -52,7 +49,7 @@ sycl::event stream_k_gemm_run(void *_A, void *_B, void *_C, void *_Acc,
constexpr uint32_t sg_tile_k = 32;

// StreamK parameters - xecores available for stream_k dispatch
uint32_t avail_xecores = 32;
uint32_t avail_xecores = 64;

// Org the compute shape for sub-matrix
using tile_shape =
Expand Down

0 comments on commit 6018c7b

Please sign in to comment.