Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimized fused MoE Kernel #2913

Closed
wants to merge 36 commits into from

Conversation

pcmoritz
Copy link
Collaborator

@pcmoritz pcmoritz commented Feb 19, 2024

This PR is based on @WoosukKwon 's excellent work in porting the TensorRT MoE kernels in https://github.com/vllm-project/vllm/tree/cutlass-moe

It is based on the observation that the TensorRT MoE kernels are working very well in the small batch size regime, whereas the fused MoE kernel is working much better in the large batch size regime. I have been trying to optimize the triton kernels in the small batch size regime too, but unfortunately triton doesn't seem to have great support for matrix multiplications that involve skinny matrices (e.g. tl.dot only supports dimensions >= 16). Therefore, we use the TensorRT kernel in the small batch size regime and the fused MoE kernels in the large batch size regime. It would be much preferable to have one unified kernel for all regimes, so if anybody knows how to make that happen, I'd love to know.

This PR also incorporates some of @cadedaniel 's work on autotuning the fused MoE kernel.

The benchmarks are as follows (all on H100 with TP2, using 1000 input and 50 output tokens):

This PR with below tuning configs:

qps = 1 => 16.9 ms ITL (0.85s end-to-end completion time per request)
qps = 2 => 19.0 ms ITL (0.95s end-to-end completion time per request)
qps = 4 => 32.7 ms ITL (1.63s end-to-end completion time per request)
qps = 6 => 43.4 ms ITL (2.16s end-to-end completion time per request)

current main branch (untuned fused MoE kernel):

qps = 1 => 23.3 ms ITL (1.17s end-to-end completion time per request)
qps = 2 => 25.4ms ITL (1.27s end-to-end completion time per request)
qps = 4 => 43.0ms ITL (2.15s end-to-end completion time per request)
qps = 6 => 60.8ms ITL (3.04s end-to-end complition time per request)

only using the TensorRT Moe kernels:

qps = 1 => 18.1 ms ITL (0.90s end-to-end completion time per request)
qps = 2 => 23.8 ms ITL (1.19s end-to-end completion time per request)
qps = 4 => 48.1 ms ITL (2.36s end-to-end completion time per request)
qps = 6 => 90.8 ms ITL (4.54s end-to-end completion time per request)

You can run the autotuned kernel by setting

export VLLM_MIXTRAL_FUSE_MOE_CONFIG=/path/to/fused_moe_h100_tp2_config.json

where fused_moe_h100_tp2_config.json contains the following file:

{
    "64": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 64, "num_warps": 4, "num_stages": 4},
    "128": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 32, "num_warps": 4, "num_stages": 4},
    "256": {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 32, "num_warps": 8, "num_stages": 4},
    "512": {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 64, "num_warps": 8, "num_stages": 4},
    "1024": {"BLOCK_SIZE_M": 256, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 64, "num_warps": 8, "num_stages": 4},
    "2048": {"BLOCK_SIZE_M": 256, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 64, "num_warps": 8, "num_stages": 4},
    "4096": {"BLOCK_SIZE_M": 256, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 16, "num_warps": 8, "num_stages": 4}
}

@WoosukKwon
Copy link
Collaborator

Hi @pcmoritz Thanks for the amazing PR! Is this PR ready for review? Or, do you have any blocker to the PR?

@pcmoritz
Copy link
Collaborator Author

pcmoritz commented Feb 20, 2024

I think we should merge your kernel https://github.com/vllm-project/vllm/tree/cutlass-moe as a separate PR and then we can merge this one. If you open the PR about the TensorRT kernels, I'm happy to review it! The thing I'm currently unsure about is whether we should have two different kernels in the two different regimes, that seems very unfortunate to me.

I'll be looking a little more if we can get more out of the triton kernel in the low batch size regime and will keep you updated. Let's come to a conclusion before the end of this week and execute on it :)

Also I'm curious about your thoughts on this (stitching together two kernels).

@pcmoritz
Copy link
Collaborator Author

Closed in favor of #2979

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants