Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MPS] add support for sgn to MPS backend #110829

Closed
wants to merge 3 commits into from
Closed

Conversation

igm503
Copy link
Contributor

@igm503 igm503 commented Oct 8, 2023

Fixes #86805

Adds support for sgn to MPS backend.

Notes:

  1. @malfet self-assigned this when he was working on implementing polar, but from what I can tell, he didn't end up needing to implement it.

  2. @Berzeg implemented this last year, before view_as_complex was supported. Because of @malfet recent contributions, however, @Berzeg 's implementation works. I've removed the part of his implementation that dealt with non-complex dtypes (since these can just be passed to at::sign), matched the more recent pattern we've been using in UnaryOps.mm, and thrown in a simple implementation of _efficientzerotensor for mps, so that the backward function works.

  3. @Berzeg deserves a good bit of credit for this, so let me know if there's a way to assign him some without jamming up the pr (he seems to be AWOL since last working on this)

@igm503 igm503 requested a review from kulinseth as a code owner October 8, 2023 19:06
@pytorch-bot pytorch-bot bot added ciflow/mps Run MPS tests (subset of trunk) release notes: mps Release notes category labels Oct 8, 2023
@pytorch-bot
Copy link

pytorch-bot bot commented Oct 8, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/110829

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit fb8831c with merge base fa8e4ea (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

Copy link
Contributor

@malfet malfet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM but please explain why _efficientzerotensor_mps is needed. (I'm totally fine with adding it, but if its not needed for enabling sgn, would appreciate adding it as separate PR.
Also, complex part of the sign.out would not be tested until one adds it to SUPPORTED_COMPLEX_OPS list in test_mps.py, would it?

@igm503
Copy link
Contributor Author

igm503 commented Oct 9, 2023

LGTM but please explain why _efficientzerotensor_mps is needed. (I'm totally fine with adding it, but if its not needed for enabling sgn, would appreciate adding it as separate PR. Also, complex part of the sign.out would not be tested until one adds it to SUPPORTED_COMPLEX_OPS list in test_mps.py, would it?

Thanks @malfet!

Re _efficientzerotensor, the sgn grad function seems to use it:

Tensor sgn_backward(const Tensor& x, const Tensor& gx, const Tensor& sgn) {

Re SUPPORTED_COMPLEX_OPS, yeah, but, for some reason, someone had already added it to the list.

Also, sorry, but I just did some refactoring of the sgn MPSgraph block (I hadn't looked at Berzeg's implementation very closely before, other than checking that it was doing the correct calculation). The new version is roughly 50% faster on my machine with large inputs, mostly because it doesn't split the input tensor into two separate tensors and repeat operations on both of them.

Comment on lines +462 to +465
Tensor output_copy = output.alias();
at::sign_out(output_copy, self);
output.copy_(output_copy);
return;
Copy link
Contributor

@malfet malfet Oct 9, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be much faster, shouldn't it? (I.e. it avoids dispatch, also, feels like sign_out is invoked in inverse right now, isn't it?

Suggested change
Tensor output_copy = output.alias();
at::sign_out(output_copy, self);
output.copy_(output_copy);
return;
return sign_out_mps(self, output);

Copy link
Contributor Author

@igm503 igm503 Oct 9, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This causes an error, I think because sign_out_mps is declared with the TORCH_IMPL_FUNC macro. But maybe I'm missing something? Is there a way to directly call a function declared in the T_I_F macro?

Re the reverse order of the arguments, I was surprised too! For some reason, at::sign_out takes the arguments in reverse! See here:

// aten::sign.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
inline at::Tensor & sign_out(at::Tensor & out, const at::Tensor & self) {
    return at::_ops::sign_out::call(self, out);
}
// aten::sign.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
inline at::Tensor & sign_outf(const at::Tensor & self, at::Tensor & out) {
    return at::_ops::sign_out::call(self, out);
}

I was also trying to get rid of the alias() and subsequent copy_() call. The obstacle as I understand it is that output is const and the available sign_out functions take non-const arguments for output.

Copy link
Contributor Author

@igm503 igm503 Oct 9, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should work, but I'm not sure if it's preferable:

output.copy_(at::sign(self));
return;

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@igm503, no this would be much slower, as it would incur extraneous copy.
Anyway, let's land as is and I'll submit a follow up PR to fix this.

@malfet
Copy link
Contributor

malfet commented Oct 9, 2023

@pytorchbot merge -f "MPS tests are green"

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Please use -f as last resort and instead consider -i/--ignore-current to continue the merge ignoring current failures. This will allow currently pending tests to finish and report signal before the merge.

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@kkirby
Copy link

kkirby commented Oct 16, 2023

Why was this closed without merge when it seemed everyone was on the same page to merge it?

-- edit --

Never mind. That's one strange way to handle merging, but alright.

@igm503
Copy link
Contributor Author

igm503 commented Oct 16, 2023

Why was this closed without merge when it seemed everyone was on the same page to merge it?

-- edit --

Never mind. That's one strange way to handle merging, but alright.

Yeah, it confused me at first, too

@qqaatw
Copy link
Collaborator

qqaatw commented Oct 18, 2023

Not a Meta employee but long story short, they use a merge bot to commit the PR to main on your behalf without touching the GitHub built-in merge mechanism to be more flexible on controlling the merging process and ensuring the code quality. So after your PR is committed to main by the bot, Pull Request resolved: xxx in the commit message will trigger GitHub to close the corresponding PR directly.

@kebwi
Copy link

kebwi commented Dec 21, 2023

As of Dec. 21, using the latest version, v2.1.2, I am still encountering this problem. Why is the implication of the discussion above and the status of this issue (closed) that the problem has been solved? It doesn't seem that way based on my current usage.

@qqaatw
Copy link
Collaborator

qqaatw commented Dec 21, 2023

Hey @kebwi, the 2.1 release branch was cut on Aug 28, 2023, and this PR was merged on Oct 9, 2023. Therefore, it's not in the 2.1.* releases, where patch versions do not introduce new features. You could try the nightly version of PyTorch and see if it works.

@pattplatt
Copy link

Hey @kebwi, the 2.1 release branch was cut on Aug 28, 2023, and this PR was merged on Oct 9, 2023. Therefore, it's not in the 2.1.* releases, where patch versions do not introduce new features. You could try the nightly version of PyTorch and see if it works.

Seems like its also not included in the nightly yet, i get the same error as with stable

@qqaatw
Copy link
Collaborator

qqaatw commented Dec 22, 2023

I'm confused.

@igm503 would you mind taking a look at it?

@igm503
Copy link
Contributor Author

igm503 commented Dec 24, 2023

@qqaatw Sure.

@igm503
Copy link
Contributor Author

igm503 commented Dec 24, 2023

What I've found:

  1. Following the pytorch website's instructions for installing the latest nightly for mac via pip doesn't work on my machine. It installs an old 1.13.1 pre-release version, for some reason, so to get the correct version via pip, you have to do some more work.

  2. Following the instructions for installing the latest mac nightly via conda does work:

(torch_test) [~]$ pip show torch
Name: torch
Version: 2.3.0.dev20231224
Summary: Tensors and Dynamic neural networks in Python with strong GPU acceleration
Home-page: https://pytorch.org/
Author: PyTorch Team
Author-email: [email protected]
License: BSD-3
Location: /opt/homebrew/Caskroom/miniforge/base/envs/torch_test/lib/python3.9/site-packages
Requires: filelock, fsspec, jinja2, networkx, sympy, typing-extensions
Required-by: torchaudio, torchvision
  1. Using the today's latest nightly, on my machine, torch.sign does work on the mps backend:
(torch_test) [~]$ python
Python 3.9.18 (main, Sep 11 2023, 08:25:10) 
[Clang 14.0.6 ] :: Anaconda, Inc. on darwin
Type "help", "copyright", "credits" or "license" for more information.
>>> import torch
>>> a = torch.tensor([0.7, -1.2, 0., 2.3])
>>> a
tensor([ 0.7000, -1.2000,  0.0000,  2.3000])
>>> torch.sign(a)
tensor([ 1., -1.,  0.,  1.])
>>> a.device
device(type='cpu')
>>> a.to('mps')
tensor([ 0.7000, -1.2000,  0.0000,  2.3000], device='mps:0')
>>> torch.sign(a.to('mps'))
tensor([ 1., -1.,  0.,  1.], device='mps:0')
>>> a = a.to('mps')
>>> a
tensor([ 0.7000, -1.2000,  0.0000,  2.3000], device='mps:0')
>>> torch.sign(a)
tensor([ 1., -1.,  0.,  1.], device='mps:0')

@pattplatt If you're still running into issues, can you share more info, like which exact version of torch you're using and what code is producing the not supported error message?

@pattplatt
Copy link

@igm503 thanks, your solution with conda works

@x4080
Copy link

x4080 commented Dec 28, 2023

@igm503 can it works using venv also ? Or must be conda for some reason ?

@igm503
Copy link
Contributor Author

igm503 commented Dec 29, 2023

@x4080 yes, you could use a venv. All I was saying regarding pip vs. conda in that last post was that the pip install command on pytorch.com get was pointing to an out of date nightly. It seems to work fine now, however, when I test it, and it may just have been a problem with my locally cached package downloads.

The takeaway here is just to manually confirm that you have the right version of pytorch installed when having trouble with new features, since there's often a bit of lag between when the features are merged into the library and when they are available in release versions of the library. For this op (sgn.out on MPS backend), just make sure your pytorch version is >= 2.2

@x4080
Copy link

x4080 commented Dec 29, 2023

@igm503 Thanks

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/mps Run MPS tests (subset of trunk) Merged open source release notes: mps Release notes category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[MPS] Add support for aten::sgn.out for MPS backend
9 participants