From f240b103ce58f0f8f0b81c9743c995ba63727428 Mon Sep 17 00:00:00 2001 From: Varun Sundar Rabindranath Date: Mon, 16 Sep 2024 04:52:24 +0000 Subject: [PATCH] add tests --- .../multi_step/test_correctness_async_llm.py | 24 +++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/tests/multi_step/test_correctness_async_llm.py b/tests/multi_step/test_correctness_async_llm.py index a75a671e57f74..78a6916461de6 100644 --- a/tests/multi_step/test_correctness_async_llm.py +++ b/tests/multi_step/test_correctness_async_llm.py @@ -25,6 +25,19 @@ "16", ] +def skip_test(is_chunked_prefill: bool, + tp_size: int, + pp_size: int, + attn_backend: str) -> bool: + if not is_chunked_prefill: + return False + + if tp_size == 1 and \ + pp_size == 1 and \ + attn_backend == "FLASH_ATTN": + return False + + return True @pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize(("tp_size, pp_size"), [ @@ -37,6 +50,7 @@ @pytest.mark.parametrize("num_logprobs", [5]) @pytest.mark.parametrize("is_async", [True]) @pytest.mark.parametrize("attention_backend", ["FLASHINFER", "FLASH_ATTN"]) +@pytest.mark.parametrize("with_chunked_prefill", [True, False]) @pytest.mark.asyncio async def test_multi_step( example_prompts, @@ -49,6 +63,7 @@ async def test_multi_step( is_async: bool, num_logprobs: Optional[int], attention_backend: str, + with_chunked_prefill: bool, monkeypatch, ) -> None: """Test vLLM engine with multi-step scheduling in an OpenAI-protocol @@ -75,6 +90,12 @@ async def test_multi_step( completions endpoint; `None` -> no logprobs """ + if skip_test(is_chunked_prefill = with_chunked_prefill, + tp_size = tp_size, + pp_size = pp_size, + attn_backend = attention_backend): + return + override_backend_env_variable(monkeypatch, attention_backend) prompts = example_prompts @@ -93,6 +114,9 @@ async def test_multi_step( if eager_mode: ms_server_args.append("--enforce-eager") + if with_chunked_prefill: + ms_server_args.append("--enable-chunked-prefill") + distributed_args = [ "--tensor-parallel-size", str(tp_size),