Skip to content

Commit

Permalink
[2.1.40]Fix attn_mask for sdpa beam_search (#4557)
Browse files Browse the repository at this point in the history
* fix page fault issue, add mask in load to avoid out of bound access

* repeat attn_mask

* cover attn_mask is None

---------

Signed-off-by: zhuyuhua-v <[email protected]>
Co-authored-by: zhuwei <[email protected]>
  • Loading branch information
zhuyuhua-v and Wanzizhu authored Jul 26, 2024
1 parent eeb92d2 commit 80ed476
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 4 deletions.
15 changes: 11 additions & 4 deletions csrc/gpu/aten/operators/xetla/kernels/SDP/ifmha_forward.h
Original file line number Diff line number Diff line change
Expand Up @@ -1101,27 +1101,34 @@ class ifmha_forward_t {
uint32_t startT,
uint32_t endT,
matSij_t& matAcc) {
if (startT + ctx.sg_idx * kSgBc >= endT) {
return;
}

base_offset += (startT + ctx.sg_idx * kSgBc);
constexpr int simd_lanes = kSgBc > 32 ? 32 : 16;

constexpr int simd_lanes = kSgBc >= 32 ? 32 : 16;

static_assert(kSgBc % simd_lanes == 0);

constexpr int loops = kSgBc / simd_lanes;
xetla_vector<uint32_t, simd_lanes> offsets =
xetla_vector_gen<uint32_t, simd_lanes>(0, 1);
offsets *= sizeof(scalar_t);
offsets += (base_offset * sizeof(scalar_t));
#pragma unroll
for (int i = 0; i < loops; ++i) {
for (int i = 0; i < kSgBc / simd_lanes; ++i) {
offsets += i * simd_lanes * sizeof(scalar_t);
xetla_vector<uint32_t, simd_lanes> seq =
xetla_vector_gen<uint32_t, simd_lanes>(i * simd_lanes, 1);
xetla_mask<simd_lanes> mask = seq < (endT - startT - ctx.sg_idx * kSgBc);
matAcc.reg.xetla_select<simd_lanes, 1>(i * simd_lanes) =
xetla_load_global<
scalar_t,
1,
data_size::default_size,
cache_hint::cached,
cache_hint::cached,
simd_lanes>(ptr, offsets);
simd_lanes>(ptr, offsets, mask);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -487,6 +487,10 @@ def compute_sdp(
is_causal,
)
elif self.is_beam_search():
if attention_mask is not None:
attention_mask = attention_mask.repeat(
1, int(query.shape[1] / attention_mask.shape[1]), 1, 1
).contiguous()
return self.sdp_2nd2last_beam_search(
query,
key,
Expand Down

0 comments on commit 80ed476

Please sign in to comment.