-
Notifications
You must be signed in to change notification settings - Fork 5.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Added segment_sum function to math.py #21907
Conversation
Thanks for contributing to Ivy! 😊👏 |
Hi @ZoeCD , could you please review this PR and suggest any changes to be made? |
@@ -549,6 +549,19 @@ def tanh(x, name=None): | |||
def rsqrt(x, name=None): | |||
return ivy.reciprocal(ivy.sqrt(x)) | |||
|
|||
@to_ivy_arrays_and_back | |||
def segment_sum(data, segment_ids, name= "segment_sum"): | |||
data = ivy.array(data) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You don't need to change the arrays to ivy arrays. The to_ivy_arrays_and_back
decorator handles that
on_device, | ||
): | ||
helpers.test_frontend_function( | ||
input_dtypes=["int32", "int64"], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here you should be passing the dtype that your data has. I recommend checking other tests to see how it is done
), | ||
test_with_out=st.just(False), | ||
) | ||
def test_tensorflow_segment_sum( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hello! The test is failing for all backends. You can see the stacktrace and error in the details section of the intelligent-tests (1)
Let me know if you have any questions!
): | ||
data_dtype, data = dtype_and_data | ||
segment_dtype, segment_ids = dtype_and_segment | ||
helpers.test_frontend_function( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The tests are still failing. It shows TypeError: 'Array' object cannot be interpreted as an integer
You don't need to send the dtyès as an array. Check out this call as an example
helpers.test_frontend_function(
input_dtypes=input_dtype,
backend_to_test=backend_fw,
frontend=frontend,
test_flags=test_flags,
fn_tree=fn_tree,
on_device=on_device,
x1=x[0],
x2=x[1],
)
To the kwarg input_dtypes
you only need to pass the data_dtype
. And you pass the inputs at the end of the function call. And so on. Let me know if this makes sense!
This PR has been labelled as stale because it has been inactive for more than 7 days. If you would like to continue working on this PR, then please add another comment or this PR will be closed in 7 days. |
Close #21903