Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Benchmark] Run xetla streamk gemm in benchmark #2438

Merged
merged 3 commits into from
Oct 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 17 additions & 2 deletions benchmarks/triton_kernels_benchmark/gemm_streamk_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import triton.language as tl

import triton_kernels_benchmark as benchmark_suit
import xetla_kernel

if benchmark_suit.USE_IPEX_OPTION:
import intel_extension_for_pytorch # type: ignore # noqa: F401
Expand Down Expand Up @@ -253,9 +254,9 @@ def matmul(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor):
line_arg='provider',
# argument name whose value corresponds to a different line in the plot
# possible values for `line_arg``
line_vals=['triton'],
line_vals=['triton', 'xetla'],
# label name for the lines
line_names=['Triton'],
line_names=['Triton', 'XeTLA'],
# line styles
styles=[('green', '-'), ('green', '--'), ('blue', '-'), ('blue', '--')],
ylabel=['GB/s', 'TFlops'], # label name for the y-axis
Expand All @@ -281,6 +282,20 @@ def benchmark(M, N, K, provider):
_, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(triton_fn, n_warmup=10, n_repeat=10,
quantiles=quantiles,
kernel_name=['first_wave', 'full_tiles'])
elif provider == 'xetla':
c = torch.empty((M, N), device='xpu', dtype=torch.float32)
acc = torch.empty((M, N), device='xpu', dtype=torch.float32)
cnt = torch.empty((M, N), device='xpu', dtype=torch.int32)

name = f'gemm_streamk_shape_{M}_{K}_{N}'
func = getattr(xetla_kernel, name)
xetla_fn = lambda: func(a, b, c, acc, cnt)
torch_fn = lambda: torch.matmul(a, b).to(torch.float32)

# benchmark_suit.assert_close(xetla_fn(), torch_fn(), atol=1e-4, rtol=1.0, err_msg='xetla to torch')
_, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(
xetla_fn, n_warmup=10, n_repeat=10, quantiles=quantiles,
kernel_name='gpu::xetla::kernel::gemm_universal_t<dispatch_stream_k')
else:
raise NotImplementedError(f'Unsupported provider {provider}')

Expand Down
5 changes: 4 additions & 1 deletion benchmarks/xetla_kernel/python_main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,10 @@ PYBIND11_MODULE(xetla_kernel, m) {
&bf16_gemm<Test_4096x8x128x16384_row_row>, "bf16_gemm (XeTLA)");
m.def("gemm_shape_4096_8_16384_128",
&bf16_gemm<Test_4096x8x16384x128_row_row>, "bf16_gemm (XeTLA)");
// flash_attn_fwd
// gemm stream k
m.def("gemm_streamk_shape_3072_4096_3072", &bf16_stream_k_gemm,
"bf16_gemm_streamk (XeTLA)");
// flash_attn
m.def("flash_attn_causal_false", &flash_attn<false, false, false>,
"flash attn fwd (XeTLA)");
m.def("flash_attn_causal_true", &flash_attn<false, true, false>,
Expand Down
5 changes: 1 addition & 4 deletions benchmarks/xetla_kernel/stream_k_gemm/stream_k_gemm.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,6 @@ sycl::event stream_k_gemm_run(void *_A, void *_B, void *_C, void *_Acc,
using data_type_c = float;
using data_type_acc = float;

auto context = queue.get_info<sycl::info::queue::context>();
auto device = queue.get_info<sycl::info::queue::device>();

data_type_a *A = static_cast<data_type_a *>(_A);
data_type_b *B = static_cast<data_type_b *>(_B);
data_type_c *C = static_cast<data_type_c *>(_C);
Expand All @@ -52,7 +49,7 @@ sycl::event stream_k_gemm_run(void *_A, void *_B, void *_C, void *_Acc,
constexpr uint32_t sg_tile_k = 32;

// StreamK parameters - xecores available for stream_k dispatch
uint32_t avail_xecores = 32;
uint32_t avail_xecores = 64;

// Org the compute shape for sub-matrix
using tile_shape =
Expand Down