From 96a6e8bb3bca2818b629643b4bd6c9ab1706ad7e Mon Sep 17 00:00:00 2001 From: Roger Wang <136131678+ywang96@users.noreply.github.com> Date: Thu, 27 Jun 2024 11:43:04 -0700 Subject: [PATCH] [CI/Build] Fix Args for `_get_logits_warper` in Sampler Test (#5922) --- tests/samplers/test_sampler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/samplers/test_sampler.py b/tests/samplers/test_sampler.py index 0aabde6aa8c5c..9572588ce6e53 100644 --- a/tests/samplers/test_sampler.py +++ b/tests/samplers/test_sampler.py @@ -587,7 +587,7 @@ def test_sampler_top_k_top_p(seed: int, device: str): generation_config = GenerationConfig(top_k=top_k, top_p=top_p, do_sample=True) - warpers = generation_model._get_logits_warper(generation_config) + warpers = generation_model._get_logits_warper(generation_config, device) assert len(warpers) == 2 # top_p and top_k seq_group_metadata_list: List[SequenceGroupMetadata] = []