diff --git a/ivy/functional/frontends/tensorflow/math.py b/ivy/functional/frontends/tensorflow/math.py index 8fccb8cbfcc78..1f8b1b4e40806 100644 --- a/ivy/functional/frontends/tensorflow/math.py +++ b/ivy/functional/frontends/tensorflow/math.py @@ -549,6 +549,16 @@ def tanh(x, name=None): def rsqrt(x, name=None): return ivy.reciprocal(ivy.sqrt(x)) +@to_ivy_arrays_and_back +def segment_sum(data, segment_ids, name="segment_sum"): + ivy.utils.assertions.check_equal( + list(segment_ids.shape), [list(data.shape)[0]], as_array=False + ) + sum_array = ivy.zeros( + tuple([segment_ids[-1] + 1] + (list(data.shape))[1:]), dtype=ivy.dtype(data)) + for i in range((segment_ids).shape[0]): + sum_array[segment_ids[i]] = sum_array[segment_ids[i]] + data[i] + return sum_array @to_ivy_arrays_and_back def nextafter(x1, x2, name=None): diff --git a/ivy/stateful/module.py b/ivy/stateful/module.py index 6a2cd10750fc9..e8a0cee40e414 100644 --- a/ivy/stateful/module.py +++ b/ivy/stateful/module.py @@ -5,7 +5,7 @@ import os import abc import copy -import dill +# import dill from typing import Optional, Tuple, Dict # local diff --git a/ivy/utils/assertions.py b/ivy/utils/assertions.py index 17dae2302df33..3bd296b6d74d1 100644 --- a/ivy/utils/assertions.py +++ b/ivy/utils/assertions.py @@ -255,7 +255,36 @@ def check_unsorted_segment_min_valid_params(data, segment_ids, num_segments): if num_segments <= 0: raise ValueError("num_segments must be positive") +def check_segment_sum_valid_params(data, segment_ids): + valid_dtypes = [ + ivy.int32, + ivy.int64, + ] + + # if ivy.backend == "torch": + # import torch + + # valid_dtypes = [ + # torch.int32, + # torch.int64, + # ] + # elif ivy.backend == "paddle": + # import paddle + + # valid_dtypes = [ + # paddle.int32, + # paddle.int64, + # ] + + # if segment_ids.dtype not in valid_dtypes: + # raise ValueError("segment_ids must have an integer dtype") + + if data.shape[0] != segment_ids.shape[0]: + raise ValueError("The length of segment_ids should be equal to data.shape[0].") + for x in range(1, len(segment_ids)): + if segment_ids[x] < segment_ids[x-1]: + raise ivy.utils.exceptions.IvyException("Segment_ids must be sorted") # General # # ------- # diff --git a/ivy_tests/test_ivy/test_frontends/test_tensorflow/test_math.py b/ivy_tests/test_ivy/test_frontends/test_tensorflow/test_math.py index 9f7f44cf4cf3f..88c6db8f58a1b 100644 --- a/ivy_tests/test_ivy/test_frontends/test_tensorflow/test_math.py +++ b/ivy_tests/test_ivy/test_frontends/test_tensorflow/test_math.py @@ -1940,6 +1940,43 @@ def test_tensorflow_rsqrt( x=x[0], ) +#segment_sum +@handle_frontend_test( + fn_tree="tensorflow.math.segment_sum", + dtype_and_data=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("numeric"), + shape = (5, 6), + ), + dtype_and_segment=helpers.dtype_and_values( + available_dtypes=["int32", "int64"], + shape = (5, ), + min_value = 0, + max_value = 4, + ), + test_with_out=st.just(False), +) +def test_tensorflow_segment_sum( + *, + dtype_and_data, + dtype_and_segment, + frontend, + test_flags, + fn_tree, + backend_fw, + on_device, +): + data_dtype, data = dtype_and_data + segment_dtype, segment_ids = dtype_and_segment + helpers.test_frontend_function( + input_dtypes=data_dtype + segment_dtype, + frontend=frontend, + backend_to_test=backend_fw, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + data=data[0], + segment_ids=ivy.sort(segment_ids[0]), + ) # nextafter @handle_frontend_test(