Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Question] How to match Flash Attention 2 performance? #98

Open
vedantroy opened this issue Aug 14, 2024 · 4 comments
Open

[Question] How to match Flash Attention 2 performance? #98

vedantroy opened this issue Aug 14, 2024 · 4 comments

Comments

@vedantroy
Copy link

I wrote a helper that allows someone to use CuDNN attention within Pytorch seamlessly.

import cudnn
import torch
import math

# export CUDNN_FRONTEND_LOG_FLIE=fe.log
# export CUDNN_FRONTEND_LOG_INFO=1

# import os
# os.environ["CUDNN_FRONTEND_LOG_FILE"] = "fe.log"
# os.environ["CUDNN_FRONTEND_LOG_INFO"] = "1"

def convert_to_cudnn_type(torch_type):
    if torch_type == torch.float16:
        return cudnn.data_type.HALF
    elif torch_type == torch.bfloat16:
        return cudnn.data_type.BFLOAT16
    elif torch_type == torch.float32:
        return cudnn.data_type.FLOAT
    elif torch_type == torch.int32:
        return cudnn.data_type.INT32
    elif torch_type == torch.int64:
        return cudnn.data_type.INT64
    else:
        raise ValueError("Unsupported tensor data type.")

def make_cudnn_autograd(*, num_heads, head_dim, dtype):
    assert dtype in [torch.float16, torch.bfloat16], f"Invalid dtype {dtype}"
    dtype = convert_to_cudnn_type(dtype)
    # match CuDNN's docs
    H, D = num_heads, head_dim
    del num_heads, head_dim

    cache = {}

    def init_or_check_tensor_attrs(tensor_name, tensor):
        nonlocal cache
        for attr in ['shape', 'stride', 'dtype', 'device']:
            key = f'{tensor_name}_{attr}'
            if key not in cache:
                cache[key] = getattr(tensor, attr)
                if callable(cache[key]):
                    cache[key] = cache[key]()
            else:
                v = cache[key]() if callable(cache[key]) else cache[key]
                assert cache[key] == v, f"Expected {cache[key]} but got {v}"


    class CuDNNAttention(torch.autograd.Function):
        @staticmethod
        def forward(ctx, B, N, L, q, kv, seqlens_kv):
            assert q.shape == (B, N, H, D)
            assert kv.shape == (B, N + L, 2, H, D)
            assert seqlens_kv.shape == (B,)

            # CuDNN plans are compiled for a specific shape, stride, dtype
            # So we need to verify those attributes
            init_or_check_tensor_attrs('q', q)
            init_or_check_tensor_attrs('kv', kv)
            init_or_check_tensor_attrs('seqlens_kv', seqlens_kv)

            q = q.permute(0, 2, 1, 3)  # B N H D -> B H N D
            kv_view = kv.permute(2, 0, 3, 1, 4) # B S KV H D -> KV B H S D
            k_view, v_view = torch.unbind(kv_view, dim=0)

            assert not k_view.is_contiguous() and not v_view.is_contiguous(), f"kv should not be contiguous (unnecessary copy)"
            assert k_view.shape == (B, H, (N + L),  D), f"Got shape {k_view.shape} instead of {(B, num_heads, (N + L),  D)}"
            assert v_view.shape == (B, H, (N + L), D)

            # TODO: Is this safe?
            if 'stats' not in cache:
                cache['stats'] = torch.empty(B, H, N, 1, dtype=torch.float32, device=q.device)
                cache['seqlens_q'] = torch.tensor([N] * B, device=q.device, dtype=torch.int32).view(B, 1, 1, 1)
                cache['o'] = torch.empty_like(q)

            stats = cache['stats']
            seqlens_q = cache['seqlens_q']
            o = cache['o']

            seqlens_kv = seqlens_kv.view(B, 1, 1, 1)

            if 'compiled_graph_fwd' not in cache:
                print("Compiling CuDNN graphs ...")
                g_fwd = cudnn.pygraph(
                    io_data_type=dtype,
                    intermediate_data_type=cudnn.data_type.FLOAT,
                    compute_data_type=cudnn.data_type.FLOAT,
                )
                cache['name_to_cu_tensor'] = {
                    'q_cu': g_fwd.tensor_like(q.detach()),
                    'k_cu': g_fwd.tensor_like(k_view.detach()),
                    'v_cu': g_fwd.tensor_like(v_view.detach()),
                    'seqlens_q_cu': g_fwd.tensor_like(seqlens_q.detach()),
                    'seqlens_kv_cu': g_fwd.tensor_like(seqlens_kv.detach())
                }
                cu_tens = cache['name_to_cu_tensor']

                o_forward, stats_forward = g_fwd.sdpa(
                    name="sdpa",
                    q=cu_tens['q_cu'],
                    k=cu_tens['k_cu'],
                    v=cu_tens['v_cu'],
                    is_inference=False,
                    attn_scale=1.0 / math.sqrt(D),
                    use_causal_mask=False,
                    use_padding_mask=True,
                    seq_len_q=cu_tens['seqlens_q_cu'],
                    seq_len_kv=cu_tens['seqlens_kv_cu']
                )

                o_forward.set_output(True).set_dim(o.shape).set_stride(o.stride()).set_data_type(dtype)
                stats_forward.set_output(True).set_data_type(cudnn.data_type.FLOAT).set_dim(stats.shape).set_stride(stats.stride())

                cu_tens['o_forward_cu'] = o_forward
                cu_tens['stats_forward_cu'] = stats_forward

                def assert_cudnn_shape(tensor, expected_shape):
                    assert tuple(tensor.get_dim()) == expected_shape, f"Expected shape {expected_shape} but got {tensor.get_dim()}"

                assert_cudnn_shape(cu_tens['q_cu'], (B, H, N, D))
                assert_cudnn_shape(cu_tens['k_cu'], (B, H, N + L, D))
                assert_cudnn_shape(cu_tens['v_cu'], (B, H, N + L, D))
                assert_cudnn_shape(cu_tens['o_forward_cu'], (B, H, N, D))
                assert_cudnn_shape(cu_tens['stats_forward_cu'], (B, H, N, 1))
                assert_cudnn_shape(cu_tens['seqlens_q_cu'], (B, 1, 1, 1))
                assert_cudnn_shape(cu_tens['seqlens_kv_cu'], (B, 1, 1, 1))

                g_fwd.validate()
                g_fwd.build_operation_graph()
                g_fwd.create_execution_plans([cudnn.heur_mode.A, cudnn.heur_mode.FALLBACK])
                g_fwd.check_support()
                g_fwd.build_plans()

                cache['compiled_graph_fwd'] = g_fwd

                g_bwd = cudnn.pygraph(
                     io_data_type=dtype,
                     intermediate_data_type=cudnn.data_type.FLOAT,
                     compute_data_type=cudnn.data_type.FLOAT,
                )

                cu_tens['q_cu_bwd'] = g_bwd.tensor_like(q.detach())
                cu_tens['k_cu_bwd'] = g_bwd.tensor_like(k_view.detach())
                cu_tens['v_cu_bwd'] = g_bwd.tensor_like(v_view.detach())
                cu_tens['o_cu_bwd'] = g_bwd.tensor_like(o.detach())
                cu_tens['dO_cu_bwd'] = g_bwd.tensor_like(o.detach())
                cu_tens['stats_cu_bwd'] = g_bwd.tensor_like(stats.detach())
                cu_tens['seqlens_q_cu_bwd'] = g_bwd.tensor_like(seqlens_q.detach())
                cu_tens['seqlens_kv_cu_bwd'] = g_bwd.tensor_like(seqlens_kv.detach())

                dQ_bwd_cu, dK_bwd_cu, dV_bwd_cu = g_bwd.sdpa_backward(
                    name="sdpa_backward",
                    q=cu_tens['q_cu_bwd'],
                    k=cu_tens['k_cu_bwd'],
                    v=cu_tens['v_cu_bwd'],
                    o=cu_tens['o_cu_bwd'],
                    dO=cu_tens['dO_cu_bwd'],
                    stats=cu_tens['stats_cu_bwd'],
                    attn_scale=1.0 / math.sqrt(D),
                    use_causal_mask=False,
                    use_padding_mask=True,
                    seq_len_q=cu_tens['seqlens_q_cu_bwd'],
                    seq_len_kv=cu_tens['seqlens_kv_cu_bwd']
                )

                # TODO: Is this safe?
                # cache['dQ'] = torch.empty_like(q).contiguous()
                # cache['dK'] = torch.empty_like(k_view).contiguous()
                # cache['dV'] = torch.empty_like(v_view).contiguous()

                cache['dQ'] = torch.empty_like(q)
                cache['dK'] = torch.empty_like(k_view)
                cache['dV'] = torch.empty_like(v_view)

                dQ_bwd_cu.set_output(True).set_dim(cache['dQ'].size()).set_stride(cache['dQ'].stride())
                dK_bwd_cu.set_output(True).set_dim(cache['dK'].size()).set_stride(cache['dK'].stride())
                dV_bwd_cu.set_output(True).set_dim(cache['dV'].size()).set_stride(cache['dV'].stride())

                cu_tens['dQ_cu_bwd'] = dQ_bwd_cu
                cu_tens['dK_cu_bwd'] = dK_bwd_cu
                cu_tens['dV_cu_bwd'] = dV_bwd_cu

                assert_cudnn_shape(cu_tens['q_cu_bwd'], (B, H, N, D))
                assert_cudnn_shape(cu_tens['k_cu_bwd'], (B, H, N + L, D))
                assert_cudnn_shape(cu_tens['v_cu_bwd'], (B, H, N + L, D))
                assert_cudnn_shape(cu_tens['dQ_cu_bwd'], (B, H, N, D))
                assert_cudnn_shape(cu_tens['dK_cu_bwd'], (B, H, N + L, D))
                assert_cudnn_shape(cu_tens['dV_cu_bwd'], (B, H, N + L, D))
                assert_cudnn_shape(cu_tens['o_cu_bwd'], (B, H, N, D))
                assert_cudnn_shape(cu_tens['dO_cu_bwd'], (B, H, N, D))
                assert_cudnn_shape(cu_tens['stats_cu_bwd'], (B, H, N, 1))
                assert_cudnn_shape(cu_tens['seqlens_q_cu_bwd'], (B, 1, 1, 1))
                assert_cudnn_shape(cu_tens['seqlens_kv_cu_bwd'], (B, 1, 1, 1))

                g_bwd.validate()
                g_bwd.build_operation_graph()
                g_bwd.create_execution_plans([cudnn.heur_mode.A, cudnn.heur_mode.FALLBACK])
                g_bwd.check_support()
                g_bwd.build_plans()

                cache['compiled_graph_bwd'] = g_bwd

                # TODO: Is this safe?
                cache['workspace'] = torch.empty(
                    max(g_fwd.get_workspace_size(), g_bwd.get_workspace_size()),
                    device=q.device, dtype=torch.uint8
                )
            
            name_to_cu_tensor = cache['name_to_cu_tensor']
            variant_pack_forward = {
                name_to_cu_tensor[name]: tensor for name, tensor in [
                    ('q_cu', q),
                    ('k_cu', k_view),
                    ('v_cu', v_view),
                    ('o_forward_cu', o),
                    ('stats_forward_cu', stats),
                    ('seqlens_q_cu', seqlens_q),
                    ('seqlens_kv_cu', seqlens_kv)
                ]
            }
            cache['compiled_graph_fwd'].execute(variant_pack_forward, cache['workspace'])
            ctx.save_for_backward(q, k_view, v_view, o, stats, seqlens_kv)
            ctx.B, ctx.N, ctx.L = B, N, L
            ctx.dtype = dtype
            return o

        @staticmethod
        def backward(ctx, grad_output):
            q, k, v, o, stats, seqlens = ctx.saved_tensors
            B, N, L = ctx.B, ctx.N, ctx.L
            seqlens_q = cache['seqlens_q']

            cu_tens = cache['name_to_cu_tensor']

            assert tuple(grad_output.shape) ==  (B, H, N, D)
            assert tuple(grad_output.shape) == tuple(cu_tens['dO_cu_bwd'].get_dim())
            # For batch size 1, the stride can have 2 1s, I think this is a Pytorch bug
            # https://discuss.pytorch.org/t/stride-has-2-1s-in-it/208036
            assert tuple(grad_output.stride())[1:] == tuple(cu_tens['dO_cu_bwd'].get_stride())[1:], f"{tuple(cu_tens['dO_cu_bwd'].get_stride())} (expected) != {tuple(grad_output.stride())} (actual) for shape {tuple(grad_output.shape)}"
            assert convert_to_cudnn_type(grad_output.dtype) == cu_tens['dO_cu_bwd'].get_data_type()

            variant_pack_backward = {
                cu_tens[name]: tensor for name, tensor in [
                    ('dQ_cu_bwd', cache['dQ']),
                    ('dK_cu_bwd', cache['dK']),
                    ('dV_cu_bwd', cache['dV']),
                    ('q_cu_bwd', q),
                    ('k_cu_bwd', k),
                    ('v_cu_bwd', v),
                    ('o_cu_bwd', o),
                    ('dO_cu_bwd', grad_output),
                    ('stats_cu_bwd', stats),
                    ('seqlens_q_cu_bwd', seqlens_q),
                    ('seqlens_kv_cu_bwd', seqlens)
                ]
            }

            cache['compiled_graph_bwd'].execute(variant_pack_backward, cache['workspace'])
            assert cache['dQ'].shape == (B, H, N, D)
            dQ = cache['dQ'].permute(0, 2, 1, 3) # B H N D -> B N H D

            assert cache['dK'].shape == (B, H, N + L, D)
            assert cache['dV'].shape == (B, H, N + L, D)

            dKV = torch.stack([cache['dK'], cache['dV']], dim=2)
            assert dKV.shape == (B, H, 2, N + L, D)

            dKV = dKV.permute(0, 3, 2, 1, 4) # B H 2 N D -> B N 2 H D

            return None, None, None, dQ, dKV, None

    return CuDNNAttention

However, while this gets better forward pass performance. It gets far worse backwards pass performance. Any thoughts on why this might be the case? I'm hoping there might be some obvious deficiency in my code.

(Unit is ms).

attention-forward-performance:
   batch_size     CuDNN  FlashAttention
0         1.0  0.022976        0.033024
1         2.0  0.021664        0.039456
2         4.0  0.047680        0.058112
3         6.0  0.056800        0.072208

attention-backward-performance:
   batch_size     CuDNN  FlashAttention
0         2.0  0.386144        0.282272
1         4.0  0.741664        0.301184
2         6.0  1.108608        0.464320
@Anerudhan
Copy link
Collaborator

  • What card is this? nvidia-smi output will be great.
  • What is the cudnn version you are using?
    python -c 'import cudnn; print(cudnn.backend_version_string())'

@vedantroy
Copy link
Author

CuDNN version: 9.1.0.

nvidia-smi:

+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.183.01             Driver Version: 535.183.01   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|=========================================+======================+======================|
|   0  NVIDIA H100 80GB HBM3          On  | 00000002:00:01.0 Off |                    0 |
| N/A   31C    P0             116W / 700W |    885MiB / 81559MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
|   1  NVIDIA H100 80GB HBM3          On  | 00000002:00:02.0 Off |                    0 |
| N/A   30C    P0              72W / 700W |      3MiB / 81559MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
|   2  NVIDIA H100 80GB HBM3          On  | 00000002:00:03.0 Off |                    0 |
| N/A   29C    P0              70W / 700W |      3MiB / 81559MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
|   3  NVIDIA H100 80GB HBM3          On  | 00000002:00:04.0 Off |                    0 |
| N/A   27C    P0              68W / 700W |      3MiB / 81559MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
|   4  NVIDIA H100 80GB HBM3          On  | 00000003:00:01.0 Off |                    0 |
| N/A   28C    P0              68W / 700W |      3MiB / 81559MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
|   5  NVIDIA H100 80GB HBM3          On  | 00000003:00:02.0 Off |                    0 |
| N/A   29C    P0              73W / 700W |      3MiB / 81559MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
|   6  NVIDIA H100 80GB HBM3          On  | 00000003:00:03.0 Off |                    0 |
| N/A   30C    P0              70W / 700W |      3MiB / 81559MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
|   7  NVIDIA H100 80GB HBM3          On  | 00000003:00:04.0 Off |                    0 |
| N/A   27C    P0              69W / 700W |      3MiB / 81559MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
                                                                                         
+---------------------------------------------------------------------------------------+
| Processes:                                                                            |
|  GPU   GI   CI        PID   Type   Process name                            GPU Memory |
|        ID   ID                                                             Usage      |
|=======================================================================================|
|    0   N/A  N/A    239227      C   /tmp/env/bin/python                         876MiB |
+---------------------------------------------------------------------------------------+

@vedantroy
Copy link
Author

Archive.zip

I've attached both the benchmarking and CuDNN wrapper code to this post. I suspect the benchmarking code is off, so I'll switch to something simpler (like the Pytorch profiler), and see what the results are.

@Anerudhan
Copy link
Collaborator

You can try improvising on this Install FAV2 pip inside the container and go from there. Try out the latest container (24.07, just in case).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants