From 92f6d3b55726d7610af74c76fb3bae2173d54890 Mon Sep 17 00:00:00 2001 From: fxmarty <9808326+fxmarty@users.noreply.github.com> Date: Mon, 24 Jun 2024 16:53:32 +0200 Subject: [PATCH] symbolic trace supports inputs_embeds --- src/transformers/utils/fx.py | 7 +++++++ tests/test_modeling_common.py | 19 +++++++++++++++++-- 2 files changed, 24 insertions(+), 2 deletions(-) diff --git a/src/transformers/utils/fx.py b/src/transformers/utils/fx.py index c3687c035c5837..6a8d99d672cbd0 100755 --- a/src/transformers/utils/fx.py +++ b/src/transformers/utils/fx.py @@ -995,6 +995,13 @@ def _generate_dummy_input( inputs_dict[input_name] = torch.zeros( *shape, model.config.input_feat_per_channel, dtype=torch.float, device=device ) + elif "inputs_embeds" in input_name: + batch_size = shape[0] + sequence_length = shape[1] + + inputs_dict[input_name] = torch.zeros( + batch_size, sequence_length, model.config.hidden_size, dtype=torch.float, device=device + ) elif "visual_feats" in input_name: inputs_dict[input_name] = torch.zeros( shape diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index f7f0db79a8d7ce..3d1e37ebba3ff8 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -1271,6 +1271,7 @@ def _create_and_check_torch_fx_tracing(self, config, inputs_dict, output_loss=Fa "input_features", "input_ids", "input_values", + "inputs_embeds", "pixel_values", "token_type_ids", "visual_feats", @@ -1327,16 +1328,30 @@ def _create_and_check_torch_fx_tracing(self, config, inputs_dict, output_loss=Fa (past_mask, inputs_to_test[1]["attention_mask"]), dim=1 ) + if "inputs_embeds" in inspect.signature(model.forward).parameters: + inputs_to_test.append( + { + "inputs_embeds": torch.rand( + 3, 5, model.config.hidden_size, dtype=torch.float, device=torch_device + ) + } + ) + for inps in inputs_to_test: filtered_inputs = {k: v for (k, v) in inps.items() if k in input_names} - input_names = list(filtered_inputs.keys()) + input_names_to_trace = list(filtered_inputs.keys()) if model.__class__.__name__ in set(MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES.values()) and ( not hasattr(model.config, "problem_type") or model.config.problem_type is None ): model.config.problem_type = "single_label_classification" - traced_model = symbolic_trace(model, input_names) + if "past_key_values" in input_names_to_trace: + model.config.use_cache = True + else: + model.config.use_cache = False + + traced_model = symbolic_trace(model, input_names_to_trace) with torch.no_grad(): traced_output = traced_model(**filtered_inputs)