Skip to content

Commit

Permalink
fix the bug ivy-llc#23044
Browse files Browse the repository at this point in the history
  • Loading branch information
Aryan8912 committed Sep 20, 2023
1 parent d154394 commit af38d41
Showing 1 changed file with 35 additions and 0 deletions.
35 changes: 35 additions & 0 deletions ivy/functional/backends/jax/general.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,8 @@ def to_list(x: JaxArray, /) -> list:
return _to_array(x).tolist()


# ivy/utils/assertions.py

# ivy/utils/assertions.py

def get_positive_axis_for_gather(axis, ndims):
Expand All @@ -140,6 +142,39 @@ def get_positive_axis_for_gather(axis, ndims):

# ivy/functional/backends/jax/general.py

def gather(
params: JaxArray,
indices: JaxArray,
/,
*,
axis: int = -1,
batch_dims: int = 0,
out: Optional[JaxArray] = None,
) -> JaxArray:
axis = get_positive_axis_for_gather(axis, params.ndim)
batch_dims = batch_dims % len(params.shape)
ivy.utils.assertions.check_gather_input_valid(params, indices, axis, batch_dims)
result = []
if batch_dims == 0:
result = jnp.take(params, indices, axis)
else:
for b in range(batch_dims):
if b == 0:
zip_list = [(p, i) for p, i in zip(params, indices)]
else:
zip_list = [
(p, i) for z in [zip(p1, i1) for p1, i1 in zip_list] for p, i in z
]
for z in zip_list:
p, i = z
r = jnp.take(p, i, axis - batch_dims)
result.append(r)
result = jnp.array(result)
result = result.reshape([*params.shape[0:batch_dims], *result.shape[1:]])
return result

# ivy/functional/backends/jax/general.py

def gather(
params: JaxArray,
indices: JaxArray,
Expand Down

0 comments on commit af38d41

Please sign in to comment.