Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix flash attention for ROCm #7011

Draft
wants to merge 2 commits into
base: master
Choose a base branch
from

Conversation

jdecourval
Copy link
Contributor

@jdecourval jdecourval commented Apr 30, 2024

llama-bench

model size params backend ngl fa test t/s
command-r 35B IQ3_XS - 3.3 bpw 15.65 GiB 37.08 B ROCm 99 1 pp 4096 605.22 ± 0.75
command-r 35B IQ3_XS - 3.3 bpw 15.65 GiB 37.08 B ROCm 99 1 tg 128 26.82 ± 0.01
command-r 35B IQ3_XS - 3.3 bpw 15.65 GiB 37.08 B ROCm 99 0 pp 4096 604.84 ± 0.23
command-r 35B IQ3_XS - 3.3 bpw 15.65 GiB 37.08 B ROCm 99 0 tg 128 26.80 ± 0.01
llama 8B Q5_K - Medium 4.78 GiB 7.24 B ROCm 99 1 pp 4096 2448.01 ± 2.25
llama 8B Q5_K - Medium 4.78 GiB 7.24 B ROCm 99 1 tg 128 86.25 ± 0.03
llama 8B Q5_K - Medium 4.78 GiB 7.24 B ROCm 99 0 pp 4096 2446.30 ± 1.53
llama 8B Q5_K - Medium 4.78 GiB 7.24 B ROCm 99 0 tg 128 86.31 ± 0.01
llama 8x7B IQ3_XXS - 3.0625 bpw 33.27 GiB 91.80 B ROCm 99 1 pp 4096 1033.32 ± 1.28
llama 8x7B IQ3_XXS - 3.0625 bpw 33.27 GiB 91.80 B ROCm 99 1 tg 128 53.41 ± 0.02
llama 8x7B IQ3_XXS - 3.0625 bpw 33.27 GiB 91.80 B ROCm 99 0 pp 4096 1033.59 ± 2.31
llama 8x7B IQ3_XXS - 3.0625 bpw 33.27 GiB 91.80 B ROCm 99 0 tg 128 53.37 ± 0.01
llama 8B Q6_K 5.53 GiB 7.24 B ROCm 99 1 pp 4096 2486.02 ± 1.37
llama 8B Q6_K 5.53 GiB 7.24 B ROCm 99 1 tg 128 84.43 ± 0.02
llama 8B Q6_K 5.53 GiB 7.24 B ROCm 99 0 pp 4096 2481.60 ± 1.73
llama 8B Q6_K 5.53 GiB 7.24 B ROCm 99 0 tg 128 84.41 ± 0.01
llama ?B Q4_K - Small 17.59 GiB 33.34 B ROCm 99 1 pp 4096 610.69 ± 0.36
llama ?B Q4_K - Small 17.59 GiB 33.34 B ROCm 99 1 tg 128 26.62 ± 0.00
llama ?B Q4_K - Small 17.59 GiB 33.34 B ROCm 99 0 pp 4096 610.17 ± 0.20
llama ?B Q4_K - Small 17.59 GiB 33.34 B ROCm 99 0 tg 128 26.60 ± 0.00
./batched-bench $model 10000 2048 512 $fa 1 99 8192 256 1
T_PP s S_PP t/s T_TG s S_TG t/s T s S t/s fa model buffer MiB
3.891 2105.51 5.662 45.22 9.553 884.37 0 7B.Q6_K 692
3.820 2144.46 5.614 45.60 9.434 895.48 1 7B.Q6_K 94
OOM OOM OOM OOM OOM OOM 0 33B.Q4_K_S 1196
14.452 566.86 14.465 17.70 28.916 292.15 1 33B.Q4_K_S 123
3.862 2121.45 5.773 44.34 9.635 876.83 0 8B.Q6_K 692
3.822 2143.62 5.648 45.33 9.469 892.14 1 8B.Q6_K 267
7.315 1119.94 8.936 28.65 16.251 519.85 0 8x7B.IQ3_S 692
6.860 1194.17 8.754 29.24 15.614 541.05 1 8x7B.IQ3_S 150
13.400 611.34 15.499 16.52 28.899 292.33 0 32B.Q4_K_M 860
13.189 621.10 14.526 17.62 27.715 304.81 1 32B.Q4_K_M 307

buffer = ROCm0 compute buffer size

@jdecourval jdecourval mentioned this pull request Apr 30, 2024
8 tasks
Copy link
Contributor

📈 llama.cpp server for bench-server-baseline on Standard_NC4as_T4_v3 for phi-2-q4_0: 555 iterations 🚀

Expand details for performance related PR only
  • Concurrent users: 8, duration: 10m
  • HTTP request : avg=8469.29ms p(95)=20510.1ms fails=, finish reason: stop=490 truncated=65
  • Prompt processing (pp): avg=95.98tk/s p(95)=401.72tk/s
  • Token generation (tg): avg=32.95tk/s p(95)=48.94tk/s
  • ggml-org/models/phi-2/ggml-model-q4_0.gguf parallel=8 ctx-size=16384 ngl=33 batch-size=2048 ubatch-size=256 pp=1024 pp+tg=2048 branch=fixflashattn2 commit=3e560c8665d4ea627be920a26da6d83811fde3b4

prompt_tokens_seconds

More
---
config:
    xyChart:
        titleFontSize: 12
        width: 900
        height: 600
    themeVariables:
        xyChart:
            titleColor: "#000000"
---
xychart-beta
    title "llama.cpp bench-server-baseline on Standard_NC4as_T4_v3
 duration=10m 555 iterations"
    y-axis "llamacpp:prompt_tokens_seconds"
    x-axis "llamacpp:prompt_tokens_seconds" 1714497241 --> 1714497873
    line [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 793.32, 793.32, 793.32, 793.32, 793.32, 644.36, 644.36, 644.36, 644.36, 644.36, 688.58, 688.58, 688.58, 688.58, 688.58, 752.89, 752.89, 752.89, 752.89, 752.89, 753.82, 753.82, 753.82, 753.82, 753.82, 750.84, 750.84, 750.84, 750.84, 750.84, 768.29, 768.29, 768.29, 768.29, 768.29, 783.25, 783.25, 783.25, 783.25, 783.25, 799.9, 799.9, 799.9, 799.9, 799.9, 800.31, 800.31, 800.31, 800.31, 800.31, 825.41, 825.41, 825.41, 825.41, 825.41, 844.02, 844.02, 844.02, 844.02, 844.02, 867.33, 867.33, 867.33, 867.33, 867.33, 822.36, 822.36, 822.36, 822.36, 822.36, 832.09, 832.09, 832.09, 832.09, 832.09, 833.55, 833.55, 833.55, 833.55, 833.55, 831.44, 831.44, 831.44, 831.44, 831.44, 850.23, 850.23, 850.23, 850.23, 850.23, 851.0, 851.0, 851.0, 851.0, 851.0, 849.01, 849.01, 849.01, 849.01, 849.01, 853.53, 853.53, 853.53, 853.53, 853.53, 855.6, 855.6, 855.6, 855.6, 855.6, 835.07, 835.07, 835.07, 835.07, 835.07, 834.85, 834.85, 834.85, 834.85, 834.85, 834.94, 834.94, 834.94, 834.94, 834.94, 850.54, 850.54, 850.54, 850.54, 850.54, 847.53, 847.53, 847.53, 847.53, 847.53, 846.87, 846.87, 846.87, 846.87, 846.87, 845.31, 845.31, 845.31, 845.31, 845.31, 848.47, 848.47, 848.47, 848.47, 848.47, 849.14, 849.14, 849.14, 849.14, 849.14, 847.35, 847.35, 847.35, 847.35, 847.35, 843.3, 843.3, 843.3, 843.3, 843.3, 837.5, 837.5, 837.5, 837.5, 837.5, 843.74, 843.74, 843.74, 843.74, 843.74, 829.99, 829.99, 829.99, 829.99, 829.99, 829.42, 829.42, 829.42, 829.42, 829.42, 829.46, 829.46, 829.46, 829.46, 829.46, 828.76, 828.76, 828.76, 828.76, 828.76, 832.23, 832.23, 832.23, 832.23, 832.23, 832.66, 832.66, 832.66, 832.66, 832.66, 818.68, 818.68, 818.68, 818.68, 818.68, 803.87, 803.87, 803.87, 803.87, 803.87, 803.07, 803.07, 803.07, 803.07, 803.07, 802.13, 802.13, 802.13, 802.13, 802.13, 807.17, 807.17, 807.17, 807.17, 807.17, 810.51, 810.51, 810.51, 810.51, 810.51, 809.97, 809.97, 809.97, 809.97, 809.97, 810.51, 810.51, 810.51, 810.51, 810.51, 815.14, 815.14, 815.14, 815.14, 815.14, 815.18, 815.18, 815.18, 815.18, 815.18, 816.04, 816.04, 816.04, 816.04, 816.04, 816.14, 816.14, 816.14, 816.14, 816.14, 809.82, 809.82, 809.82, 809.82, 809.82, 811.8, 811.8, 811.8, 811.8, 811.8, 811.76, 811.76, 811.76, 811.76, 811.76, 811.95, 811.95, 811.95, 811.95, 811.95, 813.66, 813.66, 813.66, 813.66, 813.66, 817.74, 817.74, 817.74, 817.74, 817.74, 817.99, 817.99, 817.99, 817.99, 817.99, 817.77, 817.77]
                    
Loading
predicted_tokens_seconds
More
---
config:
    xyChart:
        titleFontSize: 12
        width: 900
        height: 600
    themeVariables:
        xyChart:
            titleColor: "#000000"
---
xychart-beta
    title "llama.cpp bench-server-baseline on Standard_NC4as_T4_v3
 duration=10m 555 iterations"
    y-axis "llamacpp:predicted_tokens_seconds"
    x-axis "llamacpp:predicted_tokens_seconds" 1714497241 --> 1714497873
    line [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 32.3, 32.3, 32.3, 32.3, 32.3, 32.83, 32.83, 32.83, 32.83, 32.83, 33.34, 33.34, 33.34, 33.34, 33.34, 34.6, 34.6, 34.6, 34.6, 34.6, 34.72, 34.72, 34.72, 34.72, 34.72, 34.92, 34.92, 34.92, 34.92, 34.92, 35.41, 35.41, 35.41, 35.41, 35.41, 35.59, 35.59, 35.59, 35.59, 35.59, 35.53, 35.53, 35.53, 35.53, 35.53, 34.76, 34.76, 34.76, 34.76, 34.76, 35.02, 35.02, 35.02, 35.02, 35.02, 34.79, 34.79, 34.79, 34.79, 34.79, 33.85, 33.85, 33.85, 33.85, 33.85, 33.17, 33.17, 33.17, 33.17, 33.17, 32.55, 32.55, 32.55, 32.55, 32.55, 32.63, 32.63, 32.63, 32.63, 32.63, 32.87, 32.87, 32.87, 32.87, 32.87, 32.56, 32.56, 32.56, 32.56, 32.56, 32.3, 32.3, 32.3, 32.3, 32.3, 31.97, 31.97, 31.97, 31.97, 31.97, 31.71, 31.71, 31.71, 31.71, 31.71, 31.78, 31.78, 31.78, 31.78, 31.78, 31.82, 31.82, 31.82, 31.82, 31.82, 32.04, 32.04, 32.04, 32.04, 32.04, 32.06, 32.06, 32.06, 32.06, 32.06, 32.15, 32.15, 32.15, 32.15, 32.15, 31.99, 31.99, 31.99, 31.99, 31.99, 31.33, 31.33, 31.33, 31.33, 31.33, 31.46, 31.46, 31.46, 31.46, 31.46, 31.75, 31.75, 31.75, 31.75, 31.75, 31.86, 31.86, 31.86, 31.86, 31.86, 32.0, 32.0, 32.0, 32.0, 32.0, 32.13, 32.13, 32.13, 32.13, 32.13, 32.06, 32.06, 32.06, 32.06, 32.06, 31.99, 31.99, 31.99, 31.99, 31.99, 31.81, 31.81, 31.81, 31.81, 31.81, 31.76, 31.76, 31.76, 31.76, 31.76, 31.82, 31.82, 31.82, 31.82, 31.82, 31.98, 31.98, 31.98, 31.98, 31.98, 32.08, 32.08, 32.08, 32.08, 32.08, 32.18, 32.18, 32.18, 32.18, 32.18, 32.23, 32.23, 32.23, 32.23, 32.23, 32.0, 32.0, 32.0, 32.0, 32.0, 31.28, 31.28, 31.28, 31.28, 31.28, 30.93, 30.93, 30.93, 30.93, 30.93, 30.56, 30.56, 30.56, 30.56, 30.56, 30.53, 30.53, 30.53, 30.53, 30.53, 30.57, 30.57, 30.57, 30.57, 30.57, 30.69, 30.69, 30.69, 30.69, 30.69, 30.75, 30.75, 30.75, 30.75, 30.75, 30.78, 30.78, 30.78, 30.78, 30.78, 30.71, 30.71, 30.71, 30.71, 30.71, 30.53, 30.53, 30.53, 30.53, 30.53, 30.52, 30.52, 30.52, 30.52, 30.52, 30.61, 30.61, 30.61, 30.61, 30.61, 30.73, 30.73, 30.73, 30.73, 30.73, 30.79, 30.79, 30.79, 30.79, 30.79, 30.94, 30.94, 30.94, 30.94, 30.94, 30.96, 30.96, 30.96, 30.96, 30.96, 30.99, 30.99, 30.99, 30.99, 30.99, 31.0, 31.0]
                    
Loading

Details

kv_cache_usage_ratio

More
---
config:
    xyChart:
        titleFontSize: 12
        width: 900
        height: 600
    themeVariables:
        xyChart:
            titleColor: "#000000"
---
xychart-beta
    title "llama.cpp bench-server-baseline on Standard_NC4as_T4_v3
 duration=10m 555 iterations"
    y-axis "llamacpp:kv_cache_usage_ratio"
    x-axis "llamacpp:kv_cache_usage_ratio" 1714497241 --> 1714497873
    line [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.25, 0.25, 0.25, 0.25, 0.25, 0.22, 0.22, 0.22, 0.22, 0.22, 0.1, 0.1, 0.1, 0.1, 0.1, 0.23, 0.23, 0.23, 0.23, 0.23, 0.11, 0.11, 0.11, 0.11, 0.11, 0.15, 0.15, 0.15, 0.15, 0.15, 0.17, 0.17, 0.17, 0.17, 0.17, 0.14, 0.14, 0.14, 0.14, 0.14, 0.23, 0.23, 0.23, 0.23, 0.23, 0.12, 0.12, 0.12, 0.12, 0.12, 0.19, 0.19, 0.19, 0.19, 0.19, 0.22, 0.22, 0.22, 0.22, 0.22, 0.18, 0.18, 0.18, 0.18, 0.18, 0.29, 0.29, 0.29, 0.29, 0.29, 0.16, 0.16, 0.16, 0.16, 0.16, 0.16, 0.16, 0.16, 0.16, 0.16, 0.15, 0.15, 0.15, 0.15, 0.15, 0.28, 0.28, 0.28, 0.28, 0.28, 0.32, 0.32, 0.32, 0.32, 0.32, 0.14, 0.14, 0.14, 0.14, 0.14, 0.17, 0.17, 0.17, 0.17, 0.17, 0.14, 0.14, 0.14, 0.14, 0.14, 0.18, 0.18, 0.18, 0.18, 0.18, 0.11, 0.11, 0.11, 0.11, 0.11, 0.17, 0.17, 0.17, 0.17, 0.17, 0.29, 0.29, 0.29, 0.29, 0.29, 0.37, 0.37, 0.37, 0.37, 0.37, 0.23, 0.23, 0.23, 0.23, 0.23, 0.13, 0.13, 0.13, 0.13, 0.13, 0.13, 0.13, 0.13, 0.13, 0.13, 0.09, 0.09, 0.09, 0.09, 0.09, 0.19, 0.19, 0.19, 0.19, 0.19, 0.17, 0.17, 0.17, 0.17, 0.17, 0.24, 0.24, 0.24, 0.24, 0.24, 0.2, 0.2, 0.2, 0.2, 0.2, 0.32, 0.32, 0.32, 0.32, 0.32, 0.12, 0.12, 0.12, 0.12, 0.12, 0.16, 0.16, 0.16, 0.16, 0.16, 0.16, 0.16, 0.16, 0.16, 0.16, 0.1, 0.1, 0.1, 0.1, 0.1, 0.18, 0.18, 0.18, 0.18, 0.18, 0.45, 0.45, 0.45, 0.45, 0.45, 0.54, 0.54, 0.54, 0.54, 0.54, 0.39, 0.39, 0.39, 0.39, 0.39, 0.31, 0.31, 0.31, 0.31, 0.31, 0.19, 0.19, 0.19, 0.19, 0.19, 0.14, 0.14, 0.14, 0.14, 0.14, 0.09, 0.09, 0.09, 0.09, 0.09, 0.08, 0.08, 0.08, 0.08, 0.08, 0.15, 0.15, 0.15, 0.15, 0.15, 0.22, 0.22, 0.22, 0.22, 0.22, 0.2, 0.2, 0.2, 0.2, 0.2, 0.14, 0.14, 0.14, 0.14, 0.14, 0.2, 0.2, 0.2, 0.2, 0.2, 0.12, 0.12, 0.12, 0.12, 0.12, 0.13, 0.13, 0.13, 0.13, 0.13, 0.09, 0.09, 0.09, 0.09, 0.09, 0.13, 0.13, 0.13, 0.13, 0.13, 0.16, 0.16, 0.16, 0.16, 0.16, 0.17, 0.17, 0.17, 0.17, 0.17, 0.21, 0.21]
                    
Loading
requests_processing
More
---
config:
    xyChart:
        titleFontSize: 12
        width: 900
        height: 600
    themeVariables:
        xyChart:
            titleColor: "#000000"
---
xychart-beta
    title "llama.cpp bench-server-baseline on Standard_NC4as_T4_v3
 duration=10m 555 iterations"
    y-axis "llamacpp:requests_processing"
    x-axis "llamacpp:requests_processing" 1714497241 --> 1714497873
    line [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 3.0, 3.0, 3.0, 3.0, 3.0, 8.0, 8.0, 8.0, 8.0, 8.0, 5.0, 5.0, 5.0, 5.0, 5.0, 3.0, 3.0, 3.0, 3.0, 3.0, 7.0, 7.0, 7.0, 7.0, 7.0, 5.0, 5.0, 5.0, 5.0, 5.0, 4.0, 4.0, 4.0, 4.0, 4.0, 5.0, 5.0, 5.0, 5.0, 5.0, 7.0, 7.0, 7.0, 7.0, 7.0, 4.0, 4.0, 4.0, 4.0, 4.0, 7.0, 7.0, 7.0, 7.0, 7.0, 4.0, 4.0, 4.0, 4.0, 4.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 2.0, 2.0, 2.0, 2.0, 2.0, 4.0, 4.0, 4.0, 4.0, 4.0, 5.0, 5.0, 5.0, 5.0, 5.0, 6.0, 6.0, 6.0, 6.0, 6.0, 8.0, 8.0, 8.0, 8.0, 8.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 4.0, 4.0, 4.0, 4.0, 4.0, 5.0, 5.0, 5.0, 5.0, 5.0, 6.0, 6.0, 6.0, 6.0, 6.0, 8.0, 8.0, 8.0, 8.0, 8.0, 6.0, 6.0, 6.0, 6.0, 6.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 1.0, 1.0, 1.0, 1.0, 1.0, 5.0, 5.0, 5.0, 5.0, 5.0, 4.0, 4.0, 4.0, 4.0, 4.0, 2.0, 2.0, 2.0, 2.0, 2.0, 5.0, 5.0, 5.0, 5.0, 5.0, 8.0, 8.0, 8.0, 8.0, 8.0, 5.0, 5.0, 5.0, 5.0, 5.0, 7.0, 7.0, 7.0, 7.0, 7.0, 4.0, 4.0, 4.0, 4.0, 4.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 4.0, 4.0, 4.0, 4.0, 4.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 4.0, 4.0, 4.0, 4.0, 4.0, 7.0, 7.0, 7.0, 7.0, 7.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 8.0, 8.0, 8.0, 8.0, 8.0, 7.0, 7.0, 7.0, 7.0, 7.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 1.0, 1.0, 1.0, 1.0, 1.0, 7.0, 7.0, 7.0, 7.0, 7.0, 2.0, 2.0, 2.0, 2.0, 2.0, 8.0, 8.0, 8.0, 8.0, 8.0, 7.0, 7.0, 7.0, 7.0, 7.0, 4.0, 4.0]
                    
Loading

@JohannesGaessler
Copy link
Collaborator

I didn't close that other PR on accident. As I said before, I don't think we should be adding a dependency with rocWMMA when the performance is no better than master and we have no dev to test and support it. And I will do an implementation of FlashAttention without any tensor cores at all which may end up being faster anyways.

@sorasoras
Copy link

I don't know how to get compile on windows :(

@jdecourval
Copy link
Contributor Author

I didn't close that other PR on accident. As I said before, I don't think we should be adding a dependency with rocWMMA when the performance is no better than master and we have no dev to test and support it. And I will do an implementation of FlashAttention without any tensor cores at all which may end up being faster anyways.

Sorry, I didn't realize it had been closed on purpose. Is the dependency that bad, though? rocwmma is header only, so no link time requirement, and it enables sharing the existing CUDA code. The performance is not better, but the VRAM saving can be very significant, 1 GB in one case. The PR is not ready to merge as is anyway, I need to disable flash-attn in CMake by default for AMD GPUs, or enable it only if rocwmma is detected installed. I might not be a ROCm expert, but I am a C++ dev and I own a 7900xtx, if not merged, I might maintain this fork anyway.

Of course, if you already have planned to work on that other implementation soon, all of this comment is irrelevant, but having access to a rocwmma based version as a comparison could be useful, I don't know. Please let me know what you think.

@sorasoras
Copy link

I didn't close that other PR on accident. As I said before, I don't think we should be adding a dependency with rocWMMA when the performance is no better than master and we have no dev to test and support it. And I will do an implementation of FlashAttention without any tensor cores at all which may end up being faster anyways.

Sorry, I didn't realize it had been closed on purpose. Is the dependency that bad, though? rocwmma is header only, so no link time requirement, and it enables sharing the existing CUDA code. The performance is not better, but the VRAM saving can be very significant, 1 GB in one case. The PR is not ready to merge as is anyway, I need to disable flash-attn in CMake by default for AMD GPUs, or enable it only if rocwmma is detected installed. I might not be a ROCm expert, but I am a C++ dev and I own a 7900xtx, if not merged, I might maintain this fork anyway.

Of course, if you already have planned to work on that other implementation soon, all of this comment is irrelevant, but having access to a rocwmma based version as a comparison could be useful, I don't know. Please let me know what you think.

I wasn't able to test flash attention on Windows with 7900XTX yet.
I wonder if there is any different between power consumption between with fa and without fa.

@mofosyne mofosyne added enhancement New feature or request Review Complexity : High Generally require indepth knowledge of LLMs or GPUs labels May 9, 2024
@IMbackK
Copy link

IMbackK commented Jun 1, 2024

So i can say that for CDNA this makes a big difference:
./bin/llama-bench -m $(MODEL) -fa 1 -p 4096

This pr:


  Device 0: AMD Instinct MI100, compute capability 9.0, VMM: no
| model                          |       size |     params | backend    | ngl |         fa | test       |              t/s |
| ------------------------------ | ---------: | ---------: | ---------- | --: | ---------: | ---------- | ---------------: |
| llama 8B Q4_K - Medium         |   4.07 GiB |     7.24 B | ROCm       |  99 |          1 | pp 512     |    675.25 ± 0.53 |
| llama 8B Q4_K - Medium         |   4.07 GiB |     7.24 B | ROCm       |  99 |          1 | tg 128     |     77.39 ± 0.21 |

Lastest Master:

  Device 0: AMD Instinct MI100, compute capability 9.0, VMM: no
| model                          |       size |     params | backend    | ngl |         fa |          test |              t/s |
| ------------------------------ | ---------: | ---------: | ---------- | --: | ---------: | ------------: | ---------------: |
| llama 7B Q4_K - Medium         |   4.07 GiB |     7.24 B | ROCm       |  99 |          1 |        pp4096 |   450.68 ± 15.59 |
| llama 7B Q4_K - Medium         |   4.07 GiB |     7.24 B | ROCm       |  99 |          1 |         tg128 |     73.86 ± 0.39 |

Both of those are still terrible compared to exllama but this pr dose make a big difference in the right direction
The fact that this makes a bigger difference for CNDA is not a huge supprise CDNA MFMA is a siginificant bit faster than RDNA3's WMMA and CDNA has low vector flops compared to RDNA3

Note that i had to make some trivial changes to this pr to make it choose the wmma path for gfx908

@IMbackK IMbackK mentioned this pull request Jun 23, 2024
4 tasks
@IMbackK
Copy link

IMbackK commented Jun 24, 2024

Id like to mention it here too, that after some optimization work to the gemm kernels (#8082) this pr now improves pp performance on CDNA by almost 2x and i really think the stance towards this pr needs to be revised. A tiny optional header only dependency is for sure worth a 2x or even 10% increase in speed and the fact that the cuda equivalent depedancy is fine but the rocm equivalent is not speaks volumes, as dose the comment on rocm perfomance here: #7716.

@JohannesGaessler
Copy link
Collaborator

My original plan was to buy an AMD GPU with tensor cores so that I can test and maintain these changes myself (currently I only have an RX 6800). But I currently have issues finding a slot for it in one of my machines. However, if I can get a pledge from you that you will help with maintenance I would be fine with merging a PR like this.

Keep in mind though that the WMMA FlashAttention kernels that I wrote for CUDA are bad in the first place. They rely on the "high-level" WMMA functionality to use tensor cores but after talking to an NVIDIA engineer and doing some related work myself the better way to utilize tensor cores is via PTX instructions (CUDA equivalent of assembly). So I want to at some point rewrite the code accordingly. Instead of rocWMMA it would be much better to implement the equivalent AMD functionality in ggml-cuda/mma.cuh (though I suspect the memory access patterns I implemented are bad for AMD GPUs).

@IMbackK
Copy link

IMbackK commented Jun 28, 2024

i cant accept maintainership of llamacpp/hip. I can promise to run regular testing (automated even if desired) on cdna.
Since with the attn code being essentally the same between cuda and hip for now i think accepting this would be fine never the less. In the future if the attn cuda code is to change it may make more sense to completely seperate the codepaths and simply freeze the hip attn code untill someone steps up, i think would be mutch better than the current set of afairs where the hip backend conisouly regresses as changes to the cuda backend are made with no regards for performance of hip or even testing for breakage.

The current state of affairs also strongly discourage any optimization effort on my and others part, as even if you do some work optimize the hip back end, and even if you manage to get that merged, the nvidia centric churn in the common code base invetiably breaks performance again, usually only shortly later.

also note that gfx11's wmma and gfx908/a/4x's mfma are very different with totally different hw implementation performance characteristics.

@JohannesGaessler
Copy link
Collaborator

i cant accept maintainership of llamacpp/hip. I can promise to run regular testing (automated even if desired) on cdna.

Since with the attn code being essentally the same between cuda and hip for now i think accepting this would be fine never the less.

changes to the cuda backend are made with no regards for performance of hip or even testing for breakage.

When I make changes to the CUDA code I test it for correctness and performance using my RX 6800. My standard for numerical software is that correct results are the first and foremost priority. I very much do not want to have any broken code in my repositories. So if I cannot test or assert that the code produces correct results myself and if I also cannot delegate this to anyone else then I am simply not willing to merge the corresponding PR. The simple rocWMMA prototype that I did still required fixes from other people to work at all.

My current stance towards HIP performance is that I am willing to invest some effort for AMD support "within reason". When it comes to MMQ in particular the performance depends very heavily on the exact data layout and for good AMD performance you would have to completely re-write the code anyways.

@jammm
Copy link
Contributor

jammm commented Oct 16, 2024

My original plan was to buy an AMD GPU with tensor cores so that I can test and maintain these changes myself (currently I only have an RX 6800).

@JohannesGaessler RDNA3 doesn't have dedicated tensor cores like CDNA does. So you will not see the same 2x perf boost @IMbackK sees. This is happening because rocWMMA translates to MFMA instructions on CDNA archs, which in turn runs directly on their matrix (tensor) cores. On RDNA3, rocWMMA translates to WMMA instructions that run on the shader cores which don't give as much of a perf boost as the dedicated matrix cores. This is why you are not seeing much of a perf boost when not using rocWMMA on RDNA3.

I would highly recommend this PR gets merged for the 2x perf boost on CDNA alone. Otherwise you are not using those matrix cores at all. You might remember me recommending rocWMMA a while ago in #4801 (comment)

Now, I realize that you are going to deprecate this soon. But let's not leave this 2x perf on the table and keep this code even if it breaks. I wish I had a AMD GPU locally to maintain this, but alas, I don't have one at the moment since I left AMD...

@jammm
Copy link
Contributor

jammm commented Oct 16, 2024

@JohannesGaessler what's your email address? I reached out to AMD to see if someone can lend a hand in maintenance. If possible please share your email so they could reach out to you to see if they can support.

@JohannesGaessler
Copy link
Collaborator

My email address can be found on my Github profile. But as I said, as of right now my plan is to remove the WMMA-based implementation and I don't want to invest the effort to maintain it long-term.

@Said-Akbar
Copy link

@JohannesGaessler,

Is my understanding correct that AMD GPUs with no matrix cores (for example, I have 2x AMD MI60 - gfx906 with no matrix cores) can see good improvements in text generation (say >=1.5x) if you implement generic flash attention with no matrix core dependency?

On a similar note, Is it possible to use a custom compiled flash attention that works with ROCm (e.g. AMD MI60) in llama.cpp? Someone on reddit shared that they have successfully compiled a flash attention library for MI60.

Some benchmarks they shared for that compiled flash attention:


### causal=False, headdim=64, batch_size=32, seqlen=512 ###
Flash2 fwd: 49.30 TFLOPs/s, bwd: 30.33 TFLOPs/s, fwd + bwd: 34.08 TFLOPs/s
Pytorch fwd: 5.30 TFLOPs/s, bwd: 7.77 TFLOPs/s, fwd + bwd: 6.86 TFLOPs/s
Triton fwd: 0.00 TFLOPs/s, bwd: 0.00 TFLOPs/s, fwd + bwd: 0.00 TFLOPs/s
 
 
### causal=False, headdim=64, batch_size=16, seqlen=1024 ###
Flash2 fwd: 64.35 TFLOPs/s, bwd: 36.21 TFLOPs/s, fwd + bwd: 41.38 TFLOPs/s
Pytorch fwd: 5.60 TFLOPs/s, bwd: 8.48 TFLOPs/s, fwd + bwd: 7.39 TFLOPs/s
Triton fwd: 0.00 TFLOPs/s, bwd: 0.00 TFLOPs/s, fwd + bwd: 0.00 TFLOPs/s
 
 
### causal=False, headdim=64, batch_size=8, seqlen=2048 ###
Flash2 fwd: 51.53 TFLOPs/s, bwd: 32.75 TFLOPs/s, fwd + bwd: 36.55 TFLOPs/s
Pytorch fwd: 4.71 TFLOPs/s, bwd: 4.76 TFLOPs/s, fwd + bwd: 4.74 TFLOPs/s
Triton fwd: 0.00 TFLOPs/s, bwd: 0.00 TFLOPs/s, fwd + bwd: 0.00 TFLOPs/s
...

### causal=False, headdim=128, batch_size=16, seqlen=1024 ###
Flash2 fwd: 70.61 TFLOPs/s, bwd: 17.20 TFLOPs/s, fwd + bwd: 21.95 TFLOPs/s
Pytorch fwd: 5.07 TFLOPs/s, bwd: 6.51 TFLOPs/s, fwd + bwd: 6.02 TFLOPs/s
Triton fwd: 0.00 TFLOPs/s, bwd: 0.00 TFLOPs/s, fwd + bwd: 0.00 TFLOPs/s

It seems there is almost 10x improvement in fwd pass and around 4x speed up in bwd pass compared to Pytorch.

@JohannesGaessler
Copy link
Collaborator

Is my understanding correct that AMD GPUs with no matrix cores (for example, I have 2x AMD MI60 - gfx906 with no matrix cores) can see good improvements in text generation (say >=1.5x) if you implement generic flash attention with no matrix core dependency?

There already are implementations that do this, the performance on AMD is just bad, especially for large batch sizes (i.e. prompt processing). CUDA has seen performance improvements comparable to what you're asking so my expectation would be that a proper ROCm implementation instead of a HIP port of the CUDA code would have a similar speedup.

On a similar note, Is it possible to use a custom compiled flash attention

llama.cpp/GGML has no support whatsoever for external FlashAttention implementations.

@Headcrabed
Copy link

Hello, any update on it? Is it possible for us to see it merged?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request Review Complexity : High Generally require indepth knowledge of LLMs or GPUs
Projects
None yet
Development

Successfully merging this pull request may close these issues.

8 participants