Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Test batched equivalence #29179

Closed
gante opened this issue Feb 21, 2024 · 1 comment · Fixed by #29297
Closed

Test batched equivalence #29179

gante opened this issue Feb 21, 2024 · 1 comment · Fixed by #29297
Assignees

Comments

@gante
Copy link
Member

gante commented Feb 21, 2024

Problem

A common assumption in our codebase and discussion is that batching has little impact on the results. However, we don't test it.

If we add the test below on ModelTesterMixin, we'll see that it fails on many models (py.test tests/models -k test_batching_support). We can also see that it passes on some models, like gpt2 or llama.

Test Proposal (click me)
  def test_batching_support(self):
      """
      Tests that the model supports batching and that the output is the nearly the same for the same input in
      different batch sizes.
      (Why "nearly the same" not "exactly the same"? Batching uses different matmul shapes, which often leads to
      different results: https://github.com/huggingface/transformers/issues/25420#issuecomment-1775317535)
      """

      def recursive_check(batched_object, single_row_object):
          if isinstance(batched_object, (List, Tuple)):
              for batched_object_value, single_row_object_value in zip(batched_object, single_row_object):
                  recursive_check(batched_object_value, single_row_object_value)
          elif isinstance(batched_object, Dict):
              for batched_object_value, single_row_object_value in zip(
                  batched_object.values(), single_row_object.values()
              ):
                  recursive_check(batched_object_value, single_row_object_value)
          elif batched_object is None:
              return
          else:
              batched_row = batched_object[0:1]
              self.assertFalse(torch.isnan(batched_row).any(), "Batched output has `nan`!")
              self.assertFalse(torch.isinf(batched_row).any(), "Batched output has `inf`!")
              self.assertFalse(torch.isnan(single_row_object).any(), "Single row output has `nan`!")
              self.assertFalse(torch.isinf(single_row_object).any(), "Single row output has `inf`!")
              self.assertTrue(
                  torch.allclose(batched_row, single_row_object, atol=1e-5),
                  msg=(
                      "Batched and Single row outputs are not equal. Difference="
                      f"{torch.max(torch.abs(batched_row - single_row_object))}."
                  ),
              )

      config, batched_input = self.model_tester.prepare_config_and_inputs_for_common()

      for model_class in self.all_model_classes:
          # Test as many model outputs as possible
          config.output_hidden_states = True
          config.output_attentions = True

          model = model_class(config).to(torch_device).eval()
          model_batched_output = model(**batched_input)

          batch_size = batched_input[model.main_input_name].shape[0]
          single_row_input = {}
          for key, value in batched_input.items():
              if isinstance(value, torch.Tensor) and value.shape[0] == batch_size:
                  single_row_input[key] = value[0:1]
              else:
                  single_row_input[key] = value
          model_row_output = model(**single_row_input)

          for key in model_batched_output:
              try:
                  recursive_check(model_batched_output[key], model_row_output[key])
              # Augments exception with the failing key to help debugging, if something went wrong
              except AssertionError as e:
                  raise AssertionError(
                      f"{e}\nError in key={key}"
                  )

Plan forward

Iterate on the test above, fixing issues with the test and models. This issue can be tagged as complete when we have a test for "batching has little impact on the results", running on all models. Individual model fixes can be merged in advance of the test.

@gante gante self-assigned this Feb 21, 2024
@gante
Copy link
Member Author

gante commented Feb 21, 2024

cc @zucchini-nlp

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

1 participant