diff --git a/benchmarks/triton_kernels_benchmark/flash_attention_fwd_benchmark.py b/benchmarks/triton_kernels_benchmark/flash_attention_fwd_benchmark.py index 5cc9998b1..a45875dde 100644 --- a/benchmarks/triton_kernels_benchmark/flash_attention_fwd_benchmark.py +++ b/benchmarks/triton_kernels_benchmark/flash_attention_fwd_benchmark.py @@ -227,7 +227,7 @@ def benchmark(Z, H, N_CTX, D_HEAD, provider): _, min_ms, max_ms, mean, cv = benchmark_suit.do_bench( lambda: torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal= False, scale=sm_scale), warmup=10, rep=10, - quantiles=quantiles, fast_flush=False) + quantiles=quantiles) elif provider == 'triton': triton_fn = lambda: forward(q, k, v, causal, sm_scale) @@ -240,8 +240,7 @@ def benchmark(Z, H, N_CTX, D_HEAD, provider): ), attn_mask=None, dropout_p=0.0, is_causal=False, scale=sm_scale).to(torch.float32) atol = 1e-1 if N_CTX == 16384 else 1e-2 benchmark_suit.assert_close(triton_fn(), torch_fn(), atol=atol, rtol=1e-3, err_msg='triton to torch') - _, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(triton_fn, warmup=10, rep=10, quantiles=quantiles, - fast_flush=False) + _, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(triton_fn, warmup=10, rep=10, quantiles=quantiles) elif provider == 'xetla': module_name = f'flash_attn_causal_{causal}'.lower() @@ -256,8 +255,7 @@ def benchmark(Z, H, N_CTX, D_HEAD, provider): l = torch.empty((size_ml, ), device='xpu', dtype=torch.float) xetla_fn = lambda: func(q, k, v, out, dropout_mask, bias, m, l, Z, H, D_HEAD, N_CTX, N_CTX, sm_scale) - _, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(xetla_fn, warmup=10, rep=10, quantiles=quantiles, - fast_flush=False) + _, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(xetla_fn, warmup=10, rep=10, quantiles=quantiles) else: raise NotImplementedError(f'Unsupported provider {provider}') diff --git a/benchmarks/triton_kernels_benchmark/gemm_benchmark.py b/benchmarks/triton_kernels_benchmark/gemm_benchmark.py index bfc00eb01..159d05a69 100644 --- a/benchmarks/triton_kernels_benchmark/gemm_benchmark.py +++ b/benchmarks/triton_kernels_benchmark/gemm_benchmark.py @@ -254,14 +254,13 @@ def benchmark(B, M, N, K, provider): if provider == 'onednn': _, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(lambda: torch.matmul(a, b), warmup=10, rep=10, - quantiles=quantiles, fast_flush=False) + quantiles=quantiles) elif provider == 'triton': triton_fn = lambda: matmul(a, b) torch_fn = lambda: torch.matmul(a, b).to(torch.float32) rtol = 1e-2 if a.dtype == torch.bfloat16 else 1e-3 benchmark_suit.assert_close(triton_fn(), torch_fn(), atol=1e-4, rtol=rtol, err_msg='triton to torch') - _, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(triton_fn, warmup=10, rep=10, quantiles=quantiles, - fast_flush=False) + _, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(triton_fn, warmup=10, rep=10, quantiles=quantiles) elif provider == 'xetla': if B == 1: c = torch.empty((M, N), device='xpu', dtype=torch.float32) @@ -276,8 +275,7 @@ def benchmark(B, M, N, K, provider): xetla_fn = lambda: func(a, b, c, acc, cnt) torch_fn = lambda: torch.matmul(a, b).to(torch.float32) # benchmark_suit.assert_close(xetla_fn(), torch_fn(), atol=1e-4, rtol=1.0, err_msg='xetla to torch') - _, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(xetla_fn, warmup=10, rep=10, quantiles=quantiles, - fast_flush=False) + _, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(xetla_fn, warmup=10, rep=10, quantiles=quantiles) else: raise NotImplementedError(f'Unsupported provider {provider}') diff --git a/benchmarks/triton_kernels_benchmark/gemm_postop_addmatrix_benchmark.py b/benchmarks/triton_kernels_benchmark/gemm_postop_addmatrix_benchmark.py index 6add18b32..b431405ee 100644 --- a/benchmarks/triton_kernels_benchmark/gemm_postop_addmatrix_benchmark.py +++ b/benchmarks/triton_kernels_benchmark/gemm_postop_addmatrix_benchmark.py @@ -271,8 +271,7 @@ def benchmark(B, M, N, K, provider): torch_fn = lambda: torch.matmul(a, b).to(torch.float32) + d rtol = 1e-2 if a.dtype == torch.bfloat16 else 1e-3 benchmark_suit.assert_close(triton_fn(), torch_fn(), atol=1e-4, rtol=rtol, err_msg='triton to torch') - _, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(triton_fn, warmup=10, rep=10, quantiles=quantiles, - fast_flush=False) + _, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(triton_fn, warmup=10, rep=10, quantiles=quantiles) else: raise NotImplementedError(f'Unsupported provider {provider}') diff --git a/benchmarks/triton_kernels_benchmark/gemm_postop_gelu_benchmark.py b/benchmarks/triton_kernels_benchmark/gemm_postop_gelu_benchmark.py index 9953cf535..2904ed679 100644 --- a/benchmarks/triton_kernels_benchmark/gemm_postop_gelu_benchmark.py +++ b/benchmarks/triton_kernels_benchmark/gemm_postop_gelu_benchmark.py @@ -273,8 +273,7 @@ def benchmark(B, M, N, K, provider): torch_fn = lambda: torch.nn.functional.gelu(torch.matmul(a, b).to(torch.float32)) rtol = 1e-2 if a.dtype == torch.bfloat16 else 1e-3 benchmark_suit.assert_close(triton_fn(), torch_fn(), atol=1e-4, rtol=rtol, err_msg='triton to torch') - _, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(triton_fn, warmup=10, rep=10, quantiles=quantiles, - fast_flush=False) + _, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(triton_fn, warmup=10, rep=10, quantiles=quantiles) else: raise NotImplementedError(f'Unsupported provider {provider}') diff --git a/benchmarks/triton_kernels_benchmark/gemm_preop_exp_benchmark.py b/benchmarks/triton_kernels_benchmark/gemm_preop_exp_benchmark.py index 1ed4f8472..bff1b4399 100644 --- a/benchmarks/triton_kernels_benchmark/gemm_preop_exp_benchmark.py +++ b/benchmarks/triton_kernels_benchmark/gemm_preop_exp_benchmark.py @@ -261,8 +261,7 @@ def benchmark(B, M, N, K, provider): torch_fn = lambda: torch.matmul(torch.exp(a), b).to(torch.float32) rtol = 1e-2 if a.dtype == torch.bfloat16 else 1e-3 benchmark_suit.assert_close(triton_fn(), torch_fn(), atol=1e-4, rtol=rtol, err_msg='triton to torch') - _, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(triton_fn, warmup=10, rep=10, quantiles=quantiles, - fast_flush=False) + _, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(triton_fn, warmup=10, rep=10, quantiles=quantiles) else: raise NotImplementedError(f'Unsupported provider {provider}') diff --git a/benchmarks/triton_kernels_benchmark/gemm_splitk_benchmark.py b/benchmarks/triton_kernels_benchmark/gemm_splitk_benchmark.py index 5fa957f69..daba442eb 100644 --- a/benchmarks/triton_kernels_benchmark/gemm_splitk_benchmark.py +++ b/benchmarks/triton_kernels_benchmark/gemm_splitk_benchmark.py @@ -158,14 +158,13 @@ def benchmark(M, N, K, provider): if provider == 'onednn': _, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(lambda: torch.matmul(a, b), warmup=10, rep=10, - quantiles=quantiles, fast_flush=False) + quantiles=quantiles) elif provider == 'triton': triton_fn = lambda: matmul(a, b) torch_fn = lambda: torch.matmul(a, b).to(torch.float32) rtol = 1e-2 if a.dtype == torch.bfloat16 else 1e-3 benchmark_suit.assert_close(triton_fn(), torch_fn(), atol=1e-4, rtol=rtol, err_msg='triton to torch') - _, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(triton_fn, warmup=10, rep=10, quantiles=quantiles, - fast_flush=False) + _, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(triton_fn, warmup=10, rep=10, quantiles=quantiles) else: raise NotImplementedError(f'Unsupported provider {provider}') diff --git a/benchmarks/triton_kernels_benchmark/gemm_streamk_benchmark.py b/benchmarks/triton_kernels_benchmark/gemm_streamk_benchmark.py index 7eb0b651f..603659ded 100644 --- a/benchmarks/triton_kernels_benchmark/gemm_streamk_benchmark.py +++ b/benchmarks/triton_kernels_benchmark/gemm_streamk_benchmark.py @@ -274,13 +274,12 @@ def benchmark(M, N, K, provider): if provider == 'onednn': _, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(lambda: torch.matmul(a, b), warmup=10, rep=10, - quantiles=quantiles, fast_flush=False) + quantiles=quantiles) elif provider == 'triton': triton_fn = lambda: matmul(a, b) torch_fn = lambda: torch.matmul(a, b).to(torch.float32) benchmark_suit.assert_close(triton_fn(), torch_fn(), atol=1e-4, rtol=1e-2, err_msg='triton to torch') - _, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(triton_fn, warmup=10, rep=10, quantiles=quantiles, - fast_flush=False) + _, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(triton_fn, warmup=10, rep=10, quantiles=quantiles) else: raise NotImplementedError(f'Unsupported provider {provider}') diff --git a/benchmarks/triton_kernels_benchmark/prefix_sums.py b/benchmarks/triton_kernels_benchmark/prefix_sums.py index f3beb0707..bb3d2069f 100644 --- a/benchmarks/triton_kernels_benchmark/prefix_sums.py +++ b/benchmarks/triton_kernels_benchmark/prefix_sums.py @@ -44,7 +44,7 @@ def benchmark(M, N, AXIS, provider): if provider == 'triton': triton_fn = lambda: scan_kernel[(1, )](x, BLOCK_SIZE_M=M, BLOCK_SIZE_N=N, AXIS=AXIS) - _, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(triton_fn, quantiles=quantiles, fast_flush=False) + _, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(triton_fn, quantiles=quantiles) else: raise NotImplementedError(f'Unsupported provider {provider}')