Skip to content

Commit

Permalink
Assert correct class of arrays returned by ivy and frontend functions…
Browse files Browse the repository at this point in the history
…/methods (#19183)
  • Loading branch information
AnnaTz committed Aug 14, 2023
1 parent 35f36db commit 16984fb
Showing 1 changed file with 23 additions and 1 deletion.
24 changes: 23 additions & 1 deletion ivy_tests/test_ivy/helpers/function_testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,10 @@ def test_function(
**kwargs,
)

assert ivy_backend.nested_map(
ret_from_target, lambda x: ivy_backend.is_ivy_array(x) if ivy_backend.is_array(x) else True
), "Ivy function returned non-ivy arrays: {}".format(ret_from_target)

# Assert indices of return if the indices of the out array provided
if test_flags.with_out and not test_flags.test_compile:
test_ret = (
Expand Down Expand Up @@ -382,6 +386,9 @@ def test_function(
test_compile=test_flags.test_compile,
**kwargs,
)
assert gt_backend.nested_map(
ret_from_gt, lambda x: gt_backend.is_ivy_array(x) if gt_backend.is_array(x) else True
), "Ground-truth function returned non-ivy arrays: {}".format(ret_from_gt)
if test_flags.with_out and not test_flags.test_compile:
test_ret_from_gt = (
ret_from_gt[getattr(gt_backend.__dict__[fn_name], "out_index")]
Expand Down Expand Up @@ -611,6 +618,10 @@ def test_frontend_function(
**kwargs_for_test,
)

assert ivy_backend.nested_map(
ret, lambda x: _is_frontend_array(x) if ivy_backend.is_array(x) else True
), "Frontend function returned non-frontend arrays: {}".format(ret)

if test_flags.with_out:
if not inspect.isclass(ret):
is_ret_tuple = issubclass(ret.__class__, tuple)
Expand Down Expand Up @@ -1219,6 +1230,10 @@ def test_method(
else:
ret_device = None

assert ivy_backend.nested_map(
ret, lambda x: ivy_backend.is_ivy_array(x) if ivy_backend.is_array(x) else True
), "Ivy method returned non-ivy arrays: {}".format(ret)

# Compute the return with a Ground Truth backend

with update_backend(ground_truth_backend) as gt_backend:
Expand Down Expand Up @@ -1262,7 +1277,10 @@ def test_method(
test_compile=test_compile,
**kwargs_gt_method,
)

assert gt_backend.nested_map(
ret_from_gt, lambda x: gt_backend.is_ivy_array(x) if gt_backend.is_array(x) else True
), "Ground-truth method returned non-ivy arrays: {}".format(ret_from_gt)

# TODO optimize or cache
# Exhuastive replication for all examples
fw_list = gradient_unsupported_dtypes(fn=ins.__getattribute__(method_name))
Expand Down Expand Up @@ -1580,6 +1598,10 @@ def test_frontend_method(
**kwargs_method,
)

assert ivy_backend.nested_map(
ret, lambda x: _is_frontend_array(x) if ivy_backend.is_array(x) else True
), "Frontend method returned non-frontend arrays: {}".format(ret)

# Compute the return with the native frontend framework
frontend_config = get_frontend_config(frontend)
args_constructor_frontend = ivy.nested_map(
Expand Down

0 comments on commit 16984fb

Please sign in to comment.