Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding aten::unsqueeze_ to PT Frontend #7231

Merged
merged 6 commits into from
Jan 13, 2021

Conversation

codeislife99
Copy link
Contributor

THis PR adds the above ops to PT Frontend to support customer models.

@codeislife99
Copy link
Contributor Author

cc: @anijain2305 @trevor-m

@masahi
Copy link
Member

masahi commented Jan 8, 2021

cc @t-vi @yongwww

This is probably not how we should support copy_. We discussed this issue before.

Later I found a promising way to support inplace copy/assignment: We first convert aten::copy_ to aten::index_put (see https://pytorch.org/docs/stable/tensors.html#torch.Tensor.index_put_) using PyTorch jit passes, and then convert aten::index_put to relay scatter_nd. This is also how PyTorch ONNX export supports inplace assignment, see pytorch/pytorch#26941

Relay scatter_nd op was added in #6854. You can convert aten::copy_ to aten::inplace_put by adding the following after

torch._C._jit_pass_onnx_function_substitution(graph)

torch._C._jit_pass_onnx_prepare_inplace_ops_for_onnx(graph)
torch._C._jit_pass_dce_allow_deleting_nodes_with_side_effects(graph)

@t-vi Does this sound a reasonable way to support inplace assignment?

@t-vi
Copy link
Contributor

t-vi commented Jan 8, 2021

@masahi can you elaborate a bit how you want to do that?

So to my mind there are two parts to this. For slice assignments

@torch.jit.script
def foo(x):
  x[:5] = 0
  return 2 * x

we get

graph(%x.1 : Tensor):
  %10 : bool = prim::Constant[value=0]()
  %6 : int = prim::Constant[value=1]()
  %1 : int = prim::Constant[value=0]()
  %4 : int = prim::Constant[value=5]()
  %14 : int = prim::Constant[value=2]()
  %7 : Tensor = aten::slice(%x.1, %1, %1, %4, %6)
  %8 : int = prim::dtype(%7)
  %9 : Device = prim::device(%7)
  %11 : Tensor = aten::tensor(%1, %8, %9, %10)
  %13 : Tensor = aten::copy_(%7, %11, %10)
  %16 : Tensor = aten::mul(%x.1, %14) # <string>:3:9
  return (%16)

Note how %x.1 and %7 change their values in the line defining %13. In this sense, TorchScript with inplace isn't a true SSA form.

So we need three things:

  • scatter (or some form of where) may provide us how to compute the tensor that is %x.1 after the copy,
  • we would still need to track views,
  • we would need to replace all tensors affected by the update with the new version.

Note that while for this simple case, using the pattern

  out = copy_(slice(inp, ...), something)

would be possible, but

@torch.jit.script
def foo(x):
  y = x[:2]
  y[:5] = 0
  return 2 * x

will mess that up for you.

P.S.: While you mention working on the TorchScript , I have a concrete plan (see my blog) to support more graph operations in TorchScript from python, so if you need something in particular, we might try to get that.

@masahi
Copy link
Member

masahi commented Jan 8, 2021

For the first case,

def foo(x):
  x[:5] = 0
  return 2 * x

Using the two passes I mentioned, I get this graph:

graph(%x : Float(10:1, requires_grad=0, device=cpu)):
  %1 : int = prim::Constant[value=0]() # test.py:8:0
  %2 : int = prim::Constant[value=0]() # test.py:8:0
  %3 : int = prim::Constant[value=5]() # test.py:8:0
  %4 : int = prim::Constant[value=1]() # test.py:8:0
  %5 : Float(5:1, requires_grad=0, device=cpu) = aten::slice(%x, %1, %2, %3, %4) # test.py:8:0
  %6 : Float(requires_grad=0, device=cpu) = prim::Constant[value={0}]() # test.py:8:0
  %7 : int[] = prim::ListConstruct()
  %8 : Float(requires_grad=0, device=cpu) = aten::view(%6, %7) # test.py:8:0
  %9 : int = prim::Constant[value=5]() # test.py:8:0
  %10 : int[] = prim::ListConstruct(%9)
  %11 : bool = prim::Constant[value=1]() # test.py:8:0
  %12 : Float(5:0, requires_grad=0, device=cpu) = aten::expand(%8, %10, %11) # test.py:8:0
  %13 : bool = prim::Constant[value=0]()
  %18 : Float(5:0, requires_grad=0, device=cpu) = aten::expand_as(%12, %5) # test.py:8:0
  %20 : int = prim::Constant[value=0]()
  %21 : int = aten::size(%x, %20)
  %22 : int = prim::Constant[value=4]()
  %23 : None = prim::Constant()
  %24 : None = prim::Constant()
  %25 : None = prim::Constant()
  %26 : Tensor = aten::arange(%21, %22, %23, %24, %25)
  %27 : int = prim::Constant[value=0]()
  %28 : Tensor = aten::slice(%26, %27, %2, %3, %4)
  %30 : int[] = prim::Constant[value=[-1]]()
  %31 : Tensor = aten::view(%28, %30)
  %32 : Tensor?[] = prim::ListConstruct(%31)
  %33 : Float(5:1, requires_grad=0, device=cpu) = aten::index_put(%x, %32, %18, %13)
  %15 : Long(requires_grad=0, device=cpu) = prim::Constant[value={2}]() # test.py:9:0
  %16 : Float(10:1, requires_grad=0, device=cpu) = aten::mul(%33, %15) # test.py:9:0
  return (%16)

i.e. It seems torch._C._jit_pass_onnx_prepare_inplace_ops_for_onnx does some of the requirement you mentioned (first and third, not sure about second)?
If we convert this graph to relay using scatter_nd, I think the output tensor would be correct but we cannot modify the input like torch does.

But right, your second example results in

graph(%x : Float(10:1, requires_grad=0, device=cpu)):
  %20 : Long(requires_grad=0, device=cpu) = prim::Constant[value={2}]() # test.py:15:0
  %21 : Float(10:1, requires_grad=0, device=cpu) = aten::mul(%x, %20) # test.py:15:0
  return (%21)

because torch._C._jit_pass_dce_allow_deleting_nodes_with_side_effects will consider side-effect only operation as dead code (y is not used to compute output).

@t-vi
Copy link
Contributor

t-vi commented Jan 8, 2021

Maybe they tried to take the same matching shortcut internally...

@masahi
Copy link
Member

masahi commented Jan 8, 2021

I experimented with this approach to support inplace update in torchvision faster rcnn / mask rcnn:

https://github.com/pytorch/vision/blob/6315358dd06e3a2bcbe9c1e8cdaa10898ac2b308/torchvision/ops/poolers.py#L269

This code path will not hit during tracing mode, so it is not an issue for us. But I tried anyway to experiment if aten::copy_ to aten::index_put conversion would work, and indeed for faster rcnn / maskrcnn it cleanly replaced aten::copy_.

I thought it might be promissing, but indeed there could be many limitation / corner cases ...

Copy link
Member

@masahi masahi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@codeislife99 As discussed, we still think adding half-baked conversion of copy_ is not a good idea. It might work in some cases, while it could generate a graph that is not consistent with Torch model. Please remove it.

@t-vi
Copy link
Contributor

t-vi commented Jan 8, 2021

I don't want to ruin the party, but does unsqueeze_ work as is?
We would want to update future use of the input to refer to the output (same for any of the "simple" inplace).
I can see that there is clearly interest in building support for inplace ops, but I think we would need to actually put in substantial effort if we want this to be a happy story.

@torch.jit.script
def foo(x):
  y = x.unsqueeze_(0)
  return x

As far as I know, we're only supporting "strided" tensors (i.e. blob of memory + sizes + strides define the tensor), so tracking views across these should be doable (one could, but doesn't have to) see the signature annotations in ATen's native_functions.yaml to see which of the ops we have need to be handled. One of the tricky ones would be reshape which is sometimes just a view and at other times (when viewing is impossible, eg a[:, :-1].reshape(-1)) it is a copy.

The other option could be to do this on the PyTorch side, improving the pass that @masahi highlighted.

@codeislife99
Copy link
Contributor Author

codeislife99 commented Jan 8, 2021

Thanks for the detailed discussion and uncovering some things which I didn't know about. I will remove copy_ from this PR since its not as trivial as I thought it would be.

@masahi
Copy link
Member

masahi commented Jan 8, 2021

I'm also not sure if inplace unsqueeze can be supported without any concern, or more generally how well we are doing with other inplace ops. @codeislife99 Does this change work for your real model (passing your contrived test is not enough) ?

@t-vi
Copy link
Contributor

t-vi commented Jan 8, 2021

@masahi , @codeislife99 : I mentioned this to one of the other PyTorch devs and he mentioned that there is a remove mutation pass in PyTorch that should take care of the cases where the PyTorch alias analysis can prove it is safe to remove it:
https://github.com/pytorch/pytorch/blob/aa18d174553e1c2ab4d9e09ae45c6ac323f04af5/torch/csrc/jit/passes/remove_mutation.h#L92
It doesn't do the slice bits, obviously, but I think running that before conversion might be an alternative that seems better than the implicit assumption it is safe.

@masahi
Copy link
Member

masahi commented Jan 9, 2021

Interesting, if that can convert unsqueeze_ to unsqueeze when Torch can prove it safe, that works great for us. @codeislife99 Do you want to try that?

@codeislife99
Copy link
Contributor Author

I am still waiting for the full scripted model from my customer. I will try the above remedy once that is ready and then I will make changes to this PR.

@codeislife99
Copy link
Contributor Author

Also thanks to @t-vi @masahi for all the detailed discussion and help. I really appreciate it

@codeislife99
Copy link
Contributor Author

codeislife99 commented Jan 13, 2021

I have removed copy_ from this PR and added a few ops after unsqueeze in the test to make sure it works.

@codeislife99 codeislife99 changed the title Adding aten::unsqueeze_ and aten::copy_ ops to PT Frontend Adding aten::unsqueeze_ to PT Frontend Jan 13, 2021
@codeislife99
Copy link
Contributor Author

The inplace op recommendation using torch._C._jit_pass_remove_mutation(graph) didn't work. So I resorted back to the previous option.

@masahi
Copy link
Member

masahi commented Jan 13, 2021

Later maybe we should look into more detail why remove mutation pass is not working.

@masahi masahi merged commit 259652b into apache:main Jan 13, 2021
@masahi
Copy link
Member

masahi commented Jan 13, 2021

Thanks @codeislife99 @t-vi

@t-vi
Copy link
Contributor

t-vi commented Jan 13, 2021

FWIW, I don't think the test is really effective in ensuring that it's good.

@torch.jit.script
def fn(x):
  _ = x.unsqueeze_(2)
  y = x *2
  return y
m,p = tvm.relay.frontend.from_pytorch(fn, [('input', [5, 5])])
m2 = tvm.relay.transform.InferType()(m)
print(m2['main'].body.checked_type)
print(fn(torch.randn(5,5)).shape)

gives

Tensor[(5, 5), float32]
torch.Size([5, 5, 1])

@masahi
Copy link
Member

masahi commented Jan 14, 2021

@codeislife99 can you follow up?

masahi pushed a commit to masahi/tvm that referenced this pull request Jan 14, 2021
* Added Ops

* Regular

* Remove copy

* Remove copy

* Tests

* Black

Co-authored-by: Ubuntu <[email protected]>
Co-authored-by: Ubuntu <[email protected]>
@codeislife99
Copy link
Contributor Author

@t-vi WHen I try to use your example as a test-case , it works. I am trying to understand what the underlying problem is ...

@t-vi
Copy link
Contributor

t-vi commented Jan 14, 2021

I don't know what the test case does, but I'm relatively sure it doesn't handle the update of the input by running the above example and getting a wrong shape in TVM.

masahi pushed a commit to masahi/tvm that referenced this pull request Jan 18, 2021
* Added Ops

* Regular

* Remove copy

* Remove copy

* Tests

* Black

Co-authored-by: Ubuntu <[email protected]>
Co-authored-by: Ubuntu <[email protected]>
TusharKanekiDey pushed a commit to TusharKanekiDey/tvm that referenced this pull request Jan 20, 2021
* Added Ops

* Regular

* Remove copy

* Remove copy

* Tests

* Black

Co-authored-by: Ubuntu <[email protected]>
Co-authored-by: Ubuntu <[email protected]>
trevor-m pushed a commit to neo-ai/tvm that referenced this pull request Jan 21, 2021
* Added Ops

* Regular

* Remove copy

* Remove copy

* Tests

* Black

Co-authored-by: Ubuntu <[email protected]>
Co-authored-by: Ubuntu <[email protected]>
@codeislife99
Copy link
Contributor Author

    class Unsqueeze3(Module):
        def forward(self, *args):
            _ = args[0].unsqueeze_(2)
            y = args[0] * 2
            return y

This is the testcase that I added to the test_forward_unsqueeze() function and it seems to pass , and I checked the shape of y. It shows it to be [5, 5, 1].

@t-vi
Copy link
Contributor

t-vi commented Jan 21, 2021

The subtlety is that the JIT tracer actually resolves this during tracing, but it will be buggy if the inplace input being used shows in the graph, e.g. through scripted parts of the model.
I can't say I'm terribly happy with the outcome here.

@t-vi
Copy link
Contributor

t-vi commented Jan 21, 2021

To keep things transparent, I'm copying a terse bit of explanation from slack:

To see what goes wrong and why the test doesn't catch it, compare the graph for traced and scripted in the above function (or in the equivalent module). They crucially differ in the input to mul.

So for just unsqueeze_ you could get by with the same trick of just replacing the node you recorded for the input by the output.

To properly support inplace (say for copy_), you would need to keep track of aliases in the TVM PyTorch frontend. I think that it is possible, but you need to actually implement the bookkeeping it involves (i.e. keep track of which tensor is a view of what and which operations have been carried out on them, then when the tensor is accessed, you need to check whether you need to "build" the inplace-modified tensor). It is a bit of work, but it certainly is possible if you're determined.

electriclilies pushed a commit to electriclilies/tvm that referenced this pull request Feb 18, 2021
* Added Ops

* Regular

* Remove copy

* Remove copy

* Tests

* Black

Co-authored-by: Ubuntu <[email protected]>
Co-authored-by: Ubuntu <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants