From cc02bea5339c0516be5172060cd02cd86b750fce Mon Sep 17 00:00:00 2001 From: soumesh113 <58463935+soumesh113@users.noreply.github.com> Date: Tue, 15 Aug 2023 17:45:09 +0530 Subject: [PATCH 01/24] Update math.py --- ivy/functional/frontends/tensorflow/math.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/ivy/functional/frontends/tensorflow/math.py b/ivy/functional/frontends/tensorflow/math.py index 8fccb8cbfcc78..697690346aee0 100644 --- a/ivy/functional/frontends/tensorflow/math.py +++ b/ivy/functional/frontends/tensorflow/math.py @@ -549,6 +549,19 @@ def tanh(x, name=None): def rsqrt(x, name=None): return ivy.reciprocal(ivy.sqrt(x)) +@to_ivy_arrays_and_back +def segment_sum(data, segment_ids, name= "segment_sum"): + data = ivy.array(data) + segment_ids = ivy.array(segment_ids) + ivy.utils.assertions.check_equal( + list(segment_ids.shape), [list(data.shape)[0]], as_array=False + ) + sum_array = ivy.zeros( + tuple([segment_ids[-1]+1] + (list(data.shape))[1:]), dtype = ivy.int32 + ) + for i in range((segment_ids).shape[0]): + sum_array[segment_ids[i]] = sum_array[segment_ids[i]] + data[i] + return sum_array @to_ivy_arrays_and_back def nextafter(x1, x2, name=None): From 16f1fd6516901357bc081c925f23e216bd6a7d2d Mon Sep 17 00:00:00 2001 From: soumesh113 <58463935+soumesh113@users.noreply.github.com> Date: Tue, 15 Aug 2023 17:49:09 +0530 Subject: [PATCH 02/24] Update test_math.py --- .../test_tensorflow/test_math.py | 29 +++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/ivy_tests/test_ivy/test_frontends/test_tensorflow/test_math.py b/ivy_tests/test_ivy/test_frontends/test_tensorflow/test_math.py index 9f7f44cf4cf3f..2e8b493222096 100644 --- a/ivy_tests/test_ivy/test_frontends/test_tensorflow/test_math.py +++ b/ivy_tests/test_ivy/test_frontends/test_tensorflow/test_math.py @@ -1940,6 +1940,35 @@ def test_tensorflow_rsqrt( x=x[0], ) +#segment_sum +@handle_frontend_test( + fn_tree="tensorflow.math.unsorted_segment_sum", + data=helpers.array_values(dtype=ivy.int32, shape=(5, 6), min_value=1, max_value=9), + segment_ids=helpers.array_values( + dtype=ivy.int32, shape=(5,), min_value=0, max_value=4 + ), + test_with_out=st.just(False), +) +def test_tensorflow_unsorted_segment_sum( + *, + data, + segment_ids, + frontend, + test_flags, + fn_tree, + backend_fw, + on_device, +): + helpers.test_frontend_function( + input_dtypes=["int32", "int64"], + frontend=frontend, + backend_to_test=backend_fw, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + data=data, + segment_ids=segment_ids, + ) # nextafter @handle_frontend_test( From 85c662843aa01baf89acf20285801b278259c867 Mon Sep 17 00:00:00 2001 From: soumesh113 <58463935+soumesh113@users.noreply.github.com> Date: Tue, 15 Aug 2023 17:52:59 +0530 Subject: [PATCH 03/24] Update test_math.py --- .../test_ivy/test_frontends/test_tensorflow/test_math.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ivy_tests/test_ivy/test_frontends/test_tensorflow/test_math.py b/ivy_tests/test_ivy/test_frontends/test_tensorflow/test_math.py index 2e8b493222096..f4478bac40f97 100644 --- a/ivy_tests/test_ivy/test_frontends/test_tensorflow/test_math.py +++ b/ivy_tests/test_ivy/test_frontends/test_tensorflow/test_math.py @@ -1942,14 +1942,14 @@ def test_tensorflow_rsqrt( #segment_sum @handle_frontend_test( - fn_tree="tensorflow.math.unsorted_segment_sum", + fn_tree="tensorflow.math.segment_sum", data=helpers.array_values(dtype=ivy.int32, shape=(5, 6), min_value=1, max_value=9), segment_ids=helpers.array_values( dtype=ivy.int32, shape=(5,), min_value=0, max_value=4 ), test_with_out=st.just(False), ) -def test_tensorflow_unsorted_segment_sum( +def test_tensorflow_segment_sum( *, data, segment_ids, From 7238e077a031b165fadc0d4a2880cc4ee4b3ffb7 Mon Sep 17 00:00:00 2001 From: soumesh113 <58463935+soumesh113@users.noreply.github.com> Date: Sat, 19 Aug 2023 16:26:47 +0530 Subject: [PATCH 04/24] Update test_math.py --- .../test_tensorflow/test_math.py | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/ivy_tests/test_ivy/test_frontends/test_tensorflow/test_math.py b/ivy_tests/test_ivy/test_frontends/test_tensorflow/test_math.py index f4478bac40f97..3971216e9b910 100644 --- a/ivy_tests/test_ivy/test_frontends/test_tensorflow/test_math.py +++ b/ivy_tests/test_ivy/test_frontends/test_tensorflow/test_math.py @@ -1943,31 +1943,34 @@ def test_tensorflow_rsqrt( #segment_sum @handle_frontend_test( fn_tree="tensorflow.math.segment_sum", - data=helpers.array_values(dtype=ivy.int32, shape=(5, 6), min_value=1, max_value=9), - segment_ids=helpers.array_values( - dtype=ivy.int32, shape=(5,), min_value=0, max_value=4 + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("numeric"), + num_arrays=2, + shared_dtype=True, + min_num_dims=1, + max_num_dims=3, ), test_with_out=st.just(False), ) def test_tensorflow_segment_sum( *, - data, - segment_ids, + dtype_and_x frontend, test_flags, fn_tree, backend_fw, on_device, ): + input_dtype, x = dtype_and_x helpers.test_frontend_function( - input_dtypes=["int32", "int64"], + input_dtypes=input_dtype, frontend=frontend, backend_to_test=backend_fw, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - data=data, - segment_ids=segment_ids, + data=x[0], + segment_ids=x[1], ) # nextafter From 263602dca4d08325e32e03e61cf784befc92afb4 Mon Sep 17 00:00:00 2001 From: soumesh113 <58463935+soumesh113@users.noreply.github.com> Date: Sat, 19 Aug 2023 16:41:06 +0530 Subject: [PATCH 05/24] Update math.py --- ivy/functional/frontends/tensorflow/math.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/ivy/functional/frontends/tensorflow/math.py b/ivy/functional/frontends/tensorflow/math.py index 697690346aee0..ab270a8bdd34c 100644 --- a/ivy/functional/frontends/tensorflow/math.py +++ b/ivy/functional/frontends/tensorflow/math.py @@ -550,17 +550,15 @@ def rsqrt(x, name=None): return ivy.reciprocal(ivy.sqrt(x)) @to_ivy_arrays_and_back -def segment_sum(data, segment_ids, name= "segment_sum"): - data = ivy.array(data) - segment_ids = ivy.array(segment_ids) +def segment_sum(data, segment_ids, name="segment_sum"): ivy.utils.assertions.check_equal( list(segment_ids.shape), [list(data.shape)[0]], as_array=False ) sum_array = ivy.zeros( - tuple([segment_ids[-1]+1] + (list(data.shape))[1:]), dtype = ivy.int32 + tuple([segment_ids[-1] + 1] + (list(data.shape))[1:]), dtype = ivy.int32 ) for i in range((segment_ids).shape[0]): - sum_array[segment_ids[i]] = sum_array[segment_ids[i]] + data[i] + sum_array[segment_ids[i]] = sum_array[segment_ids[i]] + data[i] return sum_array @to_ivy_arrays_and_back From 8b11c42c58d1a7c212ccba2ba1d37e8cb5810023 Mon Sep 17 00:00:00 2001 From: soumesh113 <58463935+soumesh113@users.noreply.github.com> Date: Sat, 19 Aug 2023 16:45:10 +0530 Subject: [PATCH 06/24] Update test_math.py --- ivy_tests/test_ivy/test_frontends/test_tensorflow/test_math.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ivy_tests/test_ivy/test_frontends/test_tensorflow/test_math.py b/ivy_tests/test_ivy/test_frontends/test_tensorflow/test_math.py index 3971216e9b910..fdfb9f5836e72 100644 --- a/ivy_tests/test_ivy/test_frontends/test_tensorflow/test_math.py +++ b/ivy_tests/test_ivy/test_frontends/test_tensorflow/test_math.py @@ -1944,7 +1944,7 @@ def test_tensorflow_rsqrt( @handle_frontend_test( fn_tree="tensorflow.math.segment_sum", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), + available_dtypes=helpers.tuple([ivy.int64, ivy.int32]), num_arrays=2, shared_dtype=True, min_num_dims=1, From ae97f86779c1ae43c9af99dee761cd6567b3e83d Mon Sep 17 00:00:00 2001 From: soumesh113 <58463935+soumesh113@users.noreply.github.com> Date: Sat, 19 Aug 2023 16:46:36 +0530 Subject: [PATCH 07/24] Update math.py --- ivy/functional/frontends/tensorflow/math.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ivy/functional/frontends/tensorflow/math.py b/ivy/functional/frontends/tensorflow/math.py index ab270a8bdd34c..1a30e27fd117d 100644 --- a/ivy/functional/frontends/tensorflow/math.py +++ b/ivy/functional/frontends/tensorflow/math.py @@ -555,7 +555,7 @@ def segment_sum(data, segment_ids, name="segment_sum"): list(segment_ids.shape), [list(data.shape)[0]], as_array=False ) sum_array = ivy.zeros( - tuple([segment_ids[-1] + 1] + (list(data.shape))[1:]), dtype = ivy.int32 + tuple([segment_ids[-1] + 1] + (list(data.shape))[1:]), dtype=ivy.int32 ) for i in range((segment_ids).shape[0]): sum_array[segment_ids[i]] = sum_array[segment_ids[i]] + data[i] From 530cf8f1a552fd7e5fbda17c9a796a4f1c2f5c73 Mon Sep 17 00:00:00 2001 From: soumesh113 <58463935+soumesh113@users.noreply.github.com> Date: Sat, 19 Aug 2023 17:30:36 +0530 Subject: [PATCH 08/24] Update test_math.py --- .../test_tensorflow/test_math.py | 26 ++++++++++++------- 1 file changed, 16 insertions(+), 10 deletions(-) diff --git a/ivy_tests/test_ivy/test_frontends/test_tensorflow/test_math.py b/ivy_tests/test_ivy/test_frontends/test_tensorflow/test_math.py index fdfb9f5836e72..0caffa1c501bb 100644 --- a/ivy_tests/test_ivy/test_frontends/test_tensorflow/test_math.py +++ b/ivy_tests/test_ivy/test_frontends/test_tensorflow/test_math.py @@ -1943,34 +1943,40 @@ def test_tensorflow_rsqrt( #segment_sum @handle_frontend_test( fn_tree="tensorflow.math.segment_sum", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.tuple([ivy.int64, ivy.int32]), - num_arrays=2, - shared_dtype=True, + dtype_and_data=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("numeric"), + min_num_dims=2, + max_num_dims=2, + + ), + dtype_and_segment=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("numeric"), min_num_dims=1, - max_num_dims=3, + max_num_dims=1, ), test_with_out=st.just(False), ) def test_tensorflow_segment_sum( *, - dtype_and_x + dtype_and_data, + dtype_and_segment, frontend, test_flags, fn_tree, backend_fw, on_device, ): - input_dtype, x = dtype_and_x + data_dtype, data = dtype_and_data + segment_dtype, segment_ids = dtype_and_segment helpers.test_frontend_function( - input_dtypes=input_dtype, + input_dtypes=data_dtype + segment_dtype, frontend=frontend, backend_to_test=backend_fw, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - data=x[0], - segment_ids=x[1], + data=data, + segment_ids=segment_ids, ) # nextafter From 36e0bee94a7375e501aab331872f7dab104019f7 Mon Sep 17 00:00:00 2001 From: soumesh113 <58463935+soumesh113@users.noreply.github.com> Date: Sat, 19 Aug 2023 17:48:41 +0530 Subject: [PATCH 09/24] Update test_math.py --- .../test_frontends/test_tensorflow/test_math.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/ivy_tests/test_ivy/test_frontends/test_tensorflow/test_math.py b/ivy_tests/test_ivy/test_frontends/test_tensorflow/test_math.py index 0caffa1c501bb..d7bc1d72177ec 100644 --- a/ivy_tests/test_ivy/test_frontends/test_tensorflow/test_math.py +++ b/ivy_tests/test_ivy/test_frontends/test_tensorflow/test_math.py @@ -1945,14 +1945,11 @@ def test_tensorflow_rsqrt( fn_tree="tensorflow.math.segment_sum", dtype_and_data=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("numeric"), - min_num_dims=2, - max_num_dims=2, - + shape = (5, 6), ), dtype_and_segment=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("numeric"), - min_num_dims=1, - max_num_dims=1, + shape = (5, ), ), test_with_out=st.just(False), ) @@ -1975,8 +1972,8 @@ def test_tensorflow_segment_sum( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - data=data, - segment_ids=segment_ids, + data=data[0], + segment_ids=segment_ids[0], ) # nextafter From e71e24143e7fd80620253f7974f80baf46add281 Mon Sep 17 00:00:00 2001 From: soumesh113 <58463935+soumesh113@users.noreply.github.com> Date: Sat, 19 Aug 2023 17:51:29 +0530 Subject: [PATCH 10/24] Update math.py --- ivy/functional/frontends/tensorflow/math.py | 1 + 1 file changed, 1 insertion(+) diff --git a/ivy/functional/frontends/tensorflow/math.py b/ivy/functional/frontends/tensorflow/math.py index 1a30e27fd117d..d2c05dae15743 100644 --- a/ivy/functional/frontends/tensorflow/math.py +++ b/ivy/functional/frontends/tensorflow/math.py @@ -551,6 +551,7 @@ def rsqrt(x, name=None): @to_ivy_arrays_and_back def segment_sum(data, segment_ids, name="segment_sum"): + segment_ids = ivy.sort(segment_ids) ivy.utils.assertions.check_equal( list(segment_ids.shape), [list(data.shape)[0]], as_array=False ) From 8edfec740035542b63dd00f18034f6ac476b9fc1 Mon Sep 17 00:00:00 2001 From: soumesh113 <58463935+soumesh113@users.noreply.github.com> Date: Sat, 19 Aug 2023 19:09:17 +0530 Subject: [PATCH 11/24] Update test_math.py --- .../test_ivy/test_frontends/test_tensorflow/test_math.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/ivy_tests/test_ivy/test_frontends/test_tensorflow/test_math.py b/ivy_tests/test_ivy/test_frontends/test_tensorflow/test_math.py index d7bc1d72177ec..9d288ab7534f7 100644 --- a/ivy_tests/test_ivy/test_frontends/test_tensorflow/test_math.py +++ b/ivy_tests/test_ivy/test_frontends/test_tensorflow/test_math.py @@ -1948,8 +1948,10 @@ def test_tensorflow_rsqrt( shape = (5, 6), ), dtype_and_segment=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), + available_dtypes=["int32", "int64"], shape = (5, ), + min_value = 0, + max_value = 4, ), test_with_out=st.just(False), ) From a4e53403175636cec9ab386ee52bde4f8393b5b6 Mon Sep 17 00:00:00 2001 From: soumesh113 <58463935+soumesh113@users.noreply.github.com> Date: Sat, 19 Aug 2023 20:10:07 +0530 Subject: [PATCH 12/24] Update math.py --- ivy/functional/frontends/tensorflow/math.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/ivy/functional/frontends/tensorflow/math.py b/ivy/functional/frontends/tensorflow/math.py index d2c05dae15743..fb5e1e48c8eed 100644 --- a/ivy/functional/frontends/tensorflow/math.py +++ b/ivy/functional/frontends/tensorflow/math.py @@ -556,8 +556,7 @@ def segment_sum(data, segment_ids, name="segment_sum"): list(segment_ids.shape), [list(data.shape)[0]], as_array=False ) sum_array = ivy.zeros( - tuple([segment_ids[-1] + 1] + (list(data.shape))[1:]), dtype=ivy.int32 - ) + tuple([segment_ids[-1] + 1] + (list(data.shape))[1:])) for i in range((segment_ids).shape[0]): sum_array[segment_ids[i]] = sum_array[segment_ids[i]] + data[i] return sum_array From c96f16c7b3e210ab5bf427fa6d5dfd94208a892b Mon Sep 17 00:00:00 2001 From: soumesh113 <58463935+soumesh113@users.noreply.github.com> Date: Sun, 20 Aug 2023 12:43:59 +0530 Subject: [PATCH 13/24] Update module.py --- ivy/stateful/module.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ivy/stateful/module.py b/ivy/stateful/module.py index 6a2cd10750fc9..e8a0cee40e414 100644 --- a/ivy/stateful/module.py +++ b/ivy/stateful/module.py @@ -5,7 +5,7 @@ import os import abc import copy -import dill +# import dill from typing import Optional, Tuple, Dict # local From 4f1d84df2d8acba685e0051792c439e5a56c624e Mon Sep 17 00:00:00 2001 From: soumesh113 <58463935+soumesh113@users.noreply.github.com> Date: Sun, 20 Aug 2023 12:54:42 +0530 Subject: [PATCH 14/24] Update math.py --- ivy/functional/frontends/tensorflow/math.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ivy/functional/frontends/tensorflow/math.py b/ivy/functional/frontends/tensorflow/math.py index fb5e1e48c8eed..cbbd1e0c81427 100644 --- a/ivy/functional/frontends/tensorflow/math.py +++ b/ivy/functional/frontends/tensorflow/math.py @@ -556,7 +556,7 @@ def segment_sum(data, segment_ids, name="segment_sum"): list(segment_ids.shape), [list(data.shape)[0]], as_array=False ) sum_array = ivy.zeros( - tuple([segment_ids[-1] + 1] + (list(data.shape))[1:])) + tuple([segment_ids[-1] + 1] + (list(data.shape))[1:]), dtype=ivy.dtype(data)) for i in range((segment_ids).shape[0]): sum_array[segment_ids[i]] = sum_array[segment_ids[i]] + data[i] return sum_array From 7e5b24b5945587836684ab40cdce423b778ff35b Mon Sep 17 00:00:00 2001 From: soumesh113 <58463935+soumesh113@users.noreply.github.com> Date: Sun, 20 Aug 2023 13:23:51 +0530 Subject: [PATCH 15/24] Update assertions.py --- ivy/utils/assertions.py | 29 +++++++++++++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/ivy/utils/assertions.py b/ivy/utils/assertions.py index 17dae2302df33..d5bca848ec7f8 100644 --- a/ivy/utils/assertions.py +++ b/ivy/utils/assertions.py @@ -255,7 +255,36 @@ def check_unsorted_segment_min_valid_params(data, segment_ids, num_segments): if num_segments <= 0: raise ValueError("num_segments must be positive") +def check_segment_sum_valid_params(data, segment_ids): + valid_dtypes = [ + ivy.int32, + ivy.int64, + ] + + if ivy.backend == "torch": + import torch + + valid_dtypes = [ + torch.int32, + torch.int64, + ] + elif ivy.backend == "paddle": + import paddle + + valid_dtypes = [ + paddle.int32, + paddle.int64, + ] + + if segment_ids.dtype not in valid_dtypes: + raise ValueError("segment_ids must have an integer dtype") + + if data.shape[0] != segment_ids.shape[0]: + raise ValueError("The length of segment_ids should be equal to data.shape[0].") + for x in range(1, len(segment_ids)): + if segment_ids[x] < segment_ids[x-1]: + raise InvalidArgumentError("Segment_ids must be sorted") # General # # ------- # From c196ef2e4faee0bce9560726300705de34683189 Mon Sep 17 00:00:00 2001 From: soumesh113 <58463935+soumesh113@users.noreply.github.com> Date: Sun, 20 Aug 2023 13:26:32 +0530 Subject: [PATCH 16/24] Update math.py --- ivy/functional/frontends/tensorflow/math.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/ivy/functional/frontends/tensorflow/math.py b/ivy/functional/frontends/tensorflow/math.py index cbbd1e0c81427..a416cae6059e3 100644 --- a/ivy/functional/frontends/tensorflow/math.py +++ b/ivy/functional/frontends/tensorflow/math.py @@ -551,10 +551,10 @@ def rsqrt(x, name=None): @to_ivy_arrays_and_back def segment_sum(data, segment_ids, name="segment_sum"): - segment_ids = ivy.sort(segment_ids) - ivy.utils.assertions.check_equal( - list(segment_ids.shape), [list(data.shape)[0]], as_array=False - ) + # ivy.utils.assertions.check_equal( + # list(segment_ids.shape), [list(data.shape)[0]], as_array=False + # ) + ivy.utils.assertions.check_segment_sum_valid_params(data, segment_ids) sum_array = ivy.zeros( tuple([segment_ids[-1] + 1] + (list(data.shape))[1:]), dtype=ivy.dtype(data)) for i in range((segment_ids).shape[0]): From a5a0fb6b0835f01b7f78055c760917f6433ec3eb Mon Sep 17 00:00:00 2001 From: soumesh113 <58463935+soumesh113@users.noreply.github.com> Date: Sun, 20 Aug 2023 13:35:14 +0530 Subject: [PATCH 17/24] Update assertions.py --- ivy/utils/assertions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ivy/utils/assertions.py b/ivy/utils/assertions.py index d5bca848ec7f8..683eb24ff1c48 100644 --- a/ivy/utils/assertions.py +++ b/ivy/utils/assertions.py @@ -284,7 +284,7 @@ def check_segment_sum_valid_params(data, segment_ids): raise ValueError("The length of segment_ids should be equal to data.shape[0].") for x in range(1, len(segment_ids)): if segment_ids[x] < segment_ids[x-1]: - raise InvalidArgumentError("Segment_ids must be sorted") + raise ivy.utils.exceptions.IvyException("Segment_ids must be sorted") # General # # ------- # From 1c9a7d3cdbaec4f3a913cdec7ffe784a8f207a29 Mon Sep 17 00:00:00 2001 From: soumesh113 <58463935+soumesh113@users.noreply.github.com> Date: Sun, 20 Aug 2023 13:51:04 +0530 Subject: [PATCH 18/24] Update test_math.py --- ivy_tests/test_ivy/test_frontends/test_tensorflow/test_math.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ivy_tests/test_ivy/test_frontends/test_tensorflow/test_math.py b/ivy_tests/test_ivy/test_frontends/test_tensorflow/test_math.py index 9d288ab7534f7..88c6db8f58a1b 100644 --- a/ivy_tests/test_ivy/test_frontends/test_tensorflow/test_math.py +++ b/ivy_tests/test_ivy/test_frontends/test_tensorflow/test_math.py @@ -1975,7 +1975,7 @@ def test_tensorflow_segment_sum( fn_tree=fn_tree, on_device=on_device, data=data[0], - segment_ids=segment_ids[0], + segment_ids=ivy.sort(segment_ids[0]), ) # nextafter From 0f5b384c3cfca92a48355e82815a64cff082e827 Mon Sep 17 00:00:00 2001 From: soumesh113 <58463935+soumesh113@users.noreply.github.com> Date: Sun, 20 Aug 2023 13:53:23 +0530 Subject: [PATCH 19/24] Update assertions.py --- ivy/utils/assertions.py | 34 +++++++++++++++++----------------- 1 file changed, 17 insertions(+), 17 deletions(-) diff --git a/ivy/utils/assertions.py b/ivy/utils/assertions.py index 683eb24ff1c48..3bd296b6d74d1 100644 --- a/ivy/utils/assertions.py +++ b/ivy/utils/assertions.py @@ -262,23 +262,23 @@ def check_segment_sum_valid_params(data, segment_ids): ivy.int64, ] - if ivy.backend == "torch": - import torch - - valid_dtypes = [ - torch.int32, - torch.int64, - ] - elif ivy.backend == "paddle": - import paddle - - valid_dtypes = [ - paddle.int32, - paddle.int64, - ] - - if segment_ids.dtype not in valid_dtypes: - raise ValueError("segment_ids must have an integer dtype") + # if ivy.backend == "torch": + # import torch + + # valid_dtypes = [ + # torch.int32, + # torch.int64, + # ] + # elif ivy.backend == "paddle": + # import paddle + + # valid_dtypes = [ + # paddle.int32, + # paddle.int64, + # ] + + # if segment_ids.dtype not in valid_dtypes: + # raise ValueError("segment_ids must have an integer dtype") if data.shape[0] != segment_ids.shape[0]: raise ValueError("The length of segment_ids should be equal to data.shape[0].") From cfe881015fb220671b7b07500e77613578682295 Mon Sep 17 00:00:00 2001 From: soumesh113 <58463935+soumesh113@users.noreply.github.com> Date: Sun, 20 Aug 2023 14:14:29 +0530 Subject: [PATCH 20/24] Update test_math.py --- .../test_ivy/test_frontends/test_tensorflow/test_math.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ivy_tests/test_ivy/test_frontends/test_tensorflow/test_math.py b/ivy_tests/test_ivy/test_frontends/test_tensorflow/test_math.py index 88c6db8f58a1b..9cb77ce0d3317 100644 --- a/ivy_tests/test_ivy/test_frontends/test_tensorflow/test_math.py +++ b/ivy_tests/test_ivy/test_frontends/test_tensorflow/test_math.py @@ -1945,11 +1945,11 @@ def test_tensorflow_rsqrt( fn_tree="tensorflow.math.segment_sum", dtype_and_data=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("numeric"), - shape = (5, 6), + shape = ivy.array((5, 6)), ), dtype_and_segment=helpers.dtype_and_values( available_dtypes=["int32", "int64"], - shape = (5, ), + shape = ivy.array((5, )), min_value = 0, max_value = 4, ), From d5c6fa92f5629bfd8fc4a8310619f812763b3f6c Mon Sep 17 00:00:00 2001 From: soumesh113 <58463935+soumesh113@users.noreply.github.com> Date: Sun, 20 Aug 2023 14:22:03 +0530 Subject: [PATCH 21/24] Update test_math.py --- .../test_ivy/test_frontends/test_tensorflow/test_math.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ivy_tests/test_ivy/test_frontends/test_tensorflow/test_math.py b/ivy_tests/test_ivy/test_frontends/test_tensorflow/test_math.py index 9cb77ce0d3317..88c6db8f58a1b 100644 --- a/ivy_tests/test_ivy/test_frontends/test_tensorflow/test_math.py +++ b/ivy_tests/test_ivy/test_frontends/test_tensorflow/test_math.py @@ -1945,11 +1945,11 @@ def test_tensorflow_rsqrt( fn_tree="tensorflow.math.segment_sum", dtype_and_data=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("numeric"), - shape = ivy.array((5, 6)), + shape = (5, 6), ), dtype_and_segment=helpers.dtype_and_values( available_dtypes=["int32", "int64"], - shape = ivy.array((5, )), + shape = (5, ), min_value = 0, max_value = 4, ), From 55d364f3d1feaaf4b3f85306a670beea961a9c01 Mon Sep 17 00:00:00 2001 From: soumesh113 <58463935+soumesh113@users.noreply.github.com> Date: Sun, 20 Aug 2023 14:40:12 +0530 Subject: [PATCH 22/24] Update test_math.py --- ivy_tests/test_ivy/test_frontends/test_tensorflow/test_math.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ivy_tests/test_ivy/test_frontends/test_tensorflow/test_math.py b/ivy_tests/test_ivy/test_frontends/test_tensorflow/test_math.py index 88c6db8f58a1b..cac8cc10fe085 100644 --- a/ivy_tests/test_ivy/test_frontends/test_tensorflow/test_math.py +++ b/ivy_tests/test_ivy/test_frontends/test_tensorflow/test_math.py @@ -1968,7 +1968,7 @@ def test_tensorflow_segment_sum( data_dtype, data = dtype_and_data segment_dtype, segment_ids = dtype_and_segment helpers.test_frontend_function( - input_dtypes=data_dtype + segment_dtype, + input_dtypes=data_dtype + segment_dtype frontend=frontend, backend_to_test=backend_fw, test_flags=test_flags, From c3f214a3fe7878f85586a0b64e66b0c77bbe0129 Mon Sep 17 00:00:00 2001 From: soumesh113 <58463935+soumesh113@users.noreply.github.com> Date: Sun, 20 Aug 2023 14:43:48 +0530 Subject: [PATCH 23/24] Update math.py --- ivy/functional/frontends/tensorflow/math.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/ivy/functional/frontends/tensorflow/math.py b/ivy/functional/frontends/tensorflow/math.py index a416cae6059e3..1f8b1b4e40806 100644 --- a/ivy/functional/frontends/tensorflow/math.py +++ b/ivy/functional/frontends/tensorflow/math.py @@ -551,10 +551,9 @@ def rsqrt(x, name=None): @to_ivy_arrays_and_back def segment_sum(data, segment_ids, name="segment_sum"): - # ivy.utils.assertions.check_equal( - # list(segment_ids.shape), [list(data.shape)[0]], as_array=False - # ) - ivy.utils.assertions.check_segment_sum_valid_params(data, segment_ids) + ivy.utils.assertions.check_equal( + list(segment_ids.shape), [list(data.shape)[0]], as_array=False + ) sum_array = ivy.zeros( tuple([segment_ids[-1] + 1] + (list(data.shape))[1:]), dtype=ivy.dtype(data)) for i in range((segment_ids).shape[0]): From b78280b2fe59924c3cd53a47c7fbee6069ff1712 Mon Sep 17 00:00:00 2001 From: soumesh113 <58463935+soumesh113@users.noreply.github.com> Date: Sun, 20 Aug 2023 14:48:13 +0530 Subject: [PATCH 24/24] Update test_math.py --- ivy_tests/test_ivy/test_frontends/test_tensorflow/test_math.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ivy_tests/test_ivy/test_frontends/test_tensorflow/test_math.py b/ivy_tests/test_ivy/test_frontends/test_tensorflow/test_math.py index cac8cc10fe085..88c6db8f58a1b 100644 --- a/ivy_tests/test_ivy/test_frontends/test_tensorflow/test_math.py +++ b/ivy_tests/test_ivy/test_frontends/test_tensorflow/test_math.py @@ -1968,7 +1968,7 @@ def test_tensorflow_segment_sum( data_dtype, data = dtype_and_data segment_dtype, segment_ids = dtype_and_segment helpers.test_frontend_function( - input_dtypes=data_dtype + segment_dtype + input_dtypes=data_dtype + segment_dtype, frontend=frontend, backend_to_test=backend_fw, test_flags=test_flags,