From 16984fbd91e89b5de246e7f359b214b1c4e184e2 Mon Sep 17 00:00:00 2001 From: AnnaTz <111577222+AnnaTz@users.noreply.github.com> Date: Mon, 14 Aug 2023 16:32:47 +0100 Subject: [PATCH] Assert correct class of arrays returned by ivy and frontend functions/methods (#19183) --- .../test_ivy/helpers/function_testing.py | 24 ++++++++++++++++++- 1 file changed, 23 insertions(+), 1 deletion(-) diff --git a/ivy_tests/test_ivy/helpers/function_testing.py b/ivy_tests/test_ivy/helpers/function_testing.py index ea5e684853b35..f2847a8ba055c 100644 --- a/ivy_tests/test_ivy/helpers/function_testing.py +++ b/ivy_tests/test_ivy/helpers/function_testing.py @@ -280,6 +280,10 @@ def test_function( **kwargs, ) + assert ivy_backend.nested_map( + ret_from_target, lambda x: ivy_backend.is_ivy_array(x) if ivy_backend.is_array(x) else True + ), "Ivy function returned non-ivy arrays: {}".format(ret_from_target) + # Assert indices of return if the indices of the out array provided if test_flags.with_out and not test_flags.test_compile: test_ret = ( @@ -382,6 +386,9 @@ def test_function( test_compile=test_flags.test_compile, **kwargs, ) + assert gt_backend.nested_map( + ret_from_gt, lambda x: gt_backend.is_ivy_array(x) if gt_backend.is_array(x) else True + ), "Ground-truth function returned non-ivy arrays: {}".format(ret_from_gt) if test_flags.with_out and not test_flags.test_compile: test_ret_from_gt = ( ret_from_gt[getattr(gt_backend.__dict__[fn_name], "out_index")] @@ -611,6 +618,10 @@ def test_frontend_function( **kwargs_for_test, ) + assert ivy_backend.nested_map( + ret, lambda x: _is_frontend_array(x) if ivy_backend.is_array(x) else True + ), "Frontend function returned non-frontend arrays: {}".format(ret) + if test_flags.with_out: if not inspect.isclass(ret): is_ret_tuple = issubclass(ret.__class__, tuple) @@ -1219,6 +1230,10 @@ def test_method( else: ret_device = None + assert ivy_backend.nested_map( + ret, lambda x: ivy_backend.is_ivy_array(x) if ivy_backend.is_array(x) else True + ), "Ivy method returned non-ivy arrays: {}".format(ret) + # Compute the return with a Ground Truth backend with update_backend(ground_truth_backend) as gt_backend: @@ -1262,7 +1277,10 @@ def test_method( test_compile=test_compile, **kwargs_gt_method, ) - + assert gt_backend.nested_map( + ret_from_gt, lambda x: gt_backend.is_ivy_array(x) if gt_backend.is_array(x) else True + ), "Ground-truth method returned non-ivy arrays: {}".format(ret_from_gt) + # TODO optimize or cache # Exhuastive replication for all examples fw_list = gradient_unsupported_dtypes(fn=ins.__getattribute__(method_name)) @@ -1580,6 +1598,10 @@ def test_frontend_method( **kwargs_method, ) + assert ivy_backend.nested_map( + ret, lambda x: _is_frontend_array(x) if ivy_backend.is_array(x) else True + ), "Frontend method returned non-frontend arrays: {}".format(ret) + # Compute the return with the native frontend framework frontend_config = get_frontend_config(frontend) args_constructor_frontend = ivy.nested_map(