Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support all cases of lowering of ConvertAtenViewOp #2567

Open
newling opened this issue Nov 10, 2023 · 9 comments
Open

Support all cases of lowering of ConvertAtenViewOp #2567

newling opened this issue Nov 10, 2023 · 9 comments

Comments

@newling
Copy link
Collaborator

newling commented Nov 10, 2023

There are a few cases where aten.view op fails to lower to tensor.expand_shape and tensor.collapse_shape.

There are 2 cases which need investigating:

Case 1

Where it is not possible to lower to ONLY tensor.expand_shape and tensor.collapse_shape, because of the constraint that you cannot expand one dimension to multiple dynamic dimensions. If you try this, you will hit:

https://github.com/llvm/llvm-project/blob/2bac7201018dcab549895c30c0eb26bee45d842f/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp#L239

For this case, we will need to include a tensor.reshape op in the mix -- only for the contiguous dimensions which need it though (regions away from shape edges with multiple dynamic dims) -- we can still use expand_shape and collapse_shape for the dimensions which have at most 1 dynamic dimension.

Case 2

Where it is possible to lower to only tensor.expand_shape and tensor.collapse_shape, but isn't currently supported. There aren't many, but 2 examples are:

  • Cases where the input or output is rank-0.
  • Cases where multiple -1's in the output shape can be inferred to be +1.

There is inaccurate documentation in the pass: it is claimed in 2 places that [2,3] -> [3,2] is not supported, but it is supported (there is a lit test for this exact case)

Related to this pass, I think that the core logic can be significantly simplified while still supporting all the cases it currently supports (I'm still working through this, so please don't quote me on this :-) happy to put up my WIP PR)

So I'd like to propose the following 2 steps:

  1. Simplify code, while still supporting all cases (will add additional lit tests)
  2. Implement case 1
@newling
Copy link
Collaborator Author

newling commented Nov 10, 2023

I am currently trying to put this in the broader context of
#2515 and
llvm/llvm-project#69267

@dan-garvey
Copy link
Collaborator

So with the downstream changes you're expecting to be able to lower ambiguous reshapes and let the logic there handle it?

@newling
Copy link
Collaborator Author

newling commented Nov 13, 2023

That wasn't my thinking initially, although that may be possible (not sure though). There is a tensor.reshape op that I was thinking of using as the fallback.

@newling
Copy link
Collaborator Author

newling commented Nov 14, 2023

@qedawkins and @ramiro050 -- can I get a second opinion on this please. It's lowering of aten.view to linalg related. I might be misunderstanding something, but the generated IR doesn't look correct to me.

Input:

func.func @torch.aten.view$castingView (%arg0 : !torch.vtensor<[?,?,?], f32>) -> !torch.vtensor<[3,4,5], f32> {
  %int3 = torch.constant.int 3
  %int4 = torch.constant.int 4
  %int5 = torch.constant.int 5
  %0 = torch.prim.ListConstruct %int3, %int4, %int5 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
  %1 = torch.aten.view %arg0, %0 : !torch.vtensor<[?,?,?], f32>, !torch.list<int> -> !torch.vtensor<[3,4,5], f32>
  return %1 : !torch.vtensor<[3,4,5], f32>
}

Lowering to linalg (on main branch) with

./torch-mlir-opt --convert-torch-to-linalg --canonicalize  above.mlir

produces the following IR:

func.func @torch.aten.view$castingView(%arg0: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[3,4,5],f32> {
  %c4_i64 = arith.constant 4 : i64
  %c3_i64 = arith.constant 3 : i64
  %c1 = arith.constant 1 : index
  %c0 = arith.constant 0 : index
  %0 = torch_c.to_builtin_tensor %arg0 : !torch.vtensor<[?,?,?],f32> -> tensor<?x?x?xf32>
  %dim = tensor.dim %0, %c0 : tensor<?x?x?xf32>
  %dim_0 = tensor.dim %0, %c1 : tensor<?x?x?xf32>
  %1 = arith.index_cast %dim : index to i64
  %2 = arith.cmpi eq, %1, %c3_i64 : i64
  cf.assert %2, "mismatching contracting dimension"
  %3 = arith.index_cast %dim_0 : index to i64
  %4 = arith.cmpi eq, %3, %c4_i64 : i64
  cf.assert %4, "mismatching contracting dimension"
  %cast = tensor.cast %0 : tensor<?x?x?xf32> to tensor<3x4x5xf32>
  %5 = torch_c.from_builtin_tensor %cast : tensor<3x4x5xf32> -> !torch.vtensor<[3,4,5],f32>
  return %5 : !torch.vtensor<[3,4,5],f32>
}

This doesn't look right to me, because I would expect an input of shape [1,1,120] in the (first) torch dialect function to be valid, yet it seems like it results in an assertion failure in the lowered linalg (second) function. Am I missing something, or is this a potential bug?

This came up when running #2573 through CI.

@qedawkins
Copy link
Collaborator

This doesn't look right to me, because I would expect an input of shape [1,1,120] in the (first) torch dialect function to be valid, yet it seems like it results in an assertion failure in the lowered linalg (second) function. Am I missing something, or is this a potential bug?

[1, 1, 60] -> [3, 4, 5] sounds valid to me, that said, the generated IR certainly looks wrong. This might be a side effect of some optimistic assumptions made in the view op lowering to handle "real world" dynamic cases. The only way to handle that view is with a full collapse (+ assert dynamic size == 60) into a full expand to the desired shape.

@qedawkins
Copy link
Collaborator

Cases where multiple -1's in the output shape can be inferred to be +1.

Also I don't quite follow what this means. My understanding of view/reshape required a single inferred (-1) dimension: https://pytorch.org/docs/stable/generated/torch.reshape.html#torch.reshape

@newling
Copy link
Collaborator Author

newling commented Nov 14, 2023

[1, 1, 60] -> [3, 4, 5] sounds valid to me, that said, the generated IR certainly looks wrong. This might be a side effect of some optimistic assumptions made in the view op lowering to handle "real world" dynamic cases. The only way to handle that view is with a full collapse (+ assert dynamic size == 60) into a full expand to the desired shape.

Thanks for verifying @qedawkins. I will try and make a fix for this. Ah yes, 3x4x5 = 60!

@newling
Copy link
Collaborator Author

newling commented Nov 14, 2023

Cases where multiple -1's in the output shape can be inferred to be +1.

Also I don't quite follow what this means. My understanding of view/reshape required a single inferred (-1) dimension: https://pytorch.org/docs/stable/generated/torch.reshape.html#torch.reshape

This is at the MLIR level where an argument to function can have multiple dimensions of unknows size. I should have used 'kDynamic' instead of '-1'.

@dan-garvey
Copy link
Collaborator

So there has been a lot of discussion around lowering view to reshape as a fallback, see: #2515

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants