Skip to content

Commit

Permalink
Relax make_range lowering
Browse files Browse the repository at this point in the history
  • Loading branch information
jopperm committed Oct 4, 2024
1 parent f63d13b commit 5b7ffa7
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -269,23 +269,18 @@ def benchmark(Z, H, N_CTX, D_HEAD, CAUSAL, provider):
quantiles=quantiles)

elif provider == 'triton':
# FIXME: remove below if condition when extend attention support for Causal = True done
# https://github.com/intel/intel-xpu-backend-for-triton/issues/1102
if os.environ.get('TRITON_INTEL_ADVANCED_PATH', '0') == '1' and CAUSAL and D_HEAD == 128:
min_ms, max_ms, mean, cv = (float('inf'), ) * 4
triton_fn = lambda: forward(q, k, v, CAUSAL, sm_scale)
if benchmark_suit.USE_IPEX_OPTION:
torch_fn = lambda: torch.nn.functional.scaled_dot_product_attention(
q, k, v, attn_mask=None, dropout_p=0.0, is_causal=CAUSAL, scale=sm_scale).to(torch.float32)
else:
triton_fn = lambda: forward(q, k, v, CAUSAL, sm_scale)
if benchmark_suit.USE_IPEX_OPTION:
torch_fn = lambda: torch.nn.functional.scaled_dot_product_attention(
q, k, v, attn_mask=None, dropout_p=0.0, is_causal=CAUSAL, scale=sm_scale).to(torch.float32)
else:
# FIXME: use torch sdpa for result check after https://github.com/intel/intel-xpu-backend-for-triton/issues/2042 fixed
torch_fn = lambda: torch.nn.functional.scaled_dot_product_attention(q.cpu(), k.cpu(), v.cpu(
), attn_mask=None, dropout_p=0.0, is_causal=CAUSAL, scale=sm_scale).to(torch.float32)
atol = 1e-1 if N_CTX == 16384 else 1e-2
benchmark_suit.assert_close(triton_fn(), torch_fn(), atol=atol, rtol=1e-3, err_msg='triton to torch')
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(triton_fn, warmup=10, rep=10, quantiles=quantiles,
kernel_name='_attn_fwd')
# FIXME: use torch sdpa for result check after https://github.com/intel/intel-xpu-backend-for-triton/issues/2042 fixed
torch_fn = lambda: torch.nn.functional.scaled_dot_product_attention(q.cpu(), k.cpu(), v.cpu(
), attn_mask=None, dropout_p=0.0, is_causal=CAUSAL, scale=sm_scale).to(torch.float32)
atol = 1e-1 if N_CTX == 16384 else 1e-2
benchmark_suit.assert_close(triton_fn(), torch_fn(), atol=atol, rtol=1e-3, err_msg='triton to torch')
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(triton_fn, warmup=10, rep=10, quantiles=quantiles,
kernel_name='_attn_fwd')

elif provider == 'xetla':
module_name = f'flash_attn_causal_{CAUSAL}'.lower()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1228,9 +1228,10 @@ void MatchTargetSizePass::transformMakeRangeOp(tt::MakeRangeOp op) {

unsigned start = op.getStart();
unsigned end = op.getEnd();
assert(start == 0 && end % subgroupSize == 0 && "Unsupported range");
assert(start == 0 && end <= subgroupSize ||
end % subgroupSize == 0 && "Unsupported range");

if (end == subgroupSize)
if (end <= subgroupSize)
// nothing to do
return;

Expand Down

0 comments on commit 5b7ffa7

Please sign in to comment.