diff --git a/keras/engine/data_adapter.py b/keras/engine/data_adapter.py index 9201bfe3be0..517684e7559 100644 --- a/keras/engine/data_adapter.py +++ b/keras/engine/data_adapter.py @@ -1271,6 +1271,13 @@ def __init__( self._insufficient_data = False self._model = model + if steps_per_epoch == 0: + raise ValueError( + "Unexpected value for `steps_per_epoch`. Received value is 0. " + "Please check the docstring for `model.fit()` for supported " + "values." + ) + self._steps_per_epoch = steps_per_epoch # `steps_per_execution_value` is the cached initial value. @@ -1308,6 +1315,9 @@ def __init__( strategy, x, steps_per_epoch, class_weight, distribute ) + if self._inferred_steps == 0: + raise ValueError("Expected input data to be non-empty.") + def _configure_dataset_and_inferred_steps( self, strategy, x, steps_per_epoch, class_weight, distribute ): diff --git a/keras/engine/data_adapter_test.py b/keras/engine/data_adapter_test.py index 5878e887f9b..2a480b385b9 100644 --- a/keras/engine/data_adapter_test.py +++ b/keras/engine/data_adapter_test.py @@ -1442,6 +1442,37 @@ def test_single_x_input_no_tuple_wrapping(self, use_numpy): # Check that single x input is not wrapped in a tuple. self.assertIsInstance(next(iterator), tf.Tensor) + def test_error_if_zero_steps_per_epoch(self): + data = tf.data.Dataset.from_tensor_slices([0, 1, 2, 3]).batch(1) + + with self.assertRaisesRegex( + ValueError, + "Unexpected value for `steps_per_epoch`. Received value is 0.", + ): + data_adapter.DataHandler( + data, initial_epoch=0, epochs=2, steps_per_epoch=0 + ) + + def test_error_if_empty_array_input_data(self): + x = np.array([[0, 0], [0, 1], [1, 0], [1, 1]]) + y = np.array([0, 1, 1, 0]) + idx = [] + + with self.assertRaisesWithLiteralMatch( + ValueError, + "Expected input data to be non-empty.", + ): + data_adapter.DataHandler(x[idx], y[idx]) + + def test_error_if_empty_dataset_input_data(self): + data = tf.data.Dataset.from_tensor_slices([]).batch(1) + + with self.assertRaisesWithLiteralMatch( + ValueError, + "Expected input data to be non-empty.", + ): + data_adapter.DataHandler(data) + class TestValidationSplit(test_combinations.TestCase): @parameterized.named_parameters(("numpy_arrays", True), ("tensors", False)) diff --git a/keras/engine/training_test.py b/keras/engine/training_test.py index 0c6dc9d66ad..ea040ac65b0 100644 --- a/keras/engine/training_test.py +++ b/keras/engine/training_test.py @@ -94,7 +94,7 @@ def test_fit_on_empty(self): model = sequential.Sequential([layers_module.Dense(1)]) model.compile("sgd", "mse", run_eagerly=test_utils.should_run_eagerly()) with self.assertRaisesRegex( - ValueError, "Unexpected result of `train_function`.*" + ValueError, "Expected input data to be non-empty." ): model.fit(x=np.array([]), y=np.array([])) @@ -2534,7 +2534,7 @@ def test_predict_error_with_empty_x(self): model.compile(loss="mse") with self.assertRaisesRegex( - ValueError, "Unexpected result of `predict_function`.*" + ValueError, "Expected input data to be non-empty." ): model.predict(np.array([]))