Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

flash attention support #17

Closed
fmmoret opened this issue Oct 24, 2023 · 3 comments · Fixed by #20
Closed

flash attention support #17

fmmoret opened this issue Oct 24, 2023 · 3 comments · Fixed by #20
Assignees

Comments

@fmmoret
Copy link

fmmoret commented Oct 24, 2023

Relevant issues & prs

huggingface/transformers#26557
huggingface/transformers#26350
huggingface/transformers#26585

@michaelfeil
Copy link
Owner

FlashAttention and related would be someting that is useful on GPU (only/mostly). It could get a 2-4x speedup for Bert.
As this is a free-time project, I would only see options that this is implemented in the upstream inference engine e.g. (sentence-transformers + HF_transformers + torch) or (fastembed + onnx) or TensorRT (e.g. by adding the newly released TensorRT-LLM engine).

By best guess would be to add TensorRT-LLM, which supports BERT with fp8 and flashattention afaik. Would wait for some weeks to let the TensorRT-LLM get more mature.

Contributions are of course also welcome.

@michaelfeil michaelfeil self-assigned this Oct 28, 2023
@michaelfeil
Copy link
Owner

I'll add FlashAttention via BetterTransformers:
https://huggingface.co/docs/optimum/bettertransformer/tutorials/convert

@michaelfeil
Copy link
Owner

infinity uses now torch.nn.functionnal.scaled_dot_product_attention via bettertransformer .
Attention is now anywhere 1.5-3x faster, making infinity around 20% faster on batch inference. This does not use the FlashAttention directly, as we use an attention_mask. See: The PyTorch-native scaled_dot_product_attention operator can only dispatch to Flash Attention if no attention_mask is provided.

Beyond, you can set the torch backend to .half() precision, which also boosts the performance by another 30-40% - but looses in terms of numerical precision. 10**-6 -> 10**-3

if self._target_device.type == "cuda" and os.environ.get(
"INFINITY_TORCH_ENABLE_HALF", False
):
logger.info(
"Switching to half() precision (fp16)."
"Enabled by the setting the env var `INFINITY_TORCH_ENABLE_HALF`"
)
self.half()

Perhaps, the integration better with upcoming versions of torch>=2.0.0 or new releases of optimum.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants