diff --git a/tests/kernels/test_moe.py b/tests/kernels/test_moe.py index 2356b9ec18b0d..22b6769ac3f23 100644 --- a/tests/kernels/test_moe.py +++ b/tests/kernels/test_moe.py @@ -29,7 +29,7 @@ def torch_moe(a, w1, w2, score, topk): topk_weight.view(B, -1, 1).to(out.dtype)).sum(dim=1) -@pytest.mark.parametrize("m", [512, 222, 33, 1]) +@pytest.mark.parametrize("m", [1024 * 128, 512, 222, 33, 1]) @pytest.mark.parametrize("n", [2048, 256, 1024]) @pytest.mark.parametrize("k", [128, 511, 1024]) @pytest.mark.parametrize("e", [8, 64]) diff --git a/vllm/envs.py b/vllm/envs.py index e8257535f1bf5..c624510c7ea1a 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -32,6 +32,7 @@ VLLM_OPENVINO_CPU_KV_CACHE_PRECISION: Optional[str] = None VLLM_OPENVINO_ENABLE_QUANTIZED_WEIGHTS: bool = False VLLM_XLA_CACHE_PATH: str = "~/.vllm/xla_cache/" + VLLM_FUSED_MOE_CHUNK_SIZE: int = 64 * 1024 VLLM_USE_RAY_COMPILED_DAG: bool = False VLLM_WORKER_MULTIPROC_METHOD: str = "fork" VLLM_IMAGE_FETCH_TIMEOUT: int = 5 @@ -248,6 +249,8 @@ # Only used for XLA devices such as TPUs. "VLLM_XLA_CACHE_PATH": lambda: os.getenv("VLLM_XLA_CACHE_PATH", "~/.vllm/xla_cache/"), + "VLLM_FUSED_MOE_CHUNK_SIZE": + lambda: int(os.getenv("VLLM_FUSED_MOE_CHUNK_SIZE", "65536")), } # end-env-vars-definition diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index ecab77a8b6dfb..99a5c7d78a67e 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -8,6 +8,7 @@ import triton import triton.language as tl +import vllm.envs as envs from vllm import _custom_ops as ops from vllm.logger import init_logger @@ -420,13 +421,12 @@ def fused_experts(hidden_states: torch.Tensor, torch.float32, torch.float16, torch.bfloat16 ] - M, _ = hidden_states.shape + num_tokens, _ = hidden_states.shape E, N, _ = w1.shape - - if M > 65536: - # https://github.com/vllm-project/vllm/issues/5938 - raise ValueError("MoE kernel does not support more than 65536 tokens, " - f"but got {M}") + # We execute the fused_moe kernel in chunks to circumvent this issue: + # https://github.com/vllm-project/vllm/issues/5938 + CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE + M = min(num_tokens, CHUNK_SIZE) if override_config: config = override_config @@ -455,51 +455,74 @@ def fused_experts(hidden_states: torch.Tensor, device=hidden_states.device, dtype=hidden_states.dtype) - sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size( - topk_ids, config['BLOCK_SIZE_M'], E) compute_type = (tl.bfloat16 if hidden_states.dtype == torch.bfloat16 else tl.float16) - invoke_fused_moe_kernel(hidden_states, - w1, - intermediate_cache1, - a1_scale, - w1_scale, - topk_weights, - topk_ids, - sorted_token_ids, - expert_ids, - num_tokens_post_padded, - False, - topk_ids.shape[1], - config, - compute_type=compute_type, - use_fp8=use_fp8) - - ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N)) - - invoke_fused_moe_kernel(intermediate_cache2, - w2, - intermediate_cache3, - a2_scale, - w2_scale, - topk_weights, - topk_ids, - sorted_token_ids, - expert_ids, - num_tokens_post_padded, - True, - 1, - config, - compute_type=compute_type, - use_fp8=use_fp8) - if inplace: - return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape), - dim=1, - out=hidden_states) - return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape), - dim=1) + out_hidden_states = hidden_states + else: + out_hidden_states = torch.empty_like(hidden_states) + + for chunk in range((num_tokens // CHUNK_SIZE) + 1): + begin_chunk_idx, end_chunk_idx = (chunk * CHUNK_SIZE, + min((chunk + 1) * CHUNK_SIZE, + num_tokens)) + curr_hidden_states = hidden_states[begin_chunk_idx:end_chunk_idx] + tokens_in_chunk, _ = curr_hidden_states.shape + + if tokens_in_chunk == 0: + break + + if tokens_in_chunk < CHUNK_SIZE: + # will only happen in the last chunk + intermediate_cache1 = intermediate_cache1[:tokens_in_chunk] + intermediate_cache2 = intermediate_cache2[:tokens_in_chunk] + intermediate_cache3 = intermediate_cache3[:tokens_in_chunk] + + curr_topk_ids = topk_ids[begin_chunk_idx:end_chunk_idx] + curr_topk_weights = topk_weights[begin_chunk_idx:end_chunk_idx] + + sorted_token_ids, expert_ids, num_tokens_post_padded = ( + moe_align_block_size(curr_topk_ids, config['BLOCK_SIZE_M'], E)) + + invoke_fused_moe_kernel(curr_hidden_states, + w1, + intermediate_cache1, + a1_scale, + w1_scale, + curr_topk_weights, + curr_topk_ids, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + False, + topk_ids.shape[1], + config, + compute_type=compute_type, + use_fp8=use_fp8) + + ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N)) + + invoke_fused_moe_kernel(intermediate_cache2, + w2, + intermediate_cache3, + a2_scale, + w2_scale, + curr_topk_weights, + curr_topk_ids, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + True, + 1, + config, + compute_type=compute_type, + use_fp8=use_fp8) + + torch.sum(intermediate_cache3.view(*intermediate_cache3.shape), + dim=1, + out=out_hidden_states[begin_chunk_idx:end_chunk_idx]) + return out_hidden_states def fused_moe(