Skip to content

Commit

Permalink
address review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
fxmarty committed Jul 11, 2024
1 parent da7068c commit bbff1a8
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 31 deletions.
5 changes: 4 additions & 1 deletion src/transformers/utils/fx.py
Original file line number Diff line number Diff line change
Expand Up @@ -998,7 +998,10 @@ def _generate_dummy_input(
elif "inputs_embeds" in input_name:
batch_size = shape[0]

if getattr(model.config, "embedding_size", None) is not None:
if (
getattr(model.config, "embedding_size", None) is not None
and model.config.model_type != "megatron-bert"
):
embedding_size = model.config.embedding_size
else:
embedding_size = model.config.hidden_size
Expand Down
55 changes: 25 additions & 30 deletions tests/test_modeling_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -1215,38 +1215,33 @@ def _create_and_check_torch_fx_tracing(self, config, inputs_dict, output_loss=Fa
(past_mask, inputs_to_test[1]["attention_mask"]), dim=1
)

if (
"input_ids" in inspect.signature(model.forward).parameters
and "inputs_embeds" in inspect.signature(model.forward).parameters
and not model.config.is_encoder_decoder
):
inps = copy.deepcopy(inputs_to_test[0])

embedding_size = (
model.config.embedding_size
if getattr(model.config, "embedding_size", None) is not None
and model.config.model_type != "megatron-bert"
else model.config.hidden_size
)
forward_parameters = inspect.signature(model.forward).parameters
if "input_ids" in forward_parameters and "inputs_embeds" in forward_parameters:
inps = copy.deepcopy(inputs_to_test[0])

embedding_size = (
model.config.embedding_size
if getattr(model.config, "embedding_size", None) is not None
and model.config.model_type != "megatron-bert"
else model.config.hidden_size
)

if (
model.config.model_type in MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES
and model.__class__.__name__ == MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES[model.config.model_type]
):
batch_size = inputs[next(iter(inputs))].shape[0]
num_choices = inputs[next(iter(inputs))].shape[1]
sequence_length = inputs[next(iter(inputs))].shape[2]
shape = (batch_size, num_choices, sequence_length, embedding_size)
elif inps["input_ids"].ndim == 2:
batch_size = inputs[next(iter(inputs))].shape[0]
sequence_length = inputs[next(iter(inputs))].shape[1]
shape = (batch_size, sequence_length, embedding_size)
else:
self.skipTest("Unknown case")
if (
model.config.model_type in MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES
and model.__class__.__name__
== MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES[model.config.model_type]
):
batch_size, num_choices, sequence_length = inputs["input_ids"].shape
shape = (batch_size, num_choices, sequence_length, embedding_size)
elif inps["input_ids"].ndim == 2:
batch_size, sequence_length = inputs["input_ids"].shape
shape = (batch_size, sequence_length, embedding_size)
else:
self.skipTest("Unknown case")

del inps["input_ids"]
inps["inputs_embeds"] = torch.rand(shape, dtype=torch.float, device=torch_device)
inputs_to_test.append(inps)
del inps["input_ids"]
inps["inputs_embeds"] = torch.rand(shape, dtype=torch.float, device=torch_device)
inputs_to_test.append(inps)

for inps in inputs_to_test:
filtered_inputs = {k: v for (k, v) in inps.items() if k in input_names}
Expand Down

0 comments on commit bbff1a8

Please sign in to comment.