Skip to content

Commit

Permalink
add numpy.asscalar() (#22947)
Browse files Browse the repository at this point in the history
  • Loading branch information
duspic authored Sep 4, 2023
1 parent e1f93d2 commit 2ec34e2
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,7 @@

def asmatrix(data, dtype=None):
return np_frontend.matrix(ivy.array(data), dtype=dtype, copy=False)


def asscalar(a):
return a.item()
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,13 @@ def test_numpy_asmatrix(arr, backend_fw):
ret_gt = np.asmatrix(x[0])
assert ret.shape == ret_gt.shape
assert ivy_backend.all(ivy_backend.flatten(ret._data) == np.ravel(ret_gt))


@handle_frontend_test(
fn_tree="numpy.asscalar",
arr=helpers.array_values(dtype=helpers.get_dtypes("numeric"), shape=1),
)
def test_numpy_asscalar(arr: np.ndarray):
ret_1 = arr.item()
ret_2 = np_frontend.asscalar(arr)
assert ret_1 == ret_2

0 comments on commit 2ec34e2

Please sign in to comment.