Skip to content

Commit

Permalink
fix: ensure correct offsets dtype and rank for ivy gather call in nea…
Browse files Browse the repository at this point in the history
…rest interpolate function

torch requires int64 for indices and the indices need to be same rank as x.
  • Loading branch information
Ishticode committed Feb 8, 2024
1 parent 4173945 commit 4dee494
Showing 1 changed file with 5 additions and 1 deletion.
6 changes: 5 additions & 1 deletion ivy/functional/ivy/experimental/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1473,7 +1473,11 @@ def nearest_interpolate(x, dims, size, scale, exact):
for d in range(dims):
n = size[d]
offsets = (ivy.arange(n, dtype="float32") + off) * scale[d]
offsets = ivy.astype(ivy.floor(ivy.astype(offsets, "float32")), "int32")
offsets = ivy.astype(ivy.floor(ivy.astype(offsets, "float32")), "int64")
num_dims_to_add = x.ndim - offsets.ndim
if num_dims_to_add > 0:
for _ in range(num_dims_to_add):
offsets = ivy.expand_dims(offsets, axis=0)
x = ivy.gather(x, offsets, axis=d + 2)
return x

Expand Down

0 comments on commit 4dee494

Please sign in to comment.