Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

To apply FlashAttention #203

Open
dyanos opened this issue Jun 19, 2023 · 1 comment
Open

To apply FlashAttention #203

dyanos opened this issue Jun 19, 2023 · 1 comment
Assignees

Comments

@dyanos
Copy link
Contributor

dyanos commented Jun 19, 2023

To apply FlashAttention

@dyanos dyanos self-assigned this Jun 19, 2023
@dyanos
Copy link
Contributor Author

dyanos commented Jun 19, 2023

To install

pip install flash-attn

To apply

import torch
from flash_attn.flash_attention import FlashMHA

# Replace this with your correct GPU device
device = "cuda:0"

# Create attention layer. This is similar to torch.nn.MultiheadAttention,
# and it includes the input and output linear layers
flash_mha = FlashMHA(
    embed_dim=128, # total channels (= num_heads * head_dim)
    num_heads=8, # number of heads
    device=device,
    dtype=torch.float16,
)

# Run forward pass with dummy data
x = torch.randn(
    (64, 256, 128), # (batch, seqlen, embed_dim)
    device=device,
    dtype=torch.float16
)

output = flash_mha(x)[0]
from flash_attn.flash_attention import FlashAttention

# Create the nn.Module
flash_attention = FlashAttention()

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant