Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added segment_sum function to math.py #21907

Closed
wants to merge 28 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
cc02bea
Update math.py
soumesh113 Aug 15, 2023
16f1fd6
Update test_math.py
soumesh113 Aug 15, 2023
85c6628
Update test_math.py
soumesh113 Aug 15, 2023
cc0424f
Merge branch 'unifyai:main' into segment_sum
soumesh113 Aug 19, 2023
7238e07
Update test_math.py
soumesh113 Aug 19, 2023
263602d
Update math.py
soumesh113 Aug 19, 2023
8b11c42
Update test_math.py
soumesh113 Aug 19, 2023
ae97f86
Update math.py
soumesh113 Aug 19, 2023
530cf8f
Update test_math.py
soumesh113 Aug 19, 2023
36e0bee
Update test_math.py
soumesh113 Aug 19, 2023
e71e241
Update math.py
soumesh113 Aug 19, 2023
125b9dc
Merge branch 'unifyai:main' into segment_sum
soumesh113 Aug 19, 2023
8edfec7
Update test_math.py
soumesh113 Aug 19, 2023
a4e5340
Update math.py
soumesh113 Aug 19, 2023
bbb0793
Merge branch 'unifyai:main' into segment_sum
soumesh113 Aug 20, 2023
c96f16c
Update module.py
soumesh113 Aug 20, 2023
4f1d84d
Update math.py
soumesh113 Aug 20, 2023
7e5b24b
Update assertions.py
soumesh113 Aug 20, 2023
c196ef2
Update math.py
soumesh113 Aug 20, 2023
a5a0fb6
Update assertions.py
soumesh113 Aug 20, 2023
1c9a7d3
Update test_math.py
soumesh113 Aug 20, 2023
0f5b384
Update assertions.py
soumesh113 Aug 20, 2023
cfe8810
Update test_math.py
soumesh113 Aug 20, 2023
d5c6fa9
Update test_math.py
soumesh113 Aug 20, 2023
baca3df
Merge branch 'unifyai:main' into segment_sum
soumesh113 Aug 20, 2023
55d364f
Update test_math.py
soumesh113 Aug 20, 2023
c3f214a
Update math.py
soumesh113 Aug 20, 2023
b78280b
Update test_math.py
soumesh113 Aug 20, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions ivy/functional/frontends/tensorflow/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -549,6 +549,16 @@ def tanh(x, name=None):
def rsqrt(x, name=None):
return ivy.reciprocal(ivy.sqrt(x))

@to_ivy_arrays_and_back
def segment_sum(data, segment_ids, name="segment_sum"):
ivy.utils.assertions.check_equal(
list(segment_ids.shape), [list(data.shape)[0]], as_array=False
)
sum_array = ivy.zeros(
tuple([segment_ids[-1] + 1] + (list(data.shape))[1:]), dtype=ivy.dtype(data))
for i in range((segment_ids).shape[0]):
sum_array[segment_ids[i]] = sum_array[segment_ids[i]] + data[i]
return sum_array

@to_ivy_arrays_and_back
def nextafter(x1, x2, name=None):
Expand Down
2 changes: 1 addition & 1 deletion ivy/stateful/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import os
import abc
import copy
import dill
# import dill
from typing import Optional, Tuple, Dict

# local
Expand Down
29 changes: 29 additions & 0 deletions ivy/utils/assertions.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,36 @@ def check_unsorted_segment_min_valid_params(data, segment_ids, num_segments):
if num_segments <= 0:
raise ValueError("num_segments must be positive")

def check_segment_sum_valid_params(data, segment_ids):

valid_dtypes = [
ivy.int32,
ivy.int64,
]

# if ivy.backend == "torch":
# import torch

# valid_dtypes = [
# torch.int32,
# torch.int64,
# ]
# elif ivy.backend == "paddle":
# import paddle

# valid_dtypes = [
# paddle.int32,
# paddle.int64,
# ]

# if segment_ids.dtype not in valid_dtypes:
# raise ValueError("segment_ids must have an integer dtype")

if data.shape[0] != segment_ids.shape[0]:
raise ValueError("The length of segment_ids should be equal to data.shape[0].")
for x in range(1, len(segment_ids)):
if segment_ids[x] < segment_ids[x-1]:
raise ivy.utils.exceptions.IvyException("Segment_ids must be sorted")
# General #
# ------- #

Expand Down
37 changes: 37 additions & 0 deletions ivy_tests/test_ivy/test_frontends/test_tensorflow/test_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -1940,6 +1940,43 @@ def test_tensorflow_rsqrt(
x=x[0],
)

#segment_sum
@handle_frontend_test(
fn_tree="tensorflow.math.segment_sum",
dtype_and_data=helpers.dtype_and_values(
available_dtypes=helpers.get_dtypes("numeric"),
shape = (5, 6),
),
dtype_and_segment=helpers.dtype_and_values(
available_dtypes=["int32", "int64"],
shape = (5, ),
min_value = 0,
max_value = 4,
),
test_with_out=st.just(False),
)
def test_tensorflow_segment_sum(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hello! The test is failing for all backends. You can see the stacktrace and error in the details section of the intelligent-tests (1)

Let me know if you have any questions!

*,
dtype_and_data,
dtype_and_segment,
frontend,
test_flags,
fn_tree,
backend_fw,
on_device,
):
data_dtype, data = dtype_and_data
segment_dtype, segment_ids = dtype_and_segment
helpers.test_frontend_function(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The tests are still failing. It shows TypeError: 'Array' object cannot be interpreted as an integer
You don't need to send the dtyès as an array. Check out this call as an example

      helpers.test_frontend_function(
        input_dtypes=input_dtype,
        backend_to_test=backend_fw,
        frontend=frontend,
        test_flags=test_flags,
        fn_tree=fn_tree,
        on_device=on_device,
        x1=x[0],
        x2=x[1],
    )

To the kwarg input_dtypes you only need to pass the data_dtype. And you pass the inputs at the end of the function call. And so on. Let me know if this makes sense!

input_dtypes=data_dtype + segment_dtype,
frontend=frontend,
backend_to_test=backend_fw,
test_flags=test_flags,
fn_tree=fn_tree,
on_device=on_device,
data=data[0],
segment_ids=ivy.sort(segment_ids[0]),
)

# nextafter
@handle_frontend_test(
Expand Down
Loading