Skip to content

Commit

Permalink
Use cpu version of torch sdpa until xpu version is fixed
Browse files Browse the repository at this point in the history
Signed-off-by: Anatoly Myachev <[email protected]>
  • Loading branch information
anmyachev committed Sep 20, 2024
1 parent d76446a commit 8a380fd
Showing 1 changed file with 5 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -229,8 +229,11 @@ def benchmark(Z, H, N_CTX, D_HEAD, provider):
# FIXME: use torch sdpa for result check after https://github.com/intel/intel-xpu-backend-for-triton/issues/2042 fixed
torch_fn = lambda: torch.nn.functional.scaled_dot_product_attention(
q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False, scale=sm_scale).to(torch.float32)
atol = 1e-1 if N_CTX == 16384 else 1e-2
benchmark_suit.assert_close(triton_fn(), torch_fn(), atol=atol, rtol=1e-3, err_msg='triton to torch')
else:
torch_fn = lambda: torch.nn.functional.scaled_dot_product_attention(q.cpu(), k.cpu(), v.cpu(
), attn_mask=None, dropout_p=0.0, is_causal=False, scale=sm_scale).to(torch.float32)
atol = 1e-1 if N_CTX == 16384 else 1e-2
benchmark_suit.assert_close(triton_fn(), torch_fn(), atol=atol, rtol=1e-3, err_msg='triton to torch')
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(triton_fn, warmup=10, rep=10, quantiles=quantiles,
fast_flush=False)

Expand Down

0 comments on commit 8a380fd

Please sign in to comment.