Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

failed to legalize operation 'stream.cmd.dispatch' that was explicitly marked illegal #18631

Open
pdhirajkumarprasad opened this issue Sep 30, 2024 · 12 comments
Assignees
Labels
bug 🐞 Something isn't working

Comments

@pdhirajkumarprasad
Copy link

What happened?

For the give IR

#map = affine_map<() -> ()>
#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
#map2 = affine_map<(d0, d1, d2, d3) -> ()>
#map3 = affine_map<()[s0] -> (s0 floordiv 12)>
module {
  ml_program.global private mutable @global_seed(dense<0> : tensor<i64>) : tensor<i64>
  func.func @main_graph(%arg0: tensor<?x?xi64>, %arg1: tensor<?x?xi64>, %arg2: tensor<?x?xi1>, %arg3: tensor<?x12x64x?xf32>, %arg4: tensor<?x?x768xf32>, %arg5: tensor<4xi64>, %arg6: tensor<?x?x12x64xf32>) -> tensor<?x12x?x?xf32> {
    %cst = arith.constant dense<8.000000e+00> : tensor<f32>
    %c1_i64 = arith.constant 1 : i64
    %c0_i64 = arith.constant 0 : i64
    %c0 = arith.constant 0 : index
    %c1 = arith.constant 1 : index
    %c12 = arith.constant 12 : index
    %c3 = arith.constant 3 : index
    %cst_0 = arith.constant 0.000000e+00 : f32
    %c768_i64 = arith.constant 768 : i64
    %cst_1 = arith.constant dense<-3.40282347E+38> : tensor<f32>
    %cst_2 = arith.constant dense<768> : tensor<i64>
    %extracted_slice = tensor.extract_slice %arg5[0] [1] [1] : tensor<4xi64> to tensor<1xi64>
    %extracted = tensor.extract %extracted_slice[%c0] : tensor<1xi64>
    %0 = arith.cmpi eq, %extracted, %c0_i64 : i64
    %dim = tensor.dim %arg4, %c0 : tensor<?x?x768xf32>
    %1 = arith.index_cast %dim : index to i64
    %2 = tensor.empty() : tensor<i1>
    %3 = linalg.fill ins(%0 : i1) outs(%2 : tensor<i1>) -> tensor<i1>
    %4 = tensor.empty() : tensor<i64>
    %5 = linalg.fill ins(%1 : i64) outs(%4 : tensor<i64>) -> tensor<i64>
    %6 = linalg.fill ins(%extracted : i64) outs(%4 : tensor<i64>) -> tensor<i64>
    %7 = linalg.generic {indexing_maps = [#map, #map, #map, #map], iterator_types = []} ins(%3, %5, %6 : tensor<i1>, tensor<i64>, tensor<i64>) outs(%4 : tensor<i64>) {
    ^bb0(%in: i1, %in_26: i64, %in_27: i64, %out: i64):
      %108 = arith.select %in, %in_26, %in_27 : i64
      linalg.yield %108 : i64
    } -> tensor<i64>
    %extracted_3 = tensor.extract %7[] : tensor<i64>
    %extracted_slice_4 = tensor.extract_slice %arg5[1] [1] [1] : tensor<4xi64> to tensor<1xi64>
    %extracted_5 = tensor.extract %extracted_slice_4[%c0] : tensor<1xi64>
    %8 = arith.cmpi eq, %extracted_5, %c0_i64 : i64
    %dim_6 = tensor.dim %arg4, %c1 : tensor<?x?x768xf32>
    %9 = arith.index_cast %dim_6 : index to i64
    %10 = linalg.fill ins(%8 : i1) outs(%2 : tensor<i1>) -> tensor<i1>
    %11 = linalg.fill ins(%9 : i64) outs(%4 : tensor<i64>) -> tensor<i64>
    %12 = linalg.fill ins(%extracted_5 : i64) outs(%4 : tensor<i64>) -> tensor<i64>
    %13 = linalg.generic {indexing_maps = [#map, #map, #map, #map], iterator_types = []} ins(%10, %11, %12 : tensor<i1>, tensor<i64>, tensor<i64>) outs(%4 : tensor<i64>) {
    ^bb0(%in: i1, %in_26: i64, %in_27: i64, %out: i64):
      %108 = arith.select %in, %in_26, %in_27 : i64
      linalg.yield %108 : i64
    } -> tensor<i64>
    %extracted_7 = tensor.extract %13[] : tensor<i64>
    %extracted_slice_8 = tensor.extract_slice %arg5[2] [1] [1] : tensor<4xi64> to tensor<1xi64>
    %extracted_9 = tensor.extract %extracted_slice_8[%c0] : tensor<1xi64>
    %14 = arith.cmpi eq, %extracted_9, %c0_i64 : i64
    %15 = linalg.fill ins(%14 : i1) outs(%2 : tensor<i1>) -> tensor<i1>
    %16 = linalg.fill ins(%extracted_9 : i64) outs(%4 : tensor<i64>) -> tensor<i64>
    %17 = linalg.generic {indexing_maps = [#map, #map, #map, #map], iterator_types = []} ins(%15, %cst_2, %16 : tensor<i1>, tensor<i64>, tensor<i64>) outs(%4 : tensor<i64>) {
    ^bb0(%in: i1, %in_26: i64, %in_27: i64, %out: i64):
      %108 = arith.select %in, %in_26, %in_27 : i64
      linalg.yield %108 : i64
    } -> tensor<i64>
    %extracted_10 = tensor.extract %17[] : tensor<i64>
    %extracted_slice_11 = tensor.extract_slice %arg5[3] [1] [1] : tensor<4xi64> to tensor<1xi64>
    %extracted_12 = tensor.extract %extracted_slice_11[%c0] : tensor<1xi64>
    %18 = arith.cmpi slt, %extracted_3, %c0_i64 : i64
    %19 = arith.select %18, %c1_i64, %extracted_3 : i64
    %20 = arith.extui %18 : i1 to i64
    %21 = arith.muli %19, %extracted_7 : i64
    %22 = arith.addi %20, %c1_i64 : i64
    %23 = arith.cmpi slt, %extracted_7, %c0_i64 : i64
    %24 = arith.select %23, %19, %21 : i64
    %25 = arith.select %23, %22, %20 : i64
    %26 = arith.muli %24, %extracted_10 : i64
    %27 = arith.addi %25, %c1_i64 : i64
    %28 = arith.cmpi slt, %extracted_10, %c0_i64 : i64
    %29 = arith.select %28, %24, %26 : i64
    %30 = arith.select %28, %27, %25 : i64
    %31 = arith.muli %29, %extracted_12 : i64
    %32 = arith.addi %30, %c1_i64 : i64
    %33 = arith.cmpi slt, %extracted_12, %c0_i64 : i64
    %34 = arith.select %33, %29, %31 : i64
    %35 = arith.select %33, %32, %30 : i64
    %36 = arith.cmpi sle, %35, %c1_i64 : i64
    cf.assert %36, "must have at most one inferred (negative) dimension"
    %37 = arith.muli %1, %9 : i64
    %38 = arith.muli %37, %c768_i64 : i64
    %39 = arith.divsi %38, %34 : i64
    %40 = arith.select %18, %39, %extracted_3 : i64
    %41 = arith.select %23, %39, %extracted_7 : i64
    %42 = arith.select %28, %39, %extracted_10 : i64
    %43 = arith.select %33, %39, %extracted_12 : i64
    %from_elements = tensor.from_elements %40, %41, %42, %43 : tensor<4xi64>
    %reshape = tensor.reshape %arg4(%from_elements) : (tensor<?x?x768xf32>, tensor<4xi64>) -> tensor<?x?x12x64xf32>
    %44 = arith.index_cast %40 : i64 to index
    %45 = arith.index_cast %41 : i64 to index
    %46 = tensor.empty(%44, %45) : tensor<?x12x?x64xf32>
    %transposed = linalg.transpose ins(%reshape : tensor<?x?x12x64xf32>) outs(%46 : tensor<?x12x?x64xf32>) permutation = [0, 2, 1, 3] 
    %47 = linalg.generic {indexing_maps = [#map1, #map2, #map1], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%transposed, %cst : tensor<?x12x?x64xf32>, tensor<f32>) outs(%46 : tensor<?x12x?x64xf32>) {
    ^bb0(%in: f32, %in_26: f32, %out: f32):
      %108 = arith.divf %in, %in_26 : f32
      linalg.yield %108 : f32
    } -> tensor<?x12x?x64xf32>
    %dim_13 = tensor.dim %arg3, %c0 : tensor<?x12x64x?xf32>
    %48 = arith.maxui %44, %dim_13 : index
    %dim_14 = tensor.dim %arg3, %c3 : tensor<?x12x64x?xf32>
    %collapsed = tensor.collapse_shape %47 [[0, 1], [2], [3]] : tensor<?x12x?x64xf32> into tensor<?x?x64xf32>
    %collapsed_15 = tensor.collapse_shape %arg3 [[0, 1], [2], [3]] : tensor<?x12x64x?xf32> into tensor<?x64x?xf32>
    %49 = arith.muli %48, %c12 : index
    %50 = tensor.empty(%49, %45, %dim_14) : tensor<?x?x?xf32>
    %51 = linalg.fill ins(%cst_0 : f32) outs(%50 : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
    %52 = linalg.batch_matmul ins(%collapsed, %collapsed_15 : tensor<?x?x64xf32>, tensor<?x64x?xf32>) outs(%51 : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
    %53 = arith.divui %49, %c12 : index
    %expanded = tensor.expand_shape %52 [[0, 1], [2], [3]] output_shape [%53, 12, %45, %dim_14] : tensor<?x?x?xf32> into tensor<?x12x?x?xf32>
    %dim_16 = tensor.dim %arg2, %c0 : tensor<?x?xi1>
    %54 = arith.index_cast %dim_16 : index to i64
    %55 = linalg.fill ins(%54 : i64) outs(%4 : tensor<i64>) -> tensor<i64>
    %56 = linalg.generic {indexing_maps = [#map, #map, #map, #map], iterator_types = []} ins(%3, %55, %6 : tensor<i1>, tensor<i64>, tensor<i64>) outs(%4 : tensor<i64>) {
    ^bb0(%in: i1, %in_26: i64, %in_27: i64, %out: i64):
      %108 = arith.select %in, %in_26, %in_27 : i64
      linalg.yield %108 : i64
    } -> tensor<i64>
    %extracted_17 = tensor.extract %56[] : tensor<i64>
    %dim_18 = tensor.dim %arg2, %c1 : tensor<?x?xi1>
    %57 = arith.index_cast %dim_18 : index to i64
    %58 = linalg.fill ins(%57 : i64) outs(%4 : tensor<i64>) -> tensor<i64>
    %59 = linalg.generic {indexing_maps = [#map, #map, #map, #map], iterator_types = []} ins(%10, %58, %12 : tensor<i1>, tensor<i64>, tensor<i64>) outs(%4 : tensor<i64>) {
    ^bb0(%in: i1, %in_26: i64, %in_27: i64, %out: i64):
      %108 = arith.select %in, %in_26, %in_27 : i64
      linalg.yield %108 : i64
    } -> tensor<i64>
    %extracted_19 = tensor.extract %59[] : tensor<i64>
    %60 = arith.cmpi slt, %extracted_17, %c0_i64 : i64
    %61 = arith.select %60, %c1_i64, %extracted_17 : i64
    %62 = arith.extui %60 : i1 to i64
    %63 = arith.muli %61, %extracted_19 : i64
    %64 = arith.addi %62, %c1_i64 : i64
    %65 = arith.cmpi slt, %extracted_19, %c0_i64 : i64
    %66 = arith.select %65, %61, %63 : i64
    %67 = arith.select %65, %64, %62 : i64
    %68 = arith.muli %66, %extracted_9 : i64
    %69 = arith.addi %67, %c1_i64 : i64
    %70 = arith.cmpi slt, %extracted_9, %c0_i64 : i64
    %71 = arith.select %70, %66, %68 : i64
    %72 = arith.select %70, %69, %67 : i64
    %73 = arith.muli %71, %extracted_12 : i64
    %74 = arith.addi %72, %c1_i64 : i64
    %75 = arith.select %33, %71, %73 : i64
    %76 = arith.select %33, %74, %72 : i64
    %77 = arith.cmpi sle, %76, %c1_i64 : i64
    cf.assert %77, "must have at most one inferred (negative) dimension"
    %78 = arith.muli %54, %57 : i64
    %79 = arith.divsi %78, %75 : i64
    %80 = arith.select %60, %79, %extracted_17 : i64
    %81 = arith.select %65, %79, %extracted_19 : i64
    %82 = arith.select %70, %79, %extracted_9 : i64
    %83 = arith.select %33, %79, %extracted_12 : i64
    %from_elements_20 = tensor.from_elements %80, %81, %82, %83 : tensor<4xi64>
    %reshape_21 = tensor.reshape %arg2(%from_elements_20) : (tensor<?x?xi1>, tensor<4xi64>) -> tensor<?x1x1x?xi1>
    %84 = affine.apply #map3()[%49]
    %85 = arith.index_cast %84 : index to i64
    %86 = arith.index_cast %dim_14 : index to i64
    %87 = tensor.empty() : tensor<1xi64>
    %collapsed_22 = tensor.collapse_shape %87 [] : tensor<1xi64> into tensor<i64>
    %88 = linalg.fill ins(%85 : i64) outs(%collapsed_22 : tensor<i64>) -> tensor<i64>
    %extracted_23 = tensor.extract %88[] : tensor<i64>
    %89 = arith.maxsi %extracted_23, %80 : i64
    %90 = linalg.fill ins(%41 : i64) outs(%collapsed_22 : tensor<i64>) -> tensor<i64>
    %extracted_24 = tensor.extract %90[] : tensor<i64>
    %91 = arith.maxsi %extracted_24, %c1_i64 : i64
    %92 = linalg.fill ins(%86 : i64) outs(%collapsed_22 : tensor<i64>) -> tensor<i64>
    %extracted_25 = tensor.extract %92[] : tensor<i64>
    %93 = arith.maxsi %extracted_25, %83 : i64
    %94 = arith.index_cast %89 : i64 to index
    %95 = arith.cmpi sge, %89, %c0_i64 : i64
    cf.assert %95, "unimplemented: dynamic negative broadcast sizes"
    %96 = arith.cmpi slt, %91, %c0_i64 : i64
    %97 = arith.index_cast %91 : i64 to index
    %98 = arith.select %96, %c1, %97 : index
    %99 = arith.index_cast %93 : i64 to index
    %100 = arith.cmpi sge, %93, %c0_i64 : i64
    cf.assert %100, "unimplemented: dynamic negative broadcast sizes"
    %101 = tensor.empty(%94, %98, %99) : tensor<?x12x?x?xi1>
    %102 = linalg.generic {indexing_maps = [#map1], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} outs(%101 : tensor<?x12x?x?xi1>) {
    ^bb0(%out: i1):
      %108 = linalg.index 0 : index
      %109 = linalg.index 3 : index
      %110 = arith.index_cast %80 : i64 to index
      %111 = arith.cmpi eq, %110, %c1 : index
      %112 = arith.select %111, %c0, %108 : index
      %113 = arith.index_cast %83 : i64 to index
      %114 = arith.cmpi eq, %113, %c1 : index
      %115 = arith.select %114, %c0, %109 : index
      %extracted_26 = tensor.extract %reshape_21[%112, %c0, %c0, %115] : tensor<?x1x1x?xi1>
      linalg.yield %extracted_26 : i1
    } -> tensor<?x12x?x?xi1>
    %103 = arith.cmpi eq, %94, %84 : index
    cf.assert %103, "mismatched size for broadcast"
    %104 = arith.cmpi eq, %98, %45 : index
    cf.assert %104, "mismatched size for broadcast"
    %105 = arith.cmpi eq, %99, %dim_14 : index
    cf.assert %105, "mismatched size for broadcast"
    %106 = tensor.empty(%94, %98, %99) : tensor<?x12x?x?xf32>
    %107 = linalg.generic {indexing_maps = [#map1, #map2, #map1, #map1], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%102, %cst_1, %expanded : tensor<?x12x?x?xi1>, tensor<f32>, tensor<?x12x?x?xf32>) outs(%106 : tensor<?x12x?x?xf32>) {
    ^bb0(%in: i1, %in_26: f32, %in_27: f32, %out: f32):
      %108 = arith.select %in, %in_26, %in_27 : f32
      linalg.yield %108 : f32
    } -> tensor<?x12x?x?xf32>
    return %107 : tensor<?x12x?x?xf32>
  }
}

getting error as

tt.mlir:200:12: error: failed to legalize operation 'stream.cmd.dispatch' that was explicitly marked illegal
    %107 = linalg.generic {indexing_maps = [#map1, #map2, #map1, #map1], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%102, %cst_1, %expanded : tensor<?x12x?x?xi1>, tensor<f32>, tensor<?x12x?x?xf32>) outs(%106 : tensor<?x12x?x?xf32>) {

this linalg IR was generated with following ONNX IR

module {
  func.func @main_graph(%arg0: !torch.vtensor<[?,?],si64>, %arg1: !torch.vtensor<[?,?],si64>, %arg2: !torch.vtensor<[?,?],i1>, %arg3: !torch.vtensor<[?,12,64,?],f32>, %arg4: !torch.vtensor<[?,?,768],f32>, %arg5: !torch.vtensor<[4],si64> , %arg6: !torch.vtensor<[?,?,12,64],f32>) -> !torch.vtensor<[?,12,?,?],f32>  attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 21 : si64, torch.onnx_meta.producer_name = "pytorch", torch.onnx_meta.producer_version = "2.4.0"} {
    %1 = torch.operator "onnx.Reshape"(%arg4, %arg5) : (!torch.vtensor<[?,?,768],f32>, !torch.vtensor<[4],si64>) -> !torch.vtensor<[?,?,12,64],f32> 
    %2 = torch.operator "onnx.Transpose"(%1) {torch.onnx.perm = [0 : si64, 2 : si64, 1 : si64, 3 : si64]} : (!torch.vtensor<[?,?,12,64],f32>) -> !torch.vtensor<[?,12,?,64],f32> 
    %3 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<__18> : tensor<f32>} : () -> !torch.vtensor<[],f32> 
    %4 = torch.operator "onnx.Div"(%2, %3) : (!torch.vtensor<[?,12,?,64],f32>, !torch.vtensor<[],f32>) -> !torch.vtensor<[?,12,?,64],f32> 
    %5 = torch.operator "onnx.MatMul"(%4, %arg3) : (!torch.vtensor<[?,12,?,64],f32>, !torch.vtensor<[?,12,64,?],f32>) -> !torch.vtensor<[?,12,?,?],f32> 
    %6 = torch.operator "onnx.Reshape"(%arg2, %arg5) : (!torch.vtensor<[?,?],i1>, !torch.vtensor<[4],si64>) -> !torch.vtensor<[?,1,1,?],i1> 
    %7 = torch.operator "onnx.Shape"(%5) : (!torch.vtensor<[?,12,?,?],f32>) -> !torch.vtensor<[4],si64> 
    %8 = torch.operator "onnx.Expand"(%6, %7) : (!torch.vtensor<[?,1,1,?],i1>, !torch.vtensor<[4],si64>) -> !torch.vtensor<[?,12,?,?],i1> 
    %9 = torch.operator "onnx.Cast"(%8) {torch.onnx.to = 9 : si64} : (!torch.vtensor<[?,12,?,?],i1>) -> !torch.vtensor<[?,12,?,?],i1> 
    %10 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<__22> : tensor<f32>} : () -> !torch.vtensor<[],f32> 
    %11 = torch.operator "onnx.Where"(%9, %10, %5) : (!torch.vtensor<[?,12,?,?],i1>, !torch.vtensor<[],f32>, !torch.vtensor<[?,12,?,?],f32>) -> !torch.vtensor<[?,12,?,?],f32> 
    return %11 : !torch.vtensor<[?,12,?,?],f32>
  }
}

{-#
  dialect_resources: {
    builtin: {
      __18: "0x0800000000000041",
      __22: "0x08000000FFFF7FFF"
    }
  }
#-}

If I pass the ONNX IR directory, iree is compiling fine but when passign linalg IR, after lowing it through torch-mlir, it's failing with above error

Steps to reproduce your issue

command:

iree-compile --iree-input-type=auto --iree-vm-bytecode-module-output-format=flatbuffer-binary --iree-hal-target-backends=llvm-cpu --mlir-print-debuginfo --mlir-print-op-on-diagnostic=false temp.mlir

What component(s) does this issue relate to?

Compiler

Version information

No response

Additional context

No response

@pdhirajkumarprasad pdhirajkumarprasad added the bug 🐞 Something isn't working label Sep 30, 2024
@pashu123
Copy link
Contributor

pashu123 commented Sep 30, 2024

The problematic part is

 %102 = linalg.generic {indexing_maps = [#map1], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} outs(%101 : tensor<?x12x?x?xi1>) {
    ^bb0(%out: i1):
      %108 = linalg.index 0 : index
      %109 = linalg.index 3 : index
      %110 = arith.index_cast %80 : i64 to index
      %111 = arith.cmpi eq, %110, %c1 : index
      %112 = arith.select %111, %c0, %108 : index
      %113 = arith.index_cast %83 : i64 to index
      %114 = arith.cmpi eq, %113, %c1 : index
      %115 = arith.select %114, %c0, %109 : index
      %extracted_26 = tensor.extract %reshape_21[%112, %c0, %c0, %115] : tensor<?x1x1x?xi1>
      linalg.yield %extracted_26 : i1
    } -> tensor<?x12x?x?xi1> 

This looks like a broadcast to me. We can directly pass %reshape_21 op to the generic rather than accessing it from outside.
Could someone look at the onnx.Expand op lowering?

@zjgarvey zjgarvey self-assigned this Sep 30, 2024
@pdhirajkumarprasad
Copy link
Author

complete list of models failing due to this

dispatch_list.txt

@zjgarvey
Copy link
Contributor

I think I've figure out a problematic component of the Expand lowering. Going to make sure this is an appropriate fix and will post updates.

@zjgarvey
Copy link
Contributor

So removing some logic to take the max between the provided dim size and the input dim for broadcasting seems to unblock this issue, however, this logic is necessary for producing correct results in other cases.

onnx.Expand allows having provided shapes less than the input shape at a given dim (in which case it doesn't broadcast there, and would cause an issue if we didn't take the max).

Going to keep digging a bit.

@zjgarvey
Copy link
Contributor

zjgarvey commented Oct 2, 2024

I got it to work with some extra shape help in torch-mlir.

Will update soon.

@zjgarvey
Copy link
Contributor

zjgarvey commented Oct 2, 2024

Small reproducers are passing, but full models still fail on a further node.

@nirvedhmeshram
Copy link
Contributor

nirvedhmeshram commented Oct 2, 2024

If this can be fixed in the front-end then thats great but I think we should be able to support this in the compiler, in that regard I think the problem starts in the iree-codegen-tile-and-distribute-to-workgroups . I am sanity checking the input we have at this point, the output we get is not something that can be legalized. Both are provided in this gist

Here is the command I used

iree-opt tile_and_distribute_repro.mlir \
-pass-pipeline='builtin.module(hal.executable(hal.executable.variant(builtin.module(func.func(iree-codegen-tile-and-distribute-to-workgroups, canonicalize)), cse)))' \
&> ouput.mlir

CC @MaheshRavishankar to take a look at the IR as well.

@zjgarvey
Copy link
Contributor

zjgarvey commented Oct 2, 2024

llvm/torch-mlir#3756 addresses the compile failure by simplifying the IR getting generated for broadcast substantially, but slightly reduces some rare case coverage for onnx.Expand.

If you guys think that the old IR should just be supported in IREE anyway, then it might not be worth landing that PR.

@zjgarvey
Copy link
Contributor

zjgarvey commented Oct 2, 2024

With this PR to clean up some of the gross shape computations that aren't simplifying at the torch level, llvm/torch-mlir#3757, I was able to compile the failing models when returning on the first "Where" node, but interestingly was then failing to compile again when returning on the next, nearly identical "Where" node.

I think I've looked into ways to try and simplify the broadcast shapes as best as possible with the two PR's I've posted so far. I'm not sure what else we could do from the front-end.

@MaheshRavishankar
Copy link
Contributor

@zjgarvey you are probably already looking at it, but this kind of IR is really strange

 %dim = tensor.dim %arg4, %c0 : tensor<?x?x768xf32>
    %1 = arith.index_cast %dim : index to i64
    %2 = tensor.empty() : tensor<i1>
    %3 = linalg.fill ins(%0 : i1) outs(%2 : tensor<i1>) -> tensor<i1>
    %4 = tensor.empty() : tensor<i64>
    %5 = linalg.fill ins(%1 : i64) outs(%4 : tensor<i64>) -> tensor<i64>
    %6 = linalg.fill ins(%extracted : i64) outs(%4 : tensor<i64>) -> tensor<i64>
    %7 = linalg.generic {indexing_maps = [#map, #map, #map, #map], iterator_types = []} ins(%3, %5, %6 : tensor<i1>, tensor<i64>, tensor<i64>) outs(%4 : tensor<i64>) {
    ^bb0(%in: i1, %in_26: i64, %in_27: i64, %out: i64):
      %108 = arith.select %in, %in_26, %in_27 : i64
      linalg.yield %108 : i64
    } -> tensor<i64>
    %extracted_3 = tensor.extract %7[] : tensor<i64>

This is taking a dim of a tensor, creating a new tensor and inserting into it, then performing a linalg.generic operation to do a select? and then extracting it out..

Its basically shape computation that is artificially put into tensor math... IREE tries to do all tensor math on the device. So all of this computation, which is really shape computation that should be done on the host, is being transfered into device and then things go haywire cause then it is artifically looking like an indirect dispatch problem where the shape computation is dependent on previous computation on the device. The easiest fix is that the front end needs to not try to do shape computation as tensor math.

zjgarvey added a commit to llvm/torch-mlir that referenced this issue Oct 4, 2024
Addresses ~200 onnx model compile failures in
<https://github.com/nod-ai/SHARK-TestSuite> related to
<iree-org/iree#18631>.

This change simplifies the result of the generated broadcast op
substantially, but reduces the case coverage slightly.

The case which will become unsupported: 
- trying to actually broadcast a dynamic dim that is secretly 1. 

When does this case appear in practical scenarios?
- for a model where onnx shape inference cannot figure out that a dim
should be 1.

Why do I think we should not support this case for now?
1. For all models with dynamic dim expand ops, the previous path
uniformly generates uglier linalg IR (making it harder for IREE to fuse
properly with other ops).
2. For models failing shape inference castastrophically enough to fail
to see a dim is statically 1, we can try to apply constant folding in
the onnx model before importing.

Leaving this as a draft PR, since it may be more appropriate to fix the
compilation failure in IREE rather than torch-mlir.

### Example of broadcast required in previous path:

```mlir
    %300 = linalg.generic {indexing_maps = [#map11], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} outs(%299 : tensor<?x12x?x?xi1>) {
    ^bb0(%out: i1):
      %306 = linalg.index 0 : index
      %307 = linalg.index 3 : index
      %308 = arith.index_cast %285 : i64 to index
      %309 = arith.cmpi eq, %308, %c1 : index
      %310 = arith.select %309, %c0, %306 : index
      %311 = arith.index_cast %286 : i64 to index
      %312 = arith.cmpi eq, %311, %c1 : index
      %313 = arith.select %312, %c0, %307 : index
      %extracted_79 = tensor.extract %reshape_78[%310, %c0, %c0, %313] : tensor<?x1x1x?xi1>
      linalg.yield %extracted_79 : i1
    } -> tensor<?x12x?x?xi1>
```

### Example of broadcast with simplified shape list:

```mlir
    %409 = linalg.generic {indexing_maps = [#map15, #map11], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%reshape_135 : tensor<?x1x1x?xi1>) outs(%408 : tensor<?x12x?x?xi1>) {
    ^bb0(%in: i1, %out: i1):
      linalg.yield %in : i1
    } -> tensor<?x12x?x?xi1>
```
zjgarvey added a commit to llvm/torch-mlir that referenced this issue Oct 4, 2024
… generic ops (#3762)

This is motivated by the fact that shapes are stored as tensors in ONNX,
and IREE tries to perform tensor arithmetic on the device. This causes
unnecessary dispatches, and makes it harder for the compiler to reason
about shapes.

Here is a small snippet of torch-IR that is typical seen coming from
ONNX models:

```mlir
module {
  func.func @main_graph(%arg0: !torch.vtensor<[?,?,768],f32>, %arg1: !torch.vtensor<[?,?,768],f32>) -> !torch.vtensor<[],si64> {
    %int0 = torch.constant.int 0
    %0 = torch.vtensor.literal(dense<0> : tensor<1xsi64>) : !torch.vtensor<[1],si64>
    %1 = torch.aten._shape_as_tensor %arg1 : !torch.vtensor<[?,?,768],f32> -> !torch.vtensor<[3],si64>
    %2 = torch.aten.index_select %1, %int0, %0 : !torch.vtensor<[3],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64>
    %3 = torch.aten.squeeze.dim %2, %int0 : !torch.vtensor<[1],si64>, !torch.int -> !torch.vtensor<[],si64>
    %4 = torch.aten.item %3 : !torch.vtensor<[],si64> -> !torch.int
    %5 = torch.aten.eq.int %4, %int0 : !torch.int, !torch.int -> !torch.bool
    %6 = torch.aten.Int.bool %5 : !torch.bool -> !torch.int
    %7 = torch.aten.size.int %arg0, %int0 : !torch.vtensor<[?,?,768],f32>, !torch.int -> !torch.int
    %8 = torch.prim.NumToTensor.Scalar %6 : !torch.int -> !torch.vtensor<[],i1>
    %9 = torch.prim.NumToTensor.Scalar %7 : !torch.int -> !torch.vtensor<[],si64>
    %10 = torch.prim.NumToTensor.Scalar %4 : !torch.int -> !torch.vtensor<[],si64>
    %11 = torch.aten.where.self %8, %9, %10 : !torch.vtensor<[],i1>, !torch.vtensor<[],si64>, !torch.vtensor<[],si64> -> !torch.vtensor<[],si64>
    return %11 : !torch.vtensor<[],si64>
  }
}
```

Without the change in this PR, the result would be:

```mlir
#map = affine_map<() -> ()>
module {
  ml_program.global private mutable @global_seed(dense<0> : tensor<i64>) : tensor<i64>
  func.func @main_graph(%arg0: tensor<?x?x768xf32>, %arg1: tensor<?x?x768xf32>) -> tensor<i64> {
    %c0_i64 = arith.constant 0 : i64
    %c0 = arith.constant 0 : index
    %dim = tensor.dim %arg1, %c0 : tensor<?x?x768xf32>
    %0 = arith.index_cast %dim : index to i64
    %1 = tensor.empty() : tensor<1xi64>
    %collapsed = tensor.collapse_shape %1 [] : tensor<1xi64> into tensor<i64>
    %2 = linalg.fill ins(%0 : i64) outs(%collapsed : tensor<i64>) -> tensor<i64>
    %extracted = tensor.extract %2[] : tensor<i64>
    %3 = arith.cmpi eq, %extracted, %c0_i64 : i64
    %dim_0 = tensor.dim %arg0, %c0 : tensor<?x?x768xf32>
    %4 = arith.index_cast %dim_0 : index to i64
    %5 = tensor.empty() : tensor<i1>
    %6 = linalg.fill ins(%3 : i1) outs(%5 : tensor<i1>) -> tensor<i1>
    %7 = tensor.empty() : tensor<i64>
    %8 = linalg.fill ins(%4 : i64) outs(%7 : tensor<i64>) -> tensor<i64>
    %9 = linalg.fill ins(%extracted : i64) outs(%7 : tensor<i64>) -> tensor<i64>
    %10 = linalg.generic {indexing_maps = [#map, #map, #map, #map], iterator_types = []} ins(%6, %8, %9 : tensor<i1>, tensor<i64>, tensor<i64>) outs(%7 : tensor<i64>) {
    ^bb0(%in: i1, %in_1: i64, %in_2: i64, %out: i64):
      %11 = arith.select %in, %in_1, %in_2 : i64
      linalg.yield %11 : i64
    } -> tensor<i64>
    return %10 : tensor<i64>
  }
}
```

With the change in this PR, we would instead get:

```mlir
module {
  ml_program.global private mutable @global_seed(dense<0> : tensor<i64>) : tensor<i64>
  func.func @main_graph(%arg0: tensor<?x?x768xf32>, %arg1: tensor<?x?x768xf32>) -> tensor<i64> {
    %c0_i64 = arith.constant 0 : i64
    %c0 = arith.constant 0 : index
    %dim = tensor.dim %arg1, %c0 : tensor<?x?x768xf32>
    %0 = arith.index_cast %dim : index to i64
    %1 = tensor.empty() : tensor<1xi64>
    %collapsed = tensor.collapse_shape %1 [] : tensor<1xi64> into tensor<i64>
    %2 = linalg.fill ins(%0 : i64) outs(%collapsed : tensor<i64>) -> tensor<i64>
    %extracted = tensor.extract %2[] : tensor<i64>
    %3 = arith.cmpi eq, %extracted, %c0_i64 : i64
    %dim_0 = tensor.dim %arg0, %c0 : tensor<?x?x768xf32>
    %4 = arith.index_cast %dim_0 : index to i64
    %5 = arith.select %3, %4, %extracted : i64
    %6 = tensor.empty() : tensor<i64>
    %7 = linalg.fill ins(%5 : i64) outs(%6 : tensor<i64>) -> tensor<i64>
    return %7 : tensor<i64>
  }
}
```

Some related issues for context:
1. <iree-org/iree#18677>
2. <iree-org/iree#18631>
@nirvedhmeshram
Copy link
Contributor

We need to generate new linalg IR for this issue with llvm/torch-mlir#3762 , since the ONXX IR has been already working when run directly from IREE

@zjgarvey
Copy link
Contributor

zjgarvey commented Oct 4, 2024

This is also resolved with the change llvm/torch-mlir#3756 so it might not reproduce an issue in IREE.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug 🐞 Something isn't working
Projects
Status: No status
Development

No branches or pull requests

6 participants