Skip to content

Commit

Permalink
fix test
Browse files Browse the repository at this point in the history
  • Loading branch information
gante committed Dec 12, 2023
1 parent 9b51da1 commit 18a4eda
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions tests/generation/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3139,8 +3139,8 @@ def test_model_kwarg_assisted_decoding_encoder_decoder(self):
# PT-only test: TF doesn't support assisted decoding yet.
# Bart subclass with a kwarg that distorts the output
class FakeBart(BartForConditionalGeneration):
def forward(self, input_ids, foo=False, **kwargs):
outs = super().forward(input_ids, **kwargs)
def forward(self, input_ids, past_key_values, foo=False, **kwargs):
outs = super().forward(input_ids, past_key_values=past_key_values, **kwargs)
if foo:
outs["logits"][:, :, :] = 0.0
return outs
Expand Down Expand Up @@ -3170,9 +3170,9 @@ def prepare_inputs_for_generation(self, *args, foo=False, encoder_outputs=None,
self.assertListEqual(outputs_foo.tolist(), outputs_normal.tolist())

# Assistant model
assistant = AutoModelForSeq2SeqLM.from_pretrained(
"hf-internal-testing/tiny-random-BartForConditionalGeneration"
).to(torch_device)
assistant = FakeBart.from_pretrained("hf-internal-testing/tiny-random-BartForConditionalGeneration").to(
torch_device
)

# If assisted generation passes model_kwargs correctly, should be same as previous
outputs_assisted = model.generate(
Expand Down

0 comments on commit 18a4eda

Please sign in to comment.