Skip to content

Commit

Permalink
Overhaul TF serving signatures + dummy inputs (#23234)
Browse files Browse the repository at this point in the history
* Let's try autodetecting serving sigs

* Don't clobber existing sigs

* Change shapes for multiplechoice models

* Make default dummy inputs smarter too

* Fix missing f-string

* Let's YOLO a serving output too

* Read __class__.__name__ properly

* Don't just pass naked lists in there and expect it to be okay

* Code cleanup

* Update default serving sig

* Clearer error messages

* Further updates to the default serving output

* make fixup

* Update the serving output a bit more

* Cleanups and renames, raise errors appropriately when we can't infer inputs

* More renames

* we're building in a functional context again, yolo

* import DUMMY_INPUTS from the right place

* import DUMMY_INPUTS from the right place

* Support cross-attention in the dummies

* Support cross-attention in the dummies

* Complete removal of dummy/serving overrides in BERT

* Complete removal of dummy/serving overrides in RoBERTa

* Obliterate lots and lots of serving sig and dummy overrides

* merge type hint changes

* Fix for token_type_ids with vocab_size 1

* Add missing property decorator

* Fix T5 and hopefully some models that take conv inputs

* More signature pruning

* Fix T5's signature

* Fix Wav2Vec2 signature

* Fix LongformerForMultipleChoice input signature

* Fix BLIP and LED

* Better default serving output error handling

* Fix BART dummies

* Fix dummies for cross-attention, esp encoder-decoder models

* Fix visionencoderdecoder signature

* Fix BLIP serving output

* Small tweak to BART dummies

* Cleanup the ugly parameter inspection line that I used in a few places

* committed a breakpoint again

* Move the text_dims check

* Remove blip_text serving_output

* Add decoder_input_ids to the default input sig

* Remove all the manual overrides for encoder-decoder model signatures

* Tweak longformer/led input sigs

* Tweak default serving output

* output.keys() -> output

* make fixup
  • Loading branch information
Rocketknight1 authored May 24, 2023
1 parent 3d7baef commit 814de8f
Show file tree
Hide file tree
Showing 65 changed files with 275 additions and 4,144 deletions.
131 changes: 98 additions & 33 deletions src/transformers/modeling_tf_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@
from .generation import GenerationConfig, TFGenerationMixin
from .tf_utils import expand_1d, load_attributes_from_hdf5_group, save_attributes_to_hdf5_group, shape_list
from .utils import (
DUMMY_INPUTS,
SAFE_WEIGHTS_INDEX_NAME,
SAFE_WEIGHTS_NAME,
TF2_WEIGHTS_INDEX_NAME,
Expand Down Expand Up @@ -1114,9 +1113,25 @@ def dummy_inputs(self) -> Dict[str, tf.Tensor]:
Returns:
`Dict[str, tf.Tensor]`: The dummy inputs.
"""
return {
"input_ids": tf.constant(DUMMY_INPUTS, dtype=tf.int32),
}
dummies = {}
sig = self._prune_signature(self.input_signature)
for key, spec in sig.items():
# 3 is the most correct arbitrary size. I will not be taking questions
dummies[key] = tf.ones(shape=[dim if dim is not None else 3 for dim in spec.shape], dtype=spec.dtype)
if key == "token_type_ids":
# Some models have token_type_ids but with a vocab_size of 1
dummies[key] = tf.zeros_like(dummies[key])
if self.config.add_cross_attention and "encoder_hidden_states" in inspect.signature(self.call).parameters:
if "encoder_hidden_states" not in dummies:
if self.main_input_name == "input_ids":
dummies["encoder_hidden_states"] = tf.ones(
shape=(3, 3, self.config.hidden_size), dtype=tf.float32, name="encoder_hidden_states"
)
else:
raise NotImplementedError(
"Model has cross-attention but we couldn't infer the shape for the encoder hidden states. Please manually override dummy_inputs!"
)
return dummies

@property
def framework(self) -> str:
Expand All @@ -1137,6 +1152,10 @@ def __init__(self, config, *inputs, **kwargs):
self.config = config
self.name_or_path = config.name_or_path
self.generation_config = GenerationConfig.from_model_config(config) if self.can_generate() else None
if not hasattr(self, "serving"): # Don't overwrite existing serving signatures
self.serving = tf.function(
self.eager_serving, input_signature=[self._prune_signature(self.input_signature)]
)
# Set the serving spec quickly to ensure that Keras doesn't use the specific dummy input shapes as the spec
self._set_save_spec(self.serving.input_signature[0])

Expand Down Expand Up @@ -1201,36 +1220,82 @@ def eager_serving(self, inputs):

return self.serving_output(output)

@tf.function(
input_signature=[
{
"input_ids": tf.TensorSpec((None, None), tf.int32, name="input_ids"),
"attention_mask": tf.TensorSpec((None, None), tf.int32, name="attention_mask"),
"token_type_ids": tf.TensorSpec((None, None), tf.int32, name="token_type_ids"),
}
]
)
def serving(self, inputs):
@property
def input_signature(self) -> Dict[str, tf.TensorSpec]:
"""
Method used for serving the model.
Args:
inputs (`Dict[str, tf.Tensor]`):
The input of the saved model as a dictionary of tensors.
This property should return a dict mapping input names to tf.TensorSpec objects, representing the expected
shape and dtype for model inputs. It is used for both serving and for generating the dummy inputs used to build
the model.
"""
output = self.call(inputs)
model_inputs = list(inspect.signature(self.call).parameters)
sig = {}
if "input_ids" in model_inputs:
if self.__class__.__name__.endswith("ForMultipleChoice"):
text_dims = 3
else:
text_dims = 2
for input_name in (
"input_ids",
"attention_mask",
"token_type_ids",
"decoder_input_ids",
"decoder_attention_mask",
):
if input_name in model_inputs:
sig[input_name] = tf.TensorSpec([None] * text_dims, tf.int32, name=input_name)
if "pixel_values" in model_inputs:
pixel_values_shape = [None, None, None, None]
if hasattr(self.config, "vision_config"):
vision_config = self.config.vision_config
else:
vision_config = self.config
if hasattr(vision_config, "num_channels"):
pixel_values_shape[1] = vision_config.num_channels
else:
raise NotImplementedError(
"Could not infer number of channels from config, please override input_signature to specify input shapes."
)
if hasattr(vision_config, "image_size"):
pixel_values_shape[2] = pixel_values_shape[3] = vision_config.image_size
elif hasattr(vision_config, "input_size"):
pixel_values_shape[2] = pixel_values_shape[3] = vision_config.input_size
else:
raise NotImplementedError(
"Could not infer input image shape from config, please override input_signature to specify input shapes."
)
sig["pixel_values"] = tf.TensorSpec(pixel_values_shape, tf.float32, name="pixel_values")
if "input_features" in model_inputs:
raise NotImplementedError("Audio models need a manually defined input_signature")
return sig

return self.serving_output(output)
def _prune_signature(self, signature):
"""Keeps only the keys of a given input signature that are valid for this model."""
model_inputs = list(inspect.signature(self.call).parameters)
return {key: val for key, val in signature.items() if key in model_inputs}

def serving_output(self, output):
"""
Prepare the output of the saved model. Each model must implement this function.
Args:
output ([`TFBaseModelOutput`]):
The output returned by the model.
"""
raise NotImplementedError
Prepare the output of the saved model. Can be overridden if specific serving modifications are required.
"""
if not isinstance(output, ModelOutput):
return output
for key in output:
if key.endswith("hidden_states") and not getattr(self.config, "output_hidden_states", False):
output[key] = None
elif key.endswith("attentions") and not getattr(self.config, "output_attentions", False):
output[key] = None
elif key == "past_key_values" and not getattr(self.config, "use_cache", False):
output[key] = None
elif key == "cross_attentions" and not (
getattr(self.config, "output_attentions", False) and getattr(self.config, "add_cross_attention", False)
):
output[key] = None
if isinstance(output[key], (tuple, list)):
try:
output[key] = tf.convert_to_tensor(output[key])
except (ValueError, tf.errors.InvalidArgumentError):
pass # Layers may not have the same dimensions
return output

def can_generate(self) -> bool:
"""
Expand Down Expand Up @@ -1384,7 +1449,7 @@ def prepare_tf_dataset(

if not isinstance(dataset, datasets.Dataset):
raise TypeError("Dataset argument should be a datasets.Dataset!")
model_inputs = list(dict(inspect.signature(self.call).parameters).keys())
model_inputs = list(inspect.signature(self.call).parameters)
model_labels = find_labels(self.__class__)
if "cols_to_retain" in list(inspect.signature(dataset._get_output_signature).parameters.keys()):
output_signature, _ = dataset._get_output_signature(
Expand Down Expand Up @@ -1496,7 +1561,7 @@ def compute_loss(self, *args, **kwargs):
return self.hf_compute_loss(*args, **kwargs)

def get_label_to_output_name_mapping(self):
arg_names = list(dict(inspect.signature(self.call).parameters).keys())
arg_names = list(inspect.signature(self.call).parameters)
if self._label_to_output_map is not None:
return self._label_to_output_map
elif "start_positions" in arg_names:
Expand All @@ -1519,7 +1584,7 @@ def train_step(self, data):
"""

# We hardcode the most common renamings; models with weirder names can set `self._label_to_output_map`
arg_names = list(dict(inspect.signature(self.call).parameters).keys())
arg_names = list(inspect.signature(self.call).parameters)
label_kwargs = find_labels(self.__class__)
label_to_output = self.get_label_to_output_name_mapping()
output_to_label = {val: key for key, val in label_to_output.items()}
Expand Down Expand Up @@ -1626,7 +1691,7 @@ def test_step(self, data):
that they are available to the model during the forward pass.
"""
# We hardcode the most common renamings; models with weirder names can set `self._label_to_output_map`
arg_names = list(dict(inspect.signature(self.call).parameters).keys())
arg_names = list(inspect.signature(self.call).parameters)
label_kwargs = find_labels(self.__class__)
label_to_output = self.get_label_to_output_name_mapping()
output_to_label = {val: key for key, val in label_to_output.items()}
Expand All @@ -1645,7 +1710,7 @@ def test_step(self, data):
# When using a dummy loss, we ensure that separate labels are copied to the correct model arguments,
# if those keys are not already present in the input dict
if self._using_dummy_loss and y is not None:
arg_names = list(dict(inspect.signature(self.call).parameters).keys())
arg_names = list(inspect.signature(self.call).parameters)
# If y is a tensor and the model only has one label-like input, map y to that input
if len(label_kwargs) == 1 and isinstance(y, tf.Tensor):
if isinstance(x, tf.Tensor):
Expand Down
85 changes: 0 additions & 85 deletions src/transformers/models/albert/modeling_tf_albert.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@
)
from ...tf_utils import check_embeddings_within_bounds, shape_list, stable_softmax
from ...utils import (
MULTIPLE_CHOICE_DUMMY_INPUTS,
ModelOutput,
add_code_sample_docstrings,
add_start_docstrings,
Expand Down Expand Up @@ -826,17 +825,6 @@ def call(

return outputs

def serving_output(self, output: TFBaseModelOutputWithPooling) -> TFBaseModelOutputWithPooling:
hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None

return TFBaseModelOutputWithPooling(
last_hidden_state=output.last_hidden_state,
pooler_output=output.pooler_output,
hidden_states=hs,
attentions=attns,
)


@add_start_docstrings(
"""
Expand Down Expand Up @@ -933,17 +921,6 @@ def call(
attentions=outputs.attentions,
)

def serving_output(self, output: TFAlbertForPreTrainingOutput) -> TFAlbertForPreTrainingOutput:
hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None

return TFAlbertForPreTrainingOutput(
prediction_logits=output.prediction_logits,
sop_logits=output.sop_logits,
hidden_states=hs,
attentions=attns,
)


class TFAlbertSOPHead(tf.keras.layers.Layer):
def __init__(self, config: AlbertConfig, **kwargs):
Expand Down Expand Up @@ -1058,13 +1035,6 @@ def call(
attentions=outputs.attentions,
)

# Copied from transformers.models.bert.modeling_tf_bert.TFBertForMaskedLM.serving_output
def serving_output(self, output: TFMaskedLMOutput) -> TFMaskedLMOutput:
hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None

return TFMaskedLMOutput(logits=output.logits, hidden_states=hs, attentions=attns)


@add_start_docstrings(
"""
Expand Down Expand Up @@ -1147,13 +1117,6 @@ def call(
attentions=outputs.attentions,
)

# Copied from transformers.models.bert.modeling_tf_bert.TFBertForSequenceClassification.serving_output
def serving_output(self, output: TFSequenceClassifierOutput) -> TFSequenceClassifierOutput:
hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None

return TFSequenceClassifierOutput(logits=output.logits, hidden_states=hs, attentions=attns)


@add_start_docstrings(
"""
Expand Down Expand Up @@ -1237,13 +1200,6 @@ def call(
attentions=outputs.attentions,
)

# Copied from transformers.models.bert.modeling_tf_bert.TFBertForTokenClassification.serving_output
def serving_output(self, output: TFTokenClassifierOutput) -> TFTokenClassifierOutput:
hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None

return TFTokenClassifierOutput(logits=output.logits, hidden_states=hs, attentions=attns)


@add_start_docstrings(
"""
Expand Down Expand Up @@ -1339,15 +1295,6 @@ def call(
attentions=outputs.attentions,
)

# Copied from transformers.models.bert.modeling_tf_bert.TFBertForQuestionAnswering.serving_output
def serving_output(self, output: TFQuestionAnsweringModelOutput) -> TFQuestionAnsweringModelOutput:
hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None

return TFQuestionAnsweringModelOutput(
start_logits=output.start_logits, end_logits=output.end_logits, hidden_states=hs, attentions=attns
)


@add_start_docstrings(
"""
Expand All @@ -1370,16 +1317,6 @@ def __init__(self, config: AlbertConfig, *inputs, **kwargs):
units=1, kernel_initializer=get_initializer(config.initializer_range), name="classifier"
)

@property
def dummy_inputs(self):
"""
Dummy inputs to build the network.
Returns:
tf.Tensor with dummy inputs
"""
return {"input_ids": tf.constant(MULTIPLE_CHOICE_DUMMY_INPUTS, dtype=tf.int32)}

@unpack_inputs
@add_start_docstrings_to_model_forward(ALBERT_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length"))
@add_code_sample_docstrings(
Expand Down Expand Up @@ -1457,25 +1394,3 @@ def call(
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)

@tf.function(
input_signature=[
{
"input_ids": tf.TensorSpec((None, None, None), tf.int32, name="input_ids"),
"attention_mask": tf.TensorSpec((None, None, None), tf.int32, name="attention_mask"),
"token_type_ids": tf.TensorSpec((None, None, None), tf.int32, name="token_type_ids"),
}
]
)
# Copied from transformers.models.bert.modeling_tf_bert.TFBertForMultipleChoice.serving
def serving(self, inputs: Dict[str, tf.Tensor]) -> TFMultipleChoiceModelOutput:
output = self.call(input_ids=inputs)

return self.serving_output(output)

# Copied from transformers.models.bert.modeling_tf_bert.TFBertForMultipleChoice.serving_output
def serving_output(self, output: TFMultipleChoiceModelOutput) -> TFMultipleChoiceModelOutput:
hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None

return TFMultipleChoiceModelOutput(logits=output.logits, hidden_states=hs, attentions=attns)
Loading

0 comments on commit 814de8f

Please sign in to comment.