You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
A common assumption in our codebase and discussion is that batching has little impact on the results. However, we don't test it.
If we add the test below on ModelTesterMixin, we'll see that it fails on many models (py.test tests/models -k test_batching_support). We can also see that it passes on some models, like gpt2 or llama.
Test Proposal (click me)
deftest_batching_support(self):
""" Tests that the model supports batching and that the output is the nearly the same for the same input in different batch sizes. (Why "nearly the same" not "exactly the same"? Batching uses different matmul shapes, which often leads to different results: https://github.com/huggingface/transformers/issues/25420#issuecomment-1775317535) """defrecursive_check(batched_object, single_row_object):
ifisinstance(batched_object, (List, Tuple)):
forbatched_object_value, single_row_object_valueinzip(batched_object, single_row_object):
recursive_check(batched_object_value, single_row_object_value)
elifisinstance(batched_object, Dict):
forbatched_object_value, single_row_object_valueinzip(
batched_object.values(), single_row_object.values()
):
recursive_check(batched_object_value, single_row_object_value)
elifbatched_objectisNone:
returnelse:
batched_row=batched_object[0:1]
self.assertFalse(torch.isnan(batched_row).any(), "Batched output has `nan`!")
self.assertFalse(torch.isinf(batched_row).any(), "Batched output has `inf`!")
self.assertFalse(torch.isnan(single_row_object).any(), "Single row output has `nan`!")
self.assertFalse(torch.isinf(single_row_object).any(), "Single row output has `inf`!")
self.assertTrue(
torch.allclose(batched_row, single_row_object, atol=1e-5),
msg=(
"Batched and Single row outputs are not equal. Difference="f"{torch.max(torch.abs(batched_row-single_row_object))}."
),
)
config, batched_input=self.model_tester.prepare_config_and_inputs_for_common()
formodel_classinself.all_model_classes:
# Test as many model outputs as possibleconfig.output_hidden_states=Trueconfig.output_attentions=Truemodel=model_class(config).to(torch_device).eval()
model_batched_output=model(**batched_input)
batch_size=batched_input[model.main_input_name].shape[0]
single_row_input= {}
forkey, valueinbatched_input.items():
ifisinstance(value, torch.Tensor) andvalue.shape[0] ==batch_size:
single_row_input[key] =value[0:1]
else:
single_row_input[key] =valuemodel_row_output=model(**single_row_input)
forkeyinmodel_batched_output:
try:
recursive_check(model_batched_output[key], model_row_output[key])
# Augments exception with the failing key to help debugging, if something went wrongexceptAssertionErrorase:
raiseAssertionError(
f"{e}\nError in key={key}"
)
Plan forward
Iterate on the test above, fixing issues with the test and models. This issue can be tagged as complete when we have a test for "batching has little impact on the results", running on all models. Individual model fixes can be merged in advance of the test.
The text was updated successfully, but these errors were encountered:
Problem
A common assumption in our codebase and discussion is that batching has little impact on the results. However, we don't test it.
If we add the test below on
ModelTesterMixin
, we'll see that it fails on many models (py.test tests/models -k test_batching_support
). We can also see that it passes on some models, like gpt2 or llama.Test Proposal (click me)
Plan forward
Iterate on the test above, fixing issues with the test and models. This issue can be tagged as complete when we have a test for "batching has little impact on the results", running on all models. Individual model fixes can be merged in advance of the test.
The text was updated successfully, but these errors were encountered: