Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Native support of torch.nn.functionnal.scaled_dot_product_attention #26557

Closed
younesbelkada opened this issue Oct 3, 2023 · 9 comments
Closed
Assignees

Comments

@younesbelkada
Copy link
Contributor

Feature request

PyTorch has released torch.nn.functionnal.scaled_dot_product_attention since its 2.0 version that supports more memory efficient attention computation

Official documentation here. Currently three implementations are available in that method, making it possible to dispatch the SDPA kernel to

  • C++ math implementation
  • Flash Attention 1
  • xformers memory efficient attention

In addition to that, in the next versions, PyTorch will add support for Flash Attention 2: pytorch/pytorch#105602 that is already available in the PyTorch nightlies.

SDPA makes model inference faster and more memory efficient, and supports multiple hardwares (CPU, GPU, CUDA, AMD...)

Users can already benefit from SDPA through the BetterTransformer API of optimum

# pip install optimum
model = model.to_bettertransformer()

As SDPA is already quite stable and performant, we should migrate the BetterTransformer API to the native transformers codebase to support OTB model acceleration and memory efficiency.

cc @LysandreJik @fxmarty

Motivation

Make LLMs faster, out of the box by just updating PyTorch version

Your contribution

Help implementing this in the next versions

@SimJeg
Copy link

SimJeg commented Oct 3, 2023

@younesbelkada, is there a reason why torch.nn.functionnal.scaled_dot_product_attention is not always integrated ? For instance the LlamaAttention class does not use it (see here)

@younesbelkada
Copy link
Contributor Author

Hi @SimJeg !
You can benefit from it already through the BetterTransformer API

pip install transformers optimum

Then once you load the model call:

model = model.to_bettertransformer()

The goal in the future, as mentioned in the issue is to add a native support of SDPA

@fxmarty
Copy link
Contributor

fxmarty commented Oct 3, 2023

Here is a WIP PR #26572

@SimJeg I think it is mostly about Transformers handling padding with a padding mask, which PyTorch SDPA used to not support (until recently) for the optimized paths. Having the code offloaded at first was probably a way to showcase that SDPA indeed works well and that a native integration is worth it!

@schopra8
Copy link

schopra8 commented Oct 3, 2023

@younesbelkada @patrickvonplaten - Hi team, I was looking at the attention implementation in transformers for the various LLMs vs. the attention implementation in diffusers and am a bit confused by the use (or lack of use) with PyTorch SDPA.

Is it correct that the transformers is not using PyTorch SDPA because it cannot not handle padded inputs? If so, how are we able to use Pytorch SDPA in diffusers without running into the same issues?

My understanding is that padding isn't necessary for the self-attention layers of common text-to-image models like Stable Diffusion, but is likely being used in the cross-attention layers, since text prompts are of differing lengths.

@xzuyn
Copy link

xzuyn commented Oct 4, 2023

SDPA makes model inference faster and more memory efficient, and supports multiple hardwares (CPU, GPU, CUDA, AMD...)

Is SDPA inference only, or could it be used during training as an alternative to something like Flash Attention or xformers for the folks who use ROCm? The FA2-ROCm is still a WIP and CDNA2 only.

@fxmarty
Copy link
Contributor

fxmarty commented Oct 4, 2023

@xzuyn SDPA is a wrapper around xformers and Flash Attention kernels, so yes, it can be used for training as well (and is probably even more interesting there). Unfortunately, as far as my knowledge goes, FA is not upstreamed in PyTorch on RoCm systems as of PyTorch 2.1. I believe AMD folks are working towards that though, feel free to open an issue in PyTorch repo to track the progress.

Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@fxmarty
Copy link
Contributor

fxmarty commented Nov 28, 2023

not stale

@fxmarty fxmarty closed this as completed Dec 13, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

5 participants