Skip to content

Commit

Permalink
Fix keras_core.operations.Repeat op's compute_output_spec method (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
tirthasheshpatel committed Jun 21, 2023
1 parent 62b2ae6 commit 4c435b2
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 1 deletion.
5 changes: 4 additions & 1 deletion keras_core/operations/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -2665,7 +2665,10 @@ def compute_output_spec(self, x):
size_on_ax = x_shape[self.axis]
output_shape = x_shape
if isinstance(self.repeats, int):
output_shape[self.axis] = size_on_ax * self.repeats
if size_on_ax is None:
output_shape[self.axis] = None
else:
output_shape[self.axis] = size_on_ax * self.repeats
else:
output_shape[self.axis] = int(np.sum(self.repeats))
return KerasTensor(output_shape, dtype=x.dtype)
Expand Down
1 change: 1 addition & 0 deletions keras_core/operations/numpy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1027,6 +1027,7 @@ def test_repeat(self):
self.assertEqual(knp.repeat(x, 2).shape, (None,))
self.assertEqual(knp.repeat(x, 3, axis=1).shape, (None, 9))
self.assertEqual(knp.repeat(x, [1, 2], axis=0).shape, (3, 3))
self.assertEqual(knp.repeat(x, 2, axis=0).shape, (None, 3))

def test_reshape(self):
x = KerasTensor([None, 3])
Expand Down

0 comments on commit 4c435b2

Please sign in to comment.