Skip to content

Commit

Permalink
Merge pull request #18042 from tomrtk:fit-error-msg
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 547546303
  • Loading branch information
tensorflower-gardener committed Jul 12, 2023
2 parents 37c1909 + 8c0bd19 commit 8582c5f
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 2 deletions.
10 changes: 10 additions & 0 deletions keras/engine/data_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -1271,6 +1271,13 @@ def __init__(
self._insufficient_data = False
self._model = model

if steps_per_epoch == 0:
raise ValueError(
"Unexpected value for `steps_per_epoch`. Received value is 0. "
"Please check the docstring for `model.fit()` for supported "
"values."
)

self._steps_per_epoch = steps_per_epoch

# `steps_per_execution_value` is the cached initial value.
Expand Down Expand Up @@ -1308,6 +1315,9 @@ def __init__(
strategy, x, steps_per_epoch, class_weight, distribute
)

if self._inferred_steps == 0:
raise ValueError("Expected input data to be non-empty.")

def _configure_dataset_and_inferred_steps(
self, strategy, x, steps_per_epoch, class_weight, distribute
):
Expand Down
31 changes: 31 additions & 0 deletions keras/engine/data_adapter_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1442,6 +1442,37 @@ def test_single_x_input_no_tuple_wrapping(self, use_numpy):
# Check that single x input is not wrapped in a tuple.
self.assertIsInstance(next(iterator), tf.Tensor)

def test_error_if_zero_steps_per_epoch(self):
data = tf.data.Dataset.from_tensor_slices([0, 1, 2, 3]).batch(1)

with self.assertRaisesRegex(
ValueError,
"Unexpected value for `steps_per_epoch`. Received value is 0.",
):
data_adapter.DataHandler(
data, initial_epoch=0, epochs=2, steps_per_epoch=0
)

def test_error_if_empty_array_input_data(self):
x = np.array([[0, 0], [0, 1], [1, 0], [1, 1]])
y = np.array([0, 1, 1, 0])
idx = []

with self.assertRaisesWithLiteralMatch(
ValueError,
"Expected input data to be non-empty.",
):
data_adapter.DataHandler(x[idx], y[idx])

def test_error_if_empty_dataset_input_data(self):
data = tf.data.Dataset.from_tensor_slices([]).batch(1)

with self.assertRaisesWithLiteralMatch(
ValueError,
"Expected input data to be non-empty.",
):
data_adapter.DataHandler(data)


class TestValidationSplit(test_combinations.TestCase):
@parameterized.named_parameters(("numpy_arrays", True), ("tensors", False))
Expand Down
4 changes: 2 additions & 2 deletions keras/engine/training_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def test_fit_on_empty(self):
model = sequential.Sequential([layers_module.Dense(1)])
model.compile("sgd", "mse", run_eagerly=test_utils.should_run_eagerly())
with self.assertRaisesRegex(
ValueError, "Unexpected result of `train_function`.*"
ValueError, "Expected input data to be non-empty."
):
model.fit(x=np.array([]), y=np.array([]))

Expand Down Expand Up @@ -2534,7 +2534,7 @@ def test_predict_error_with_empty_x(self):
model.compile(loss="mse")

with self.assertRaisesRegex(
ValueError, "Unexpected result of `predict_function`.*"
ValueError, "Expected input data to be non-empty."
):
model.predict(np.array([]))

Expand Down

0 comments on commit 8582c5f

Please sign in to comment.