Skip to content

Commit

Permalink
Fix upstream profiler for several kernels (#2498)
Browse files Browse the repository at this point in the history
Signed-off-by: Anatoly Myachev <[email protected]>
  • Loading branch information
anmyachev authored Oct 17, 2024
1 parent f4fdd8f commit 700abe3
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 7 deletions.
10 changes: 6 additions & 4 deletions benchmarks/triton_kernels_benchmark/benchmark_testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,16 +213,18 @@ def do_bench_upstream_pytorch_profiler(fn, n_warmup=25, n_repeat=100, grad_to_no

function_events = prof.events()

functions = []
all_functions = []
if isinstance(kernel_name, str):
kernel_name = [kernel_name]
for ker_name in kernel_name:
functions.extend(list(filter(lambda x: x.name.startswith(ker_name), function_events))) # pylint: disable=cell-var-from-loop
functions = list(filter(lambda x: x.name.startswith(ker_name), function_events)) # pylint: disable=cell-var-from-loop
assert len(functions) == n_repeat, f"the profiling number for kernel: '{ker_name}' not match, {len(functions)}"
all_functions.append(functions)
# profiling_func_filter = filter(lambda x: x.name.startswith("__profile_kernel_of_func"), function_events)

assert len(functions) == n_repeat, f"the profiling number not match, {len(functions)}"
# Make the time to the milliseconds.
times = torch.tensor([f.self_device_time_total * 1e-3 for f in functions], dtype=torch.float)
times = torch.tensor([sum(map(lambda elem: elem.self_device_time_total, f)) * 1e-3 for f in zip(*all_functions)],
dtype=torch.float)
return _summarize_statistics(times, quantiles, return_mode)


Expand Down
5 changes: 2 additions & 3 deletions benchmarks/triton_kernels_benchmark/gemm_streamk_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,9 +293,8 @@ def benchmark(M, N, K, provider):
torch_fn = lambda: torch.matmul(a, b).to(torch.float32)

# benchmark_suit.assert_close(xetla_fn(), torch_fn(), atol=1e-4, rtol=1.0, err_msg='xetla to torch')
_, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(
xetla_fn, n_warmup=10, n_repeat=10, quantiles=quantiles,
kernel_name='gpu::xetla::kernel::gemm_universal_t<dispatch_stream_k')
_, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(xetla_fn, n_warmup=10, n_repeat=10,
quantiles=quantiles, kernel_name='stream_k_gemm_run')
else:
raise NotImplementedError(f'Unsupported provider {provider}')

Expand Down

0 comments on commit 700abe3

Please sign in to comment.