diff --git a/.github/workflows/triton-benchmarks.yml b/.github/workflows/triton-benchmarks.yml index 3d55c5974..d238e3dc7 100644 --- a/.github/workflows/triton-benchmarks.yml +++ b/.github/workflows/triton-benchmarks.yml @@ -167,6 +167,7 @@ jobs: source ../../scripts/capture-hw-details.sh python ../../scripts/build_report.py $REPORTS/matmul-performance-bt.csv $REPORTS/gemm-bt-triton-report.csv --benchmark gemm-bt --compiler triton --param_cols "B,M,K,N" --tflops_col Triton-TFlops --hbm_col "Triton-GB/s" --tag $TAG + python ../../scripts/build_report.py $REPORTS/matmul-performance-bt.csv $REPORTS/gemm-bt-triton-report.csv --benchmark gemm-bt --compiler onednn --param_cols "B,M,K,N" --tflops_col onednn-TFlops --hbm_col "onednn-GB/s" --tag $TAG - name: Run Triton GEMM (A^t@B) kernel benchmark if: ${{ steps.install.outcome == 'success' && !cancelled() }} @@ -177,6 +178,7 @@ jobs: source ../../scripts/capture-hw-details.sh python ../../scripts/build_report.py $REPORTS/matmul-performance-at.csv $REPORTS/gemm-at-triton-report.csv --benchmark gemm-at --compiler triton --param_cols "B,M,K,N" --tflops_col Triton-TFlops --hbm_col "Triton-GB/s" --tag $TAG + python ../../scripts/build_report.py $REPORTS/matmul-performance-at.csv $REPORTS/gemm-at-triton-report.csv --benchmark gemm-at --compiler onednn --param_cols "B,M,K,N" --tflops_col onednn-TFlops --hbm_col "onednn-GB/s" --tag $TAG - name: Run Triton GEMM (stream-k) kernel benchmark if: ${{ steps.install.outcome == 'success' && !cancelled() }} diff --git a/benchmarks/triton_kernels_benchmark/gemm_benchmark.py b/benchmarks/triton_kernels_benchmark/gemm_benchmark.py index 6bb87b871..f54ef2abd 100644 --- a/benchmarks/triton_kernels_benchmark/gemm_benchmark.py +++ b/benchmarks/triton_kernels_benchmark/gemm_benchmark.py @@ -13,6 +13,8 @@ import triton.language as tl import triton_kernels_benchmark as benchmark_suit +from triton_kernels_benchmark.benchmark_testing import do_bench_elapsed_time, BENCHMARKING_METHOD + import xetla_kernel if benchmark_suit.USE_IPEX_OPTION: @@ -250,9 +252,9 @@ def get_shapes(B, M, N, K, transpose_a, transpose_b): line_arg='provider', # argument name whose value corresponds to a different line in the plot # possible values for `line_arg`` - line_vals=['triton'] + (['xetla'] if use_xetla else []), + line_vals=['triton'] + (['xetla'] if use_xetla else ['onednn']), # label name for the lines - line_names=['Triton'] + (['XeTLA'] if use_xetla else []), + line_names=['Triton'] + (['XeTLA'] if use_xetla else ['onednn']), # line styles styles=[('green', '-'), ('green', '--'), ('blue', '-'), ('blue', '--')], ylabel=['GB/s', 'TFlops'], # label name for the y-axis @@ -277,8 +279,12 @@ def benchmark(B, M, N, K, provider): torch_b = torch.transpose(torch_b, -2, -1) if provider == 'onednn': - _, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(lambda: torch.matmul(torch_a, torch_b), warmup=10, - rep=10, quantiles=quantiles) + do_bench = benchmark_suit.do_bench + if BENCHMARKING_METHOD == 'PYTORCH_LEGACY_PROFILER_USING_IPEX': + # Legacy profiler shows ~6000TFLOPS GeoMean for onednn measurements, so use more reliable method + do_bench = do_bench_elapsed_time + _, min_ms, max_ms, mean_ms, cv = do_bench(lambda: torch.matmul(torch_a, torch_b), warmup=10, rep=10, + quantiles=quantiles, kernel_name='gemm_kernel') elif provider == 'triton': assert len(a.shape) == len(b.shape), 'Incompatible sizes' if len(a.shape) == 3: