Skip to content

Commit

Permalink
Don't use fast_flush=False' as it seems to be deprecated
Browse files Browse the repository at this point in the history
Signed-off-by: Anatoly Myachev <[email protected]>
  • Loading branch information
anmyachev committed Sep 24, 2024
1 parent 9ad431a commit 144f81a
Show file tree
Hide file tree
Showing 8 changed files with 14 additions and 23 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ def benchmark(Z, H, N_CTX, D_HEAD, provider):
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(
lambda: torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=
False, scale=sm_scale), warmup=10, rep=10,
quantiles=quantiles, fast_flush=False)
quantiles=quantiles)

elif provider == 'triton':
triton_fn = lambda: forward(q, k, v, causal, sm_scale)
Expand All @@ -240,8 +240,7 @@ def benchmark(Z, H, N_CTX, D_HEAD, provider):
), attn_mask=None, dropout_p=0.0, is_causal=False, scale=sm_scale).to(torch.float32)
atol = 1e-1 if N_CTX == 16384 else 1e-2
benchmark_suit.assert_close(triton_fn(), torch_fn(), atol=atol, rtol=1e-3, err_msg='triton to torch')
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(triton_fn, warmup=10, rep=10, quantiles=quantiles,
fast_flush=False)
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(triton_fn, warmup=10, rep=10, quantiles=quantiles)

elif provider == 'xetla':
module_name = f'flash_attn_causal_{causal}'.lower()
Expand All @@ -256,8 +255,7 @@ def benchmark(Z, H, N_CTX, D_HEAD, provider):
l = torch.empty((size_ml, ), device='xpu', dtype=torch.float)

xetla_fn = lambda: func(q, k, v, out, dropout_mask, bias, m, l, Z, H, D_HEAD, N_CTX, N_CTX, sm_scale)
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(xetla_fn, warmup=10, rep=10, quantiles=quantiles,
fast_flush=False)
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(xetla_fn, warmup=10, rep=10, quantiles=quantiles)

else:
raise NotImplementedError(f'Unsupported provider {provider}')
Expand Down
8 changes: 3 additions & 5 deletions benchmarks/triton_kernels_benchmark/gemm_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,14 +254,13 @@ def benchmark(B, M, N, K, provider):

if provider == 'onednn':
_, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(lambda: torch.matmul(a, b), warmup=10, rep=10,
quantiles=quantiles, fast_flush=False)
quantiles=quantiles)
elif provider == 'triton':
triton_fn = lambda: matmul(a, b)
torch_fn = lambda: torch.matmul(a, b).to(torch.float32)
rtol = 1e-2 if a.dtype == torch.bfloat16 else 1e-3
benchmark_suit.assert_close(triton_fn(), torch_fn(), atol=1e-4, rtol=rtol, err_msg='triton to torch')
_, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(triton_fn, warmup=10, rep=10, quantiles=quantiles,
fast_flush=False)
_, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(triton_fn, warmup=10, rep=10, quantiles=quantiles)
elif provider == 'xetla':
if B == 1:
c = torch.empty((M, N), device='xpu', dtype=torch.float32)
Expand All @@ -276,8 +275,7 @@ def benchmark(B, M, N, K, provider):
xetla_fn = lambda: func(a, b, c, acc, cnt)
torch_fn = lambda: torch.matmul(a, b).to(torch.float32)
# benchmark_suit.assert_close(xetla_fn(), torch_fn(), atol=1e-4, rtol=1.0, err_msg='xetla to torch')
_, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(xetla_fn, warmup=10, rep=10, quantiles=quantiles,
fast_flush=False)
_, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(xetla_fn, warmup=10, rep=10, quantiles=quantiles)
else:
raise NotImplementedError(f'Unsupported provider {provider}')

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -271,8 +271,7 @@ def benchmark(B, M, N, K, provider):
torch_fn = lambda: torch.matmul(a, b).to(torch.float32) + d
rtol = 1e-2 if a.dtype == torch.bfloat16 else 1e-3
benchmark_suit.assert_close(triton_fn(), torch_fn(), atol=1e-4, rtol=rtol, err_msg='triton to torch')
_, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(triton_fn, warmup=10, rep=10, quantiles=quantiles,
fast_flush=False)
_, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(triton_fn, warmup=10, rep=10, quantiles=quantiles)
else:
raise NotImplementedError(f'Unsupported provider {provider}')

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -273,8 +273,7 @@ def benchmark(B, M, N, K, provider):
torch_fn = lambda: torch.nn.functional.gelu(torch.matmul(a, b).to(torch.float32))
rtol = 1e-2 if a.dtype == torch.bfloat16 else 1e-3
benchmark_suit.assert_close(triton_fn(), torch_fn(), atol=1e-4, rtol=rtol, err_msg='triton to torch')
_, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(triton_fn, warmup=10, rep=10, quantiles=quantiles,
fast_flush=False)
_, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(triton_fn, warmup=10, rep=10, quantiles=quantiles)
else:
raise NotImplementedError(f'Unsupported provider {provider}')

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -261,8 +261,7 @@ def benchmark(B, M, N, K, provider):
torch_fn = lambda: torch.matmul(torch.exp(a), b).to(torch.float32)
rtol = 1e-2 if a.dtype == torch.bfloat16 else 1e-3
benchmark_suit.assert_close(triton_fn(), torch_fn(), atol=1e-4, rtol=rtol, err_msg='triton to torch')
_, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(triton_fn, warmup=10, rep=10, quantiles=quantiles,
fast_flush=False)
_, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(triton_fn, warmup=10, rep=10, quantiles=quantiles)
else:
raise NotImplementedError(f'Unsupported provider {provider}')

Expand Down
5 changes: 2 additions & 3 deletions benchmarks/triton_kernels_benchmark/gemm_splitk_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,14 +158,13 @@ def benchmark(M, N, K, provider):

if provider == 'onednn':
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(lambda: torch.matmul(a, b), warmup=10, rep=10,
quantiles=quantiles, fast_flush=False)
quantiles=quantiles)
elif provider == 'triton':
triton_fn = lambda: matmul(a, b)
torch_fn = lambda: torch.matmul(a, b).to(torch.float32)
rtol = 1e-2 if a.dtype == torch.bfloat16 else 1e-3
benchmark_suit.assert_close(triton_fn(), torch_fn(), atol=1e-4, rtol=rtol, err_msg='triton to torch')
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(triton_fn, warmup=10, rep=10, quantiles=quantiles,
fast_flush=False)
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(triton_fn, warmup=10, rep=10, quantiles=quantiles)
else:
raise NotImplementedError(f'Unsupported provider {provider}')

Expand Down
5 changes: 2 additions & 3 deletions benchmarks/triton_kernels_benchmark/gemm_streamk_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,13 +274,12 @@ def benchmark(M, N, K, provider):

if provider == 'onednn':
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(lambda: torch.matmul(a, b), warmup=10, rep=10,
quantiles=quantiles, fast_flush=False)
quantiles=quantiles)
elif provider == 'triton':
triton_fn = lambda: matmul(a, b)
torch_fn = lambda: torch.matmul(a, b).to(torch.float32)
benchmark_suit.assert_close(triton_fn(), torch_fn(), atol=1e-4, rtol=1e-2, err_msg='triton to torch')
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(triton_fn, warmup=10, rep=10, quantiles=quantiles,
fast_flush=False)
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(triton_fn, warmup=10, rep=10, quantiles=quantiles)
else:
raise NotImplementedError(f'Unsupported provider {provider}')

Expand Down
2 changes: 1 addition & 1 deletion benchmarks/triton_kernels_benchmark/prefix_sums.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def benchmark(M, N, AXIS, provider):

if provider == 'triton':
triton_fn = lambda: scan_kernel[(1, )](x, BLOCK_SIZE_M=M, BLOCK_SIZE_N=N, AXIS=AXIS)
_, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(triton_fn, quantiles=quantiles, fast_flush=False)
_, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(triton_fn, quantiles=quantiles)
else:
raise NotImplementedError(f'Unsupported provider {provider}')

Expand Down

0 comments on commit 144f81a

Please sign in to comment.