diff --git a/keras_core/operations/numpy.py b/keras_core/operations/numpy.py index cf6881b0e..2d412f4b7 100644 --- a/keras_core/operations/numpy.py +++ b/keras_core/operations/numpy.py @@ -2665,7 +2665,10 @@ def compute_output_spec(self, x): size_on_ax = x_shape[self.axis] output_shape = x_shape if isinstance(self.repeats, int): - output_shape[self.axis] = size_on_ax * self.repeats + if size_on_ax is None: + output_shape[self.axis] = None + else: + output_shape[self.axis] = size_on_ax * self.repeats else: output_shape[self.axis] = int(np.sum(self.repeats)) return KerasTensor(output_shape, dtype=x.dtype) diff --git a/keras_core/operations/numpy_test.py b/keras_core/operations/numpy_test.py index 580e7674d..909769cab 100644 --- a/keras_core/operations/numpy_test.py +++ b/keras_core/operations/numpy_test.py @@ -1027,6 +1027,7 @@ def test_repeat(self): self.assertEqual(knp.repeat(x, 2).shape, (None,)) self.assertEqual(knp.repeat(x, 3, axis=1).shape, (None, 9)) self.assertEqual(knp.repeat(x, [1, 2], axis=0).shape, (3, 3)) + self.assertEqual(knp.repeat(x, 2, axis=0).shape, (None, 3)) def test_reshape(self): x = KerasTensor([None, 3])