Skip to content

Commit

Permalink
[Bugfix] Added test for sampling repetition penalty bug. (vllm-projec…
Browse files Browse the repository at this point in the history
…t#5659)

Signed-off-by: Thomas Parnell <[email protected]>
  • Loading branch information
tdoublep authored and jimpang committed Jul 24, 2024
1 parent 03b6a8b commit cdffcbb
Showing 1 changed file with 69 additions and 0 deletions.
69 changes: 69 additions & 0 deletions tests/samplers/test_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -631,3 +631,72 @@ def mock_sample(probs, *args, **kwargs):
hf_probs = torch.softmax(hf_probs, dim=-1, dtype=torch.float)
assert torch.allclose(hf_probs, sample_probs, atol=1e-5)
assert torch.equal(hf_probs.eq(0), sample_probs.eq(0))


@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_sampler_repetition_penalty_mixed(device: str):

vocab_size = 8

def test_sampling_params(sampling_params: List[SamplingParams]):

seq_group_metadata_list: List[SequenceGroupMetadata] = []
seq_lens: List[int] = []
for i in range(2):
seq_group_metadata_list.append(
SequenceGroupMetadata(
request_id=f"test_{i}",
is_prompt=True,
seq_data={0: SequenceData([1, 2, 3])},
sampling_params=sampling_params[i],
block_tables={0: [1]},
))
seq_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())

sampling_metadata = SamplingMetadata.prepare(
seq_group_metadata_list,
seq_lens,
query_lens=seq_lens,
device=device,
pin_memory=is_pin_memory_available())

fake_logits = torch.full((2, vocab_size),
1e-2,
device=device,
dtype=torch.float16)

fake_logits[:, 5] = 1.1e-2
fake_logits[:, 1] = 1.2e-2

sampler = MockLogitsSampler(fake_logits)

sampler_output = sampler(logits=fake_logits,
sampling_metadata=sampling_metadata)

generated_tokens = []
for output in sampler_output:
generated_tokens.append(output.samples[0].output_token)

return generated_tokens

# one configuration is greedy with repetition_penalty
sampling_params_rep = SamplingParams(
temperature=0.0,
repetition_penalty=2.0,
)

# other configuration is sampling w/o repetition_penalty
sampling_params_sample = SamplingParams(
temperature=1.0,
top_k=1,
seed=42,
)

tokens1 = test_sampling_params(
[sampling_params_rep, sampling_params_sample])

tokens2 = test_sampling_params(
[sampling_params_sample, sampling_params_rep])

assert tokens1[0] == tokens2[1]
assert tokens1[1] == tokens2[0]

0 comments on commit cdffcbb

Please sign in to comment.