Skip to content

Commit

Permalink
Update tests/test_modeling_common.py
Browse files Browse the repository at this point in the history
Co-authored-by: amyeroberts <[email protected]>
  • Loading branch information
fxmarty and amyeroberts authored Jun 24, 2024
1 parent 11ffa68 commit 6956a93
Showing 1 changed file with 1 addition and 4 deletions.
5 changes: 1 addition & 4 deletions tests/test_modeling_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -1346,10 +1346,7 @@ def _create_and_check_torch_fx_tracing(self, config, inputs_dict, output_loss=Fa
):
model.config.problem_type = "single_label_classification"

if "past_key_values" in input_names_to_trace:
model.config.use_cache = True
else:
model.config.use_cache = False
model.config.use_cache = "past_key_values" in input_names_to_trace

traced_model = symbolic_trace(model, input_names_to_trace)

Expand Down

0 comments on commit 6956a93

Please sign in to comment.