-
Notifications
You must be signed in to change notification settings - Fork 22.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[MPS] add support for sgn to MPS backend #110829
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/110829
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit fb8831c with merge base fa8e4ea (): This comment was automatically generated by Dr. CI and updates every 15 minutes. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM but please explain why _efficientzerotensor_mps
is needed. (I'm totally fine with adding it, but if its not needed for enabling sgn
, would appreciate adding it as separate PR.
Also, complex part of the sign.out
would not be tested until one adds it to SUPPORTED_COMPLEX_OPS
list in test_mps.py
, would it?
Thanks @malfet! Re _efficientzerotensor, the sgn grad function seems to use it:
Re SUPPORTED_COMPLEX_OPS, yeah, but, for some reason, someone had already added it to the list. Also, sorry, but I just did some refactoring of the sgn MPSgraph block (I hadn't looked at Berzeg's implementation very closely before, other than checking that it was doing the correct calculation). The new version is roughly 50% faster on my machine with large inputs, mostly because it doesn't split the input tensor into two separate tensors and repeat operations on both of them. |
Tensor output_copy = output.alias(); | ||
at::sign_out(output_copy, self); | ||
output.copy_(output_copy); | ||
return; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should be much faster, shouldn't it? (I.e. it avoids dispatch, also, feels like sign_out
is invoked in inverse right now, isn't it?
Tensor output_copy = output.alias(); | |
at::sign_out(output_copy, self); | |
output.copy_(output_copy); | |
return; | |
return sign_out_mps(self, output); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This causes an error, I think because sign_out_mps is declared with the TORCH_IMPL_FUNC macro. But maybe I'm missing something? Is there a way to directly call a function declared in the T_I_F macro?
Re the reverse order of the arguments, I was surprised too! For some reason, at::sign_out takes the arguments in reverse! See here:
// aten::sign.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
inline at::Tensor & sign_out(at::Tensor & out, const at::Tensor & self) {
return at::_ops::sign_out::call(self, out);
}
// aten::sign.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
inline at::Tensor & sign_outf(const at::Tensor & self, at::Tensor & out) {
return at::_ops::sign_out::call(self, out);
}
I was also trying to get rid of the alias() and subsequent copy_() call. The obstacle as I understand it is that output is const and the available sign_out functions take non-const arguments for output.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should work, but I'm not sure if it's preferable:
output.copy_(at::sign(self));
return;
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@igm503, no this would be much slower, as it would incur extraneous copy.
Anyway, let's land as is and I'll submit a follow up PR to fix this.
@pytorchbot merge -f "MPS tests are green" |
Merge startedYour change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Please use Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Why was this closed without merge when it seemed everyone was on the same page to merge it? -- edit -- Never mind. That's one strange way to handle merging, but alright. |
Yeah, it confused me at first, too |
Not a Meta employee but long story short, they use a merge bot to commit the PR to main on your behalf without touching the GitHub built-in merge mechanism to be more flexible on controlling the merging process and ensuring the code quality. So after your PR is committed to main by the bot, |
As of Dec. 21, using the latest version, v2.1.2, I am still encountering this problem. Why is the implication of the discussion above and the status of this issue (closed) that the problem has been solved? It doesn't seem that way based on my current usage. |
Hey @kebwi, the |
Seems like its also not included in the nightly yet, i get the same error as with stable |
I'm confused. @igm503 would you mind taking a look at it? |
@qqaatw Sure. |
What I've found:
@pattplatt If you're still running into issues, can you share more info, like which exact version of torch you're using and what code is producing the not supported error message? |
@igm503 thanks, your solution with conda works |
@igm503 can it works using venv also ? Or must be conda for some reason ? |
@x4080 yes, you could use a venv. All I was saying regarding pip vs. conda in that last post was that the pip install command on pytorch.com get was pointing to an out of date nightly. It seems to work fine now, however, when I test it, and it may just have been a problem with my locally cached package downloads. The takeaway here is just to manually confirm that you have the right version of pytorch installed when having trouble with new features, since there's often a bit of lag between when the features are merged into the library and when they are available in release versions of the library. For this op (sgn.out on MPS backend), just make sure your pytorch version is >= 2.2 |
@igm503 Thanks |
Fixes #86805
Adds support for sgn to MPS backend.
Notes:
@malfet self-assigned this when he was working on implementing polar, but from what I can tell, he didn't end up needing to implement it.
@Berzeg implemented this last year, before view_as_complex was supported. Because of @malfet recent contributions, however, @Berzeg 's implementation works. I've removed the part of his implementation that dealt with non-complex dtypes (since these can just be passed to at::sign), matched the more recent pattern we've been using in UnaryOps.mm, and thrown in a simple implementation of _efficientzerotensor for mps, so that the backward function works.
@Berzeg deserves a good bit of credit for this, so let me know if there's a way to assign him some without jamming up the pr (he seems to be AWOL since last working on this)