Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CUDA: fix FA out-of-bounds reads #7479

Merged

Conversation

JohannesGaessler
Copy link
Collaborator

After #7465 I noticed that the global reads in the FlashAttention kernels also have the problem of not checking for out-of-bounds access; this PR adds the necessary checks to avoid potential memory errors.

I also added a small performance optimization that skips the checks for ncols <= 2 because those particular kernels are only used for their exact batch sizes.

@github-actions github-actions bot added Nvidia GPU Issues specific to Nvidia GPUs ggml changes relating to the ggml tensor library for machine learning labels May 22, 2024
@JohannesGaessler JohannesGaessler merged commit cd93a28 into ggerganov:master May 22, 2024
58 of 69 checks passed
Copy link
Contributor

📈 llama.cpp server for bench-server-baseline on Standard_NC4as_T4_v3 for phi-2-q4_0: 537 iterations 🚀

Expand details for performance related PR only
  • Concurrent users: 8, duration: 10m
  • HTTP request : avg=8678.36ms p(95)=21626.31ms fails=, finish reason: stop=479 truncated=58
  • Prompt processing (pp): avg=100.49tk/s p(95)=396.16tk/s
  • Token generation (tg): avg=35.67tk/s p(95)=46.23tk/s
  • ggml-org/models/phi-2/ggml-model-q4_0.gguf parallel=8 ctx-size=16384 ngl=33 batch-size=2048 ubatch-size=256 pp=1024 pp+tg=2048 branch=cuda-fix-fa-oob2 commit=d76d1465e905988cbf5a41d0595e391bb8b804db

prompt_tokens_seconds

More
---
config:
    xyChart:
        titleFontSize: 12
        width: 900
        height: 600
    themeVariables:
        xyChart:
            titleColor: "#000000"
---
xychart-beta
    title "llama.cpp bench-server-baseline on Standard_NC4as_T4_v3
 duration=10m 537 iterations"
    y-axis "llamacpp:prompt_tokens_seconds"
    x-axis "llamacpp:prompt_tokens_seconds" 1716419879 --> 1716420505
    line [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 777.82, 777.82, 777.82, 777.82, 777.82, 731.69, 731.69, 731.69, 731.69, 731.69, 745.14, 745.14, 745.14, 745.14, 745.14, 782.52, 782.52, 782.52, 782.52, 782.52, 839.43, 839.43, 839.43, 839.43, 839.43, 836.25, 836.25, 836.25, 836.25, 836.25, 848.71, 848.71, 848.71, 848.71, 848.71, 854.33, 854.33, 854.33, 854.33, 854.33, 862.23, 862.23, 862.23, 862.23, 862.23, 864.12, 864.12, 864.12, 864.12, 864.12, 887.26, 887.26, 887.26, 887.26, 887.26, 864.22, 864.22, 864.22, 864.22, 864.22, 899.15, 899.15, 899.15, 899.15, 899.15, 911.11, 911.11, 911.11, 911.11, 911.11, 911.01, 911.01, 911.01, 911.01, 911.01, 912.16, 912.16, 912.16, 912.16, 912.16, 907.27, 907.27, 907.27, 907.27, 907.27, 921.03, 921.03, 921.03, 921.03, 921.03, 916.71, 916.71, 916.71, 916.71, 916.71, 917.74, 917.74, 917.74, 917.74, 917.74, 922.0, 922.0, 922.0, 922.0, 922.0, 924.79, 924.79, 924.79, 924.79, 924.79, 924.5, 924.5, 924.5, 924.5, 924.5, 911.78, 911.78, 911.78, 911.78, 911.78, 910.8, 910.8, 910.8, 910.8, 910.8, 911.11, 911.11, 911.11, 911.11, 911.11, 919.09, 919.09, 919.09, 919.09, 919.09, 918.66, 918.66, 918.66, 918.66, 918.66, 917.19, 917.19, 917.19, 917.19, 917.19, 920.51, 920.51, 920.51, 920.51, 920.51, 919.43, 919.43, 919.43, 919.43, 919.43, 916.22, 916.22, 916.22, 916.22, 916.22, 915.59, 915.59, 915.59, 915.59, 915.59, 917.27, 917.27, 917.27, 917.27, 917.27, 910.96, 910.96, 910.96, 910.96, 910.96, 893.69, 893.69, 893.69, 893.69, 893.69, 887.9, 887.9, 887.9, 887.9, 887.9, 886.47, 886.47, 886.47, 886.47, 886.47, 884.57, 884.57, 884.57, 884.57, 884.57, 886.12, 886.12, 886.12, 886.12, 886.12, 887.77, 887.77, 887.77, 887.77, 887.77, 886.49, 886.49, 886.49, 886.49, 886.49, 883.37, 883.37, 883.37, 883.37, 883.37, 856.25, 856.25, 856.25, 856.25, 856.25, 856.12, 856.12, 856.12, 856.12, 856.12, 848.86, 848.86, 848.86, 848.86, 848.86, 844.83, 844.83, 844.83, 844.83, 844.83, 845.23, 845.23, 845.23, 845.23, 845.23, 844.63, 844.63, 844.63, 844.63, 844.63, 847.19, 847.19, 847.19, 847.19, 847.19, 849.4, 849.4, 849.4, 849.4, 849.4, 849.3, 849.3, 849.3, 849.3, 849.3, 850.4, 850.4, 850.4, 850.4, 850.4, 846.54, 846.54, 846.54, 846.54, 846.54, 847.57, 847.57, 847.57, 847.57, 847.57, 843.95, 843.95, 843.95, 843.95, 843.95, 845.11, 845.11, 845.11, 845.11, 845.11, 846.5, 846.5, 846.5, 846.5, 846.5, 846.42, 846.42, 846.42, 846.42, 846.42, 848.34, 848.34, 848.34, 848.34]
                    
Loading
predicted_tokens_seconds
More
---
config:
    xyChart:
        titleFontSize: 12
        width: 900
        height: 600
    themeVariables:
        xyChart:
            titleColor: "#000000"
---
xychart-beta
    title "llama.cpp bench-server-baseline on Standard_NC4as_T4_v3
 duration=10m 537 iterations"
    y-axis "llamacpp:predicted_tokens_seconds"
    x-axis "llamacpp:predicted_tokens_seconds" 1716419879 --> 1716420505
    line [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 43.69, 43.69, 43.69, 43.69, 43.69, 34.32, 34.32, 34.32, 34.32, 34.32, 28.92, 28.92, 28.92, 28.92, 28.92, 30.11, 30.11, 30.11, 30.11, 30.11, 30.98, 30.98, 30.98, 30.98, 30.98, 32.01, 32.01, 32.01, 32.01, 32.01, 32.54, 32.54, 32.54, 32.54, 32.54, 33.22, 33.22, 33.22, 33.22, 33.22, 33.25, 33.25, 33.25, 33.25, 33.25, 33.6, 33.6, 33.6, 33.6, 33.6, 33.73, 33.73, 33.73, 33.73, 33.73, 33.79, 33.79, 33.79, 33.79, 33.79, 32.87, 32.87, 32.87, 32.87, 32.87, 32.4, 32.4, 32.4, 32.4, 32.4, 31.01, 31.01, 31.01, 31.01, 31.01, 30.0, 30.0, 30.0, 30.0, 30.0, 30.02, 30.02, 30.02, 30.02, 30.02, 30.21, 30.21, 30.21, 30.21, 30.21, 29.98, 29.98, 29.98, 29.98, 29.98, 29.78, 29.78, 29.78, 29.78, 29.78, 29.79, 29.79, 29.79, 29.79, 29.79, 30.0, 30.0, 30.0, 30.0, 30.0, 30.16, 30.16, 30.16, 30.16, 30.16, 29.95, 29.95, 29.95, 29.95, 29.95, 30.14, 30.14, 30.14, 30.14, 30.14, 30.38, 30.38, 30.38, 30.38, 30.38, 30.21, 30.21, 30.21, 30.21, 30.21, 30.24, 30.24, 30.24, 30.24, 30.24, 30.58, 30.58, 30.58, 30.58, 30.58, 30.64, 30.64, 30.64, 30.64, 30.64, 30.72, 30.72, 30.72, 30.72, 30.72, 30.71, 30.71, 30.71, 30.71, 30.71, 30.71, 30.71, 30.71, 30.71, 30.71, 30.75, 30.75, 30.75, 30.75, 30.75, 30.68, 30.68, 30.68, 30.68, 30.68, 30.53, 30.53, 30.53, 30.53, 30.53, 30.49, 30.49, 30.49, 30.49, 30.49, 30.27, 30.27, 30.27, 30.27, 30.27, 30.27, 30.27, 30.27, 30.27, 30.27, 30.4, 30.4, 30.4, 30.4, 30.4, 30.56, 30.56, 30.56, 30.56, 30.56, 30.61, 30.61, 30.61, 30.61, 30.61, 30.75, 30.75, 30.75, 30.75, 30.75, 30.68, 30.68, 30.68, 30.68, 30.68, 30.31, 30.31, 30.31, 30.31, 30.31, 29.88, 29.88, 29.88, 29.88, 29.88, 29.11, 29.11, 29.11, 29.11, 29.11, 29.09, 29.09, 29.09, 29.09, 29.09, 29.12, 29.12, 29.12, 29.12, 29.12, 29.27, 29.27, 29.27, 29.27, 29.27, 29.28, 29.28, 29.28, 29.28, 29.28, 29.29, 29.29, 29.29, 29.29, 29.29, 29.35, 29.35, 29.35, 29.35, 29.35, 29.2, 29.2, 29.2, 29.2, 29.2, 29.17, 29.17, 29.17, 29.17, 29.17, 29.14, 29.14, 29.14, 29.14, 29.14, 29.26, 29.26, 29.26, 29.26, 29.26, 29.43, 29.43, 29.43, 29.43, 29.43, 29.54, 29.54, 29.54, 29.54, 29.54, 29.66, 29.66, 29.66, 29.66]
                    
Loading

Details

kv_cache_usage_ratio

More
---
config:
    xyChart:
        titleFontSize: 12
        width: 900
        height: 600
    themeVariables:
        xyChart:
            titleColor: "#000000"
---
xychart-beta
    title "llama.cpp bench-server-baseline on Standard_NC4as_T4_v3
 duration=10m 537 iterations"
    y-axis "llamacpp:kv_cache_usage_ratio"
    x-axis "llamacpp:kv_cache_usage_ratio" 1716419879 --> 1716420505
    line [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.36, 0.36, 0.36, 0.36, 0.36, 0.35, 0.35, 0.35, 0.35, 0.35, 0.22, 0.22, 0.22, 0.22, 0.22, 0.12, 0.12, 0.12, 0.12, 0.12, 0.21, 0.21, 0.21, 0.21, 0.21, 0.18, 0.18, 0.18, 0.18, 0.18, 0.16, 0.16, 0.16, 0.16, 0.16, 0.11, 0.11, 0.11, 0.11, 0.11, 0.2, 0.2, 0.2, 0.2, 0.2, 0.08, 0.08, 0.08, 0.08, 0.08, 0.16, 0.16, 0.16, 0.16, 0.16, 0.31, 0.31, 0.31, 0.31, 0.31, 0.32, 0.32, 0.32, 0.32, 0.32, 0.4, 0.4, 0.4, 0.4, 0.4, 0.38, 0.38, 0.38, 0.38, 0.38, 0.21, 0.21, 0.21, 0.21, 0.21, 0.18, 0.18, 0.18, 0.18, 0.18, 0.31, 0.31, 0.31, 0.31, 0.31, 0.2, 0.2, 0.2, 0.2, 0.2, 0.24, 0.24, 0.24, 0.24, 0.24, 0.19, 0.19, 0.19, 0.19, 0.19, 0.14, 0.14, 0.14, 0.14, 0.14, 0.18, 0.18, 0.18, 0.18, 0.18, 0.14, 0.14, 0.14, 0.14, 0.14, 0.11, 0.11, 0.11, 0.11, 0.11, 0.15, 0.15, 0.15, 0.15, 0.15, 0.07, 0.07, 0.07, 0.07, 0.07, 0.11, 0.11, 0.11, 0.11, 0.11, 0.1, 0.1, 0.1, 0.1, 0.1, 0.22, 0.22, 0.22, 0.22, 0.22, 0.21, 0.21, 0.21, 0.21, 0.21, 0.12, 0.12, 0.12, 0.12, 0.12, 0.15, 0.15, 0.15, 0.15, 0.15, 0.2, 0.2, 0.2, 0.2, 0.2, 0.25, 0.25, 0.25, 0.25, 0.25, 0.15, 0.15, 0.15, 0.15, 0.15, 0.31, 0.31, 0.31, 0.31, 0.31, 0.2, 0.2, 0.2, 0.2, 0.2, 0.11, 0.11, 0.11, 0.11, 0.11, 0.17, 0.17, 0.17, 0.17, 0.17, 0.13, 0.13, 0.13, 0.13, 0.13, 0.15, 0.15, 0.15, 0.15, 0.15, 0.42, 0.42, 0.42, 0.42, 0.42, 0.52, 0.52, 0.52, 0.52, 0.52, 0.65, 0.65, 0.65, 0.65, 0.65, 0.43, 0.43, 0.43, 0.43, 0.43, 0.08, 0.08, 0.08, 0.08, 0.08, 0.21, 0.21, 0.21, 0.21, 0.21, 0.19, 0.19, 0.19, 0.19, 0.19, 0.17, 0.17, 0.17, 0.17, 0.17, 0.15, 0.15, 0.15, 0.15, 0.15, 0.23, 0.23, 0.23, 0.23, 0.23, 0.26, 0.26, 0.26, 0.26, 0.26, 0.29, 0.29, 0.29, 0.29, 0.29, 0.23, 0.23, 0.23, 0.23, 0.23, 0.09, 0.09, 0.09, 0.09, 0.09, 0.12, 0.12, 0.12, 0.12, 0.12, 0.12, 0.12, 0.12, 0.12, 0.12, 0.1, 0.1, 0.1, 0.1, 0.1, 0.14, 0.14, 0.14, 0.14]
                    
Loading
requests_processing
More
---
config:
    xyChart:
        titleFontSize: 12
        width: 900
        height: 600
    themeVariables:
        xyChart:
            titleColor: "#000000"
---
xychart-beta
    title "llama.cpp bench-server-baseline on Standard_NC4as_T4_v3
 duration=10m 537 iterations"
    y-axis "llamacpp:requests_processing"
    x-axis "llamacpp:requests_processing" 1716419879 --> 1716420505
    line [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 4.0, 4.0, 4.0, 4.0, 4.0, 6.0, 6.0, 6.0, 6.0, 6.0, 3.0, 3.0, 3.0, 3.0, 3.0, 5.0, 5.0, 5.0, 5.0, 5.0, 4.0, 4.0, 4.0, 4.0, 4.0, 3.0, 3.0, 3.0, 3.0, 3.0, 6.0, 6.0, 6.0, 6.0, 6.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 4.0, 4.0, 4.0, 4.0, 4.0, 1.0, 1.0, 1.0, 1.0, 1.0, 8.0, 8.0, 8.0, 8.0, 8.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 5.0, 5.0, 5.0, 5.0, 5.0, 6.0, 6.0, 6.0, 6.0, 6.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 6.0, 6.0, 6.0, 6.0, 6.0, 4.0, 4.0, 4.0, 4.0, 4.0, 7.0, 7.0, 7.0, 7.0, 7.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 5.0, 5.0, 5.0, 5.0, 5.0, 4.0, 4.0, 4.0, 4.0, 4.0, 6.0, 6.0, 6.0, 6.0, 6.0, 4.0, 4.0, 4.0, 4.0, 4.0, 7.0, 7.0, 7.0, 7.0, 7.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 4.0, 4.0, 4.0, 4.0, 4.0, 8.0, 8.0, 8.0, 8.0, 8.0, 1.0, 1.0, 1.0, 1.0, 1.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 8.0, 8.0, 8.0, 8.0, 8.0, 6.0, 6.0, 6.0, 6.0, 6.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 4.0, 4.0, 4.0, 4.0, 4.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 7.0, 7.0, 7.0, 7.0, 7.0, 8.0, 8.0, 8.0, 8.0, 8.0, 3.0, 3.0, 3.0, 3.0, 3.0, 5.0, 5.0, 5.0, 5.0, 5.0, 3.0, 3.0, 3.0, 3.0, 3.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 7.0, 7.0, 7.0, 7.0, 7.0, 4.0, 4.0, 4.0, 4.0, 4.0, 7.0, 7.0, 7.0, 7.0, 7.0, 1.0, 1.0, 1.0, 1.0, 1.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 5.0, 5.0, 5.0, 5.0, 5.0, 3.0, 3.0, 3.0, 3.0]
                    
Loading

@mofosyne mofosyne added the Review Complexity : Medium Generally require more time to grok but manageable by beginner to medium expertise level label May 23, 2024
teleprint-me pushed a commit to teleprint-me/llama.cpp that referenced this pull request May 23, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ggml changes relating to the ggml tensor library for machine learning Nvidia GPU Issues specific to Nvidia GPUs Review Complexity : Medium Generally require more time to grok but manageable by beginner to medium expertise level
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants