Skip to content

Commit

Permalink
fix test?
Browse files Browse the repository at this point in the history
  • Loading branch information
fxmarty committed Jun 24, 2024
1 parent 92f6d3b commit 11ffa68
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 2 deletions.
2 changes: 1 addition & 1 deletion src/transformers/utils/fx.py
Original file line number Diff line number Diff line change
Expand Up @@ -997,7 +997,7 @@ def _generate_dummy_input(
)
elif "inputs_embeds" in input_name:
batch_size = shape[0]
sequence_length = shape[1]
sequence_length = shape[-1]

inputs_dict[input_name] = torch.zeros(
batch_size, sequence_length, model.config.hidden_size, dtype=torch.float, device=device
Expand Down
2 changes: 1 addition & 1 deletion tests/test_modeling_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -1332,7 +1332,7 @@ def _create_and_check_torch_fx_tracing(self, config, inputs_dict, output_loss=Fa
inputs_to_test.append(
{
"inputs_embeds": torch.rand(
3, 5, model.config.hidden_size, dtype=torch.float, device=torch_device
2, 2, model.config.hidden_size, dtype=torch.float, device=torch_device
)
}
)
Expand Down

0 comments on commit 11ffa68

Please sign in to comment.