Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add torch frontend masked_scatter and masked_scatter_ #28783

Merged
merged 9 commits into from
Jul 10, 2024

Conversation

Daniel4078
Copy link
Contributor

@Daniel4078 Daniel4078 commented Jul 2, 2024

PR Description

Related Issue

Closes #

Checklist

  • Did you add a function?
  • Did you add the tests?
  • Did you run your tests and are your tests passing?
  • Did pre-commit not fail on any check?
  • Did you follow the steps we provided?

Socials

@Daniel4078 Daniel4078 marked this pull request as ready for review July 6, 2024 04:28
@Daniel4078
Copy link
Contributor Author

Daniel4078 commented Jul 7, 2024

@Sam-Armstrong Do you know why while the helper function I wrote generates input tensors of the same size (that follows the broadcastable requirement shown here https://pytorch.org/docs/stable/generated/torch.Tensor.masked_scatter_.html#torch.Tensor.masked_scatter_), the test still sometime fail with error of unable to broadcast between two arrays.
I examine the logs and it seems like while torch masked scatter only require the source tensor to have enough element to account for all True entry in the mask tensor, other functions I used to implement it somehow requires the mask to have shape broadcastable to the source. What should I do?

@Sam-Armstrong
Copy link
Contributor

@Daniel4078 I've fixed the implementation for masked_scatter so it passes the test (aside paddle backend). Please could you implement the fix for masked_scatter_, add a test for masked_scatter_, and remove the changes you made to index_put? Then we'll be ready to merge. Thanks! 😊

…ed, during with the dtype is passed as None when converting result back to tensor for return
@Daniel4078
Copy link
Contributor Author

@Sam-Armstrong I have done the changes you asked for, now only the paddle backend tests fail probabily due to someone forgot to pass the dtype parameter when converting to tensor in the implementaiton of put_along_axis in paddle backend. Thank you!

Copy link
Contributor

@Sam-Armstrong Sam-Armstrong left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm! thanks @Daniel4078 😁

@Sam-Armstrong Sam-Armstrong merged commit 8af3ae7 into ivy-llc:main Jul 10, 2024
4 of 5 checks passed
@Daniel4078 Daniel4078 deleted the torch.Tensor.masked_scatter branch July 10, 2024 13:13
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants