Skip to content

Commit

Permalink
Merge pull request #10266 from nero-dv/dev
Browse files Browse the repository at this point in the history
Update sub_quadratic_attention.py
  • Loading branch information
AUTOMATIC1111 authored May 11, 2023
2 parents 8aa87c5 + c8732df commit c9e5b92
Showing 1 changed file with 15 additions and 6 deletions.
21 changes: 15 additions & 6 deletions modules/sub_quadratic_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,13 +202,22 @@ def get_query_chunk(chunk_idx: int) -> Tensor:
value=value,
)

# TODO: maybe we should use torch.empty_like(query) to allocate storage in-advance,
# and pass slices to be mutated, instead of torch.cat()ing the returned slices
res = torch.cat([
compute_query_chunk_attn(
# slices of res tensor are mutable, modifications made
# to the slices will affect the original tensor.
# if output of compute_query_chunk_attn function has same number of
# dimensions as input query tensor, we initialize tensor like this:
num_query_chunks = int(np.ceil(q_tokens / query_chunk_size))
query_shape = get_query_chunk(0).shape
res_shape = (query_shape[0], query_shape[1] * num_query_chunks, *query_shape[2:])
res_dtype = get_query_chunk(0).dtype
res = torch.zeros(res_shape, dtype=res_dtype)

for i in range(num_query_chunks):
attn_scores = compute_query_chunk_attn(
query=get_query_chunk(i * query_chunk_size),
key=key,
value=value,
) for i in range(math.ceil(q_tokens / query_chunk_size))
], dim=1)
)
res[:, i * query_chunk_size:(i + 1) * query_chunk_size, :] = attn_scores

return res

0 comments on commit c9e5b92

Please sign in to comment.