Skip to content

Commit

Permalink
Reimplement eye op on TF to support tensors as inputs. (#20280)
Browse files Browse the repository at this point in the history
`N`, `M` and `k` can now be tensors and not just ints, which is useful when these parameters come from dynamic dimensions in a shape.
  • Loading branch information
hertschuh authored Sep 24, 2024
1 parent 0acd575 commit f3fa48a
Showing 1 changed file with 10 additions and 22 deletions.
32 changes: 10 additions & 22 deletions keras/src/backend/tensorflow/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -2443,29 +2443,17 @@ def sum(x, axis=None, keepdims=False):

def eye(N, M=None, k=0, dtype=None):
dtype = dtype or config.floatx()
if not M:
M = N
# Making sure N, M and k are `int`
N, M, k = int(N), int(M), int(k)
if k >= M or -k >= N:
# tf.linalg.diag will raise an error in this case
return zeros([N, M], dtype=dtype)
if k == 0:
M = N if M is None else M
if isinstance(k, int) and k == 0:
return tf.eye(N, M, dtype=dtype)
# We need the precise length, otherwise tf.linalg.diag will raise an error
diag_len = builtins.min(N, M)
if k > 0:
if N >= M:
diag_len -= k
elif N + k > M:
diag_len = M - k
elif k <= 0:
if M >= N:
diag_len += k
elif M - k > N:
diag_len = N + k
diagonal_ = tf.ones([diag_len], dtype=dtype)
return tf.linalg.diag(diagonal=diagonal_, num_rows=N, num_cols=M, k=k)
# Create a smaller square eye and pad appropriately.
return tf.pad(
tf.eye(tf.minimum(M - k, N + k), dtype=dtype),
paddings=(
(tf.maximum(-k, 0), tf.maximum(N - M + k, 0)),
(tf.maximum(k, 0), tf.maximum(M - N - k, 0)),
),
)


def floor_divide(x1, x2):
Expand Down

0 comments on commit f3fa48a

Please sign in to comment.