diff --git a/csrc/gpu/aten/operators/xetla/kernels/SDP/ifmha_forward.h b/csrc/gpu/aten/operators/xetla/kernels/SDP/ifmha_forward.h index 64a5fec93..940972b52 100644 --- a/csrc/gpu/aten/operators/xetla/kernels/SDP/ifmha_forward.h +++ b/csrc/gpu/aten/operators/xetla/kernels/SDP/ifmha_forward.h @@ -1101,19 +1101,26 @@ class ifmha_forward_t { uint32_t startT, uint32_t endT, matSij_t& matAcc) { + if (startT + ctx.sg_idx * kSgBc >= endT) { + return; + } + base_offset += (startT + ctx.sg_idx * kSgBc); - constexpr int simd_lanes = kSgBc > 32 ? 32 : 16; + + constexpr int simd_lanes = kSgBc >= 32 ? 32 : 16; static_assert(kSgBc % simd_lanes == 0); - constexpr int loops = kSgBc / simd_lanes; xetla_vector offsets = xetla_vector_gen(0, 1); offsets *= sizeof(scalar_t); offsets += (base_offset * sizeof(scalar_t)); #pragma unroll - for (int i = 0; i < loops; ++i) { + for (int i = 0; i < kSgBc / simd_lanes; ++i) { offsets += i * simd_lanes * sizeof(scalar_t); + xetla_vector seq = + xetla_vector_gen(i * simd_lanes, 1); + xetla_mask mask = seq < (endT - startT - ctx.sg_idx * kSgBc); matAcc.reg.xetla_select(i * simd_lanes) = xetla_load_global< scalar_t, @@ -1121,7 +1128,7 @@ class ifmha_forward_t { data_size::default_size, cache_hint::cached, cache_hint::cached, - simd_lanes>(ptr, offsets); + simd_lanes>(ptr, offsets, mask); } } diff --git a/intel_extension_for_pytorch/transformers/models/xpu/optimize_transformers/modules/transformer_modules/Attention.py b/intel_extension_for_pytorch/transformers/models/xpu/optimize_transformers/modules/transformer_modules/Attention.py index 20b38c38e..523b98df9 100644 --- a/intel_extension_for_pytorch/transformers/models/xpu/optimize_transformers/modules/transformer_modules/Attention.py +++ b/intel_extension_for_pytorch/transformers/models/xpu/optimize_transformers/modules/transformer_modules/Attention.py @@ -487,6 +487,10 @@ def compute_sdp( is_causal, ) elif self.is_beam_search(): + if attention_mask is not None: + attention_mask = attention_mask.repeat( + 1, int(query.shape[1] / attention_mask.shape[1]), 1, 1 + ).contiguous() return self.sdp_2nd2last_beam_search( query, key,