Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Missing dimensionality information in torch #3651

Closed
MaheshRavishankar opened this issue Aug 19, 2024 · 7 comments
Closed

Missing dimensionality information in torch #3651

MaheshRavishankar opened this issue Aug 19, 2024 · 7 comments
Assignees

Comments

@MaheshRavishankar
Copy link
Contributor

As a follow up from iree-org/iree#18229 it seems like there is some dimension information that is not being captured correctly in IR and recovering that in the program is pretty involved. This is the IR after torch-finalizing-backend-type-conversion

util.func public @torch_jit$async(%arg0: !hal.buffer_view, %arg1: !hal.fence, %arg2: !hal.fence) -> !hal.buffer_view attributes {inlining_policy = #util.inline.never, iree.abi.model = "coarse-fences", iree.abi.stub} {
  %c3 = arith.constant 3 : index
  %c2 = arith.constant 2 : index
  %c64_i64 = arith.constant 64 : i64
  %c12_i64 = arith.constant 12 : i64
  %c768_i64 = arith.constant 768 : i64
  %c512_i64 = arith.constant 512 : i64
  %c30522_i64 = arith.constant 30522 : i64
  %cst = arith.constant dense<9.99999996E-13> : tensor<f32>
  %cst_0 = arith.constant dense<2.000000e+00> : tensor<f32>
  %c2_i64 = arith.constant 2 : i64
  %cst_1 = arith.constant dense_resource<__onnx_constant_not_found_possibly_due_to_being_elided___8> : tensor<1x512xi64>
  %c0_i64 = arith.constant 0 : i64
  %cst_2 = arith.constant dense_resource<__onnx_constant_not_found_possibly_due_to_being_elided__> : tensor<1x512xi64>
  %cst_3 = arith.constant dense_resource<__onnx_constant_not_found_possibly_due_to_being_elided___1> : tensor<30522x768xf32>
  %cst_4 = arith.constant dense_resource<__onnx_constant_not_found_possibly_due_to_being_elided___2> : tensor<512x768xf32>
  %cst_5 = arith.constant dense_resource<__onnx_constant_not_found_possibly_due_to_being_elided___3> : tensor<2x768xf32>
  %cst_6 = arith.constant dense_resource<__onnx_constant_not_found_possibly_due_to_being_elided___4> : tensor<768xf32>
  %cst_7 = arith.constant dense_resource<__onnx_constant_not_found_possibly_due_to_being_elided___5> : tensor<768xf32>
  %cst_8 = arith.constant dense_resource<__onnx_constant_not_found_possibly_due_to_being_elided___6> : tensor<768xf32>
  %cst_9 = arith.constant dense_resource<__onnx_constant_not_found_possibly_due_to_being_elided___7> : tensor<768x768xf32>
  %c0 = arith.constant 0 : index
  %c1 = arith.constant 1 : index
  %c-1 = arith.constant -1 : index
  %c512 = arith.constant 512 : index
  %cst_10 = arith.constant 0.000000e+00 : f32
  %c1_i64 = arith.constant 1 : i64
  %cst_11 = arith.constant 7.680000e+02 : f32
  %0 = hal.buffer_view.dim<%arg0 : !hal.buffer_view>[0] : index
  %1 = hal.buffer_view.dim<%arg0 : !hal.buffer_view>[1] : index
  %2 = hal.tensor.import wait(%arg1) => %arg0 : !hal.buffer_view -> tensor<?x?xi64>{%0, %1}
  %dim = tensor.dim %2, %c1 : tensor<?x?xi64>
  %3 = arith.cmpi slt, %dim, %c0 : index
  %4 = arith.addi %dim, %c512 : index
  %5 = arith.select %3, %4, %dim : index
  %6 = arith.cmpi slt, %5, %c0 : index
  %7 = arith.select %6, %c-1, %5 : index
  %8 = arith.cmpi sgt, %7, %c512 : index
  %9 = arith.select %8, %c512, %7 : index
  %10 = arith.cmpi slt, %9, %c0 : index
  %11 = arith.select %10, %c0, %9 : index
  %extracted_slice = tensor.extract_slice %cst_1[0, 0] [1, %11] [1, 1] : tensor<1x512xi64> to tensor<1x?xi64>
  %extracted_slice_12 = tensor.extract_slice %cst_2[0, 0] [1, %11] [1, 1] : tensor<1x512xi64> to tensor<1x?xi64>
  %dim_13 = tensor.dim %2, %c0 : tensor<?x?xi64>
  %12 = tensor.empty(%dim_13, %dim) : tensor<?x?xi1>
  %13 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%2 : tensor<?x?xi64>) outs(%12 : tensor<?x?xi1>) {
  ^bb0(%in: i64, %out: i1):
    %63 = arith.cmpi slt, %in, %c0_i64 : i64
    linalg.yield %63 : i1
  } -> tensor<?x?xi1>
  %14 = tensor.empty(%dim_13, %dim) : tensor<?x?xi64>
  %15 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%2 : tensor<?x?xi64>) outs(%14 : tensor<?x?xi64>) {
  ^bb0(%in: i64, %out: i64):
    %63 = arith.addi %in, %c30522_i64 : i64
    linalg.yield %63 : i64
  } -> tensor<?x?xi64>
  %16 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%13, %15, %2 : tensor<?x?xi1>, tensor<?x?xi64>, tensor<?x?xi64>) outs(%14 : tensor<?x?xi64>) {
  ^bb0(%in: i1, %in_26: i64, %in_27: i64, %out: i64):
    %63 = arith.select %in, %in_26, %in_27 : i64
    linalg.yield %63 : i64
  } -> tensor<?x?xi64>
  %17 = arith.index_cast %dim_13 : index to i64
  %18 = arith.index_cast %dim : index to i64
  %collapsed = tensor.collapse_shape %16 [[0, 1]] : tensor<?x?xi64> into tensor<?xi64>
  %19 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%dim_13, %dim]
  %20 = tensor.empty(%19) : tensor<?x768xf32>
  %21 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%collapsed : tensor<?xi64>) outs(%20 : tensor<?x768xf32>) {
  ^bb0(%in: i64, %out: f32):
    %63 = arith.index_cast %in : i64 to index
    %64 = linalg.index 1 : index
    %extracted = tensor.extract %cst_3[%63, %64] : tensor<30522x768xf32>
    linalg.yield %extracted : f32
  } -> tensor<?x768xf32>
  %from_elements = tensor.from_elements %17, %18, %c768_i64 : tensor<3xi64>
  %reshape = tensor.reshape %21(%from_elements) : (tensor<?x768xf32>, tensor<3xi64>) -> tensor<?x?x768xf32>
  %22 = tensor.empty(%11) : tensor<1x?xi1>
  %23 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%extracted_slice : tensor<1x?xi64>) outs(%22 : tensor<1x?xi1>) {
  ^bb0(%in: i64, %out: i1):
    %63 = arith.cmpi slt, %in, %c0_i64 : i64
    linalg.yield %63 : i1
  } -> tensor<1x?xi1>
  %24 = tensor.empty(%11) : tensor<1x?xi64>
  %25 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%extracted_slice : tensor<1x?xi64>) outs(%24 : tensor<1x?xi64>) {
  ^bb0(%in: i64, %out: i64):
    %63 = arith.addi %in, %c2_i64 : i64
    linalg.yield %63 : i64
  } -> tensor<1x?xi64>
  %26 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%23, %25, %extracted_slice : tensor<1x?xi1>, tensor<1x?xi64>, tensor<1x?xi64>) outs(%24 : tensor<1x?xi64>) {
  ^bb0(%in: i1, %in_26: i64, %in_27: i64, %out: i64):
    %63 = arith.select %in, %in_26, %in_27 : i64
    linalg.yield %63 : i64
  } -> tensor<1x?xi64>
  %27 = arith.index_cast %11 : index to i64
  %collapsed_14 = tensor.collapse_shape %26 [[0, 1]] : tensor<1x?xi64> into tensor<?xi64>
  %28 = tensor.empty(%11) : tensor<?x768xf32>
  %29 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%collapsed_14 : tensor<?xi64>) outs(%28 : tensor<?x768xf32>) {
  ^bb0(%in: i64, %out: f32):
    %63 = arith.index_cast %in : i64 to index
    %64 = linalg.index 1 : index
    %extracted = tensor.extract %cst_5[%63, %64] : tensor<2x768xf32>
    linalg.yield %extracted : f32
  } -> tensor<?x768xf32>
  %from_elements_15 = tensor.from_elements %c1_i64, %27, %c768_i64 : tensor<3xi64>
  %reshape_16 = tensor.reshape %29(%from_elements_15) : (tensor<?x768xf32>, tensor<3xi64>) -> tensor<?x?x768xf32>
  %30 = arith.index_cast %17 : i64 to index
  %31 = arith.index_cast %18 : i64 to index
  %32 = tensor.empty(%30, %31) : tensor<?x?x768xf32>
  %33 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%reshape, %reshape_16 : tensor<?x?x768xf32>, tensor<?x?x768xf32>) outs(%32 : tensor<?x?x768xf32>) {
  ^bb0(%in: f32, %in_26: f32, %out: f32):
    %63 = arith.addf %in, %in_26 : f32
    linalg.yield %63 : f32
  } -> tensor<?x?x768xf32>
  %34 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%extracted_slice_12 : tensor<1x?xi64>) outs(%22 : tensor<1x?xi1>) {
  ^bb0(%in: i64, %out: i1):
    %63 = arith.cmpi slt, %in, %c0_i64 : i64
    linalg.yield %63 : i1
  } -> tensor<1x?xi1>
  %35 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%extracted_slice_12 : tensor<1x?xi64>) outs(%24 : tensor<1x?xi64>) {
  ^bb0(%in: i64, %out: i64):
    %63 = arith.addi %in, %c512_i64 : i64
    linalg.yield %63 : i64
  } -> tensor<1x?xi64>
  %36 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%34, %35, %extracted_slice_12 : tensor<1x?xi1>, tensor<1x?xi64>, tensor<1x?xi64>) outs(%24 : tensor<1x?xi64>) {
  ^bb0(%in: i1, %in_26: i64, %in_27: i64, %out: i64):
    %63 = arith.select %in, %in_26, %in_27 : i64
    linalg.yield %63 : i64
  } -> tensor<1x?xi64>
  %collapsed_17 = tensor.collapse_shape %36 [[0, 1]] : tensor<1x?xi64> into tensor<?xi64>
  %37 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%collapsed_17 : tensor<?xi64>) outs(%28 : tensor<?x768xf32>) {
  ^bb0(%in: i64, %out: f32):
    %63 = arith.index_cast %in : i64 to index
    %64 = linalg.index 1 : index
    %extracted = tensor.extract %cst_4[%63, %64] : tensor<512x768xf32>
    linalg.yield %extracted : f32
  } -> tensor<?x768xf32>
  %reshape_18 = tensor.reshape %37(%from_elements_15) : (tensor<?x768xf32>, tensor<3xi64>) -> tensor<?x?x768xf32>
  %38 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%33, %reshape_18 : tensor<?x?x768xf32>, tensor<?x?x768xf32>) outs(%32 : tensor<?x?x768xf32>) {
  ^bb0(%in: f32, %in_26: f32, %out: f32):
    %63 = arith.addf %in, %in_26 : f32
    linalg.yield %63 : f32
  } -> tensor<?x?x768xf32>
  %39 = tensor.empty(%30, %31) : tensor<?x?x1xf32>
  %40 = linalg.fill ins(%cst_10 : f32) outs(%39 : tensor<?x?x1xf32>) -> tensor<?x?x1xf32>
  %41 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, 0)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%38 : tensor<?x?x768xf32>) outs(%40 : tensor<?x?x1xf32>) {
  ^bb0(%in: f32, %out: f32):
    %63 = arith.addf %in, %out : f32
    linalg.yield %63 : f32
  } -> tensor<?x?x1xf32>
  %42 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, 0)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%41 : tensor<?x?x1xf32>) outs(%39 : tensor<?x?x1xf32>) {
  ^bb0(%in: f32, %out: f32):
    %63 = arith.divf %in, %cst_11 : f32
    linalg.yield %63 : f32
  } -> tensor<?x?x1xf32>
  %43 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, 0)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%38, %42 : tensor<?x?x768xf32>, tensor<?x?x1xf32>) outs(%32 : tensor<?x?x768xf32>) {
  ^bb0(%in: f32, %in_26: f32, %out: f32):
    %63 = arith.subf %in, %in_26 : f32
    linalg.yield %63 : f32
  } -> tensor<?x?x768xf32>
  %44 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> ()>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%43, %cst_0 : tensor<?x?x768xf32>, tensor<f32>) outs(%32 : tensor<?x?x768xf32>) {
  ^bb0(%in: f32, %in_26: f32, %out: f32):
    %63 = math.powf %in, %in_26 : f32
    linalg.yield %63 : f32
  } -> tensor<?x?x768xf32>
  %45 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, 0)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%44 : tensor<?x?x768xf32>) outs(%40 : tensor<?x?x1xf32>) {
  ^bb0(%in: f32, %out: f32):
    %63 = arith.addf %in, %out : f32
    linalg.yield %63 : f32
  } -> tensor<?x?x1xf32>
  %46 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, 0)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%45 : tensor<?x?x1xf32>) outs(%39 : tensor<?x?x1xf32>) {
  ^bb0(%in: f32, %out: f32):
    %63 = arith.divf %in, %cst_11 : f32
    linalg.yield %63 : f32
  } -> tensor<?x?x1xf32>
  %47 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, 0)>, affine_map<(d0, d1, d2) -> ()>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%46, %cst : tensor<?x?x1xf32>, tensor<f32>) outs(%39 : tensor<?x?x1xf32>) {
  ^bb0(%in: f32, %in_26: f32, %out: f32):
    %63 = arith.addf %in, %in_26 : f32
    linalg.yield %63 : f32
  } -> tensor<?x?x1xf32>
  %48 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, 0)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%47 : tensor<?x?x1xf32>) outs(%39 : tensor<?x?x1xf32>) {
  ^bb0(%in: f32, %out: f32):
    %63 = math.sqrt %in : f32
    linalg.yield %63 : f32
  } -> tensor<?x?x1xf32>
  %49 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, 0)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%43, %48 : tensor<?x?x768xf32>, tensor<?x?x1xf32>) outs(%32 : tensor<?x?x768xf32>) {
  ^bb0(%in: f32, %in_26: f32, %out: f32):
    %63 = arith.divf %in, %in_26 : f32
    linalg.yield %63 : f32
  } -> tensor<?x?x768xf32>
  %50 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%49, %cst_6 : tensor<?x?x768xf32>, tensor<768xf32>) outs(%32 : tensor<?x?x768xf32>) {
  ^bb0(%in: f32, %in_26: f32, %out: f32):
    %63 = arith.mulf %in, %in_26 : f32
    linalg.yield %63 : f32
  } -> tensor<?x?x768xf32>
  %51 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%50, %cst_7 : tensor<?x?x768xf32>, tensor<768xf32>) outs(%32 : tensor<?x?x768xf32>) {
  ^bb0(%in: f32, %in_26: f32, %out: f32):
    %63 = arith.addf %in, %in_26 : f32
    linalg.yield %63 : f32
  } -> tensor<?x?x768xf32>
  %52 = tensor.empty(%30) : tensor<?x768x768xf32>
  %53 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_9 : tensor<768x768xf32>) outs(%52 : tensor<?x768x768xf32>) {
  ^bb0(%in: f32, %out: f32):
    linalg.yield %in : f32
  } -> tensor<?x768x768xf32>
  %54 = linalg.fill ins(%cst_10 : f32) outs(%32 : tensor<?x?x768xf32>) -> tensor<?x?x768xf32>
  %55 = linalg.batch_matmul ins(%51, %53 : tensor<?x?x768xf32>, tensor<?x768x768xf32>) outs(%54 : tensor<?x?x768xf32>) -> tensor<?x?x768xf32>
  %56 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_8, %55 : tensor<768xf32>, tensor<?x?x768xf32>) outs(%32 : tensor<?x?x768xf32>) {
  ^bb0(%in: f32, %in_26: f32, %out: f32):
    %63 = arith.addf %in, %in_26 : f32
    linalg.yield %63 : f32
  } -> tensor<?x?x768xf32>
  %from_elements_19 = tensor.from_elements %17, %18, %c12_i64, %c64_i64 : tensor<4xi64>
  %reshape_20 = tensor.reshape %56(%from_elements_19) : (tensor<?x?x768xf32>, tensor<4xi64>) -> tensor<?x?x?x?xf32>
  %57 = arith.index_cast %17 : i64 to index
  %58 = arith.index_cast %18 : i64 to index
  %59 = tensor.empty(%57, %58) : tensor<?x12x?x64xf32>
  %cast = tensor.cast %reshape_20 : tensor<?x?x?x?xf32> to tensor<?x?x12x64xf32>
  %60 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%cast : tensor<?x?x12x64xf32>) outs(%59 : tensor<?x12x?x64xf32>) {
  ^bb0(%in: f32, %out: f32):
    linalg.yield %in : f32
  } -> tensor<?x12x?x64xf32>
  %cast_21 = tensor.cast %60 : tensor<?x12x?x64xf32> to tensor<?x?x?x?xf32>
  %61 = hal.tensor.barrier join(%cast_21 : tensor<?x?x?x?xf32>) => %arg2 : !hal.fence
  %dim_22 = tensor.dim %61, %c0 : tensor<?x?x?x?xf32>
  %dim_23 = tensor.dim %61, %c1 : tensor<?x?x?x?xf32>
  %dim_24 = tensor.dim %61, %c2 : tensor<?x?x?x?xf32>
  %dim_25 = tensor.dim %61, %c3 : tensor<?x?x?x?xf32>
  %62 = hal.tensor.export %61 : tensor<?x?x?x?xf32>{%dim_22, %dim_23, %dim_24, %dim_25} -> !hal.buffer_view
  util.return %62 : !hal.buffer_view
}

in particular this sequence of instructions

  %0 = hal.buffer_view.dim<%arg0 : !hal.buffer_view>[0] : index
  %1 = hal.buffer_view.dim<%arg0 : !hal.buffer_view>[1] : index
  %2 = hal.tensor.import wait(%arg1) => %arg0 : !hal.buffer_view -> tensor<?x?xi64>{%0, %1}
  %dim = tensor.dim %2, %c1 : tensor<?x?xi64>
  %dim_13 = tensor.dim %2, %c0 : tensor<?x?xi64>
  %17 = arith.index_cast %dim_13 : index to i64
  %18 = arith.index_cast %dim : index to i64
  %from_elements = tensor.from_elements %17, %18, %c768_i64 : tensor<3xi64>
  %reshape = tensor.reshape %21(%from_elements) : (tensor<?x768xf32>, tensor<3xi64>) -> tensor<?x?x768xf32>
 %from_elements_15 = tensor.from_elements %c1_i64, %27, %c768_i64 : tensor<3xi64>
  %reshape_16 = tensor.reshape %29(%from_elements_15) : (tensor<?x768xf32>, tensor<3xi64>) -> tensor<?x?x768xf32>
  %30 = arith.index_cast %17 : i64 to index
  %31 = arith.index_cast %18 : i64 to index
  %32 = tensor.empty(%30, %31) : tensor<?x?x768xf32>
  %33 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%reshape, %reshape_16 : tensor<?x?x768xf32>, tensor<?x?x768xf32>) outs(%32 : tensor<?x?x768xf32>) {
  ^bb0(%in: f32, %in_26: f32, %out: f32):
    %63 = arith.addf %in, %in_26 : f32
    linalg.yield %63 : f32
  } -> tensor<?x?x768xf32>

The second tensor.reshape operation has information that the outer most dim of the result (%reshape_16) is of size 1. For consistency. %32 should also have an outer dimension of size 1. Looking through the IR.. that basically means %0 is of value 1. IIUC, torch should already know this information and it should have been materialized in the IR. This inconsistency between the same dimension being known statically in one-place and being dynamic in another causes issue downstream during compilation. The downstream compilation is being fixed, but this seems like something worth fixing in torch as well. If this is not possible, then we will have to build downstream of torch a way to constraint solve the dynamic dimensions for such cases (but I think Torch already has this).

@zjgarvey
Copy link
Collaborator

This is pretty interesting. Onnx shape inference can't seem to figure out that a slice op in dim 1 should preserve dim 0, so it leaves it as dynamic. When lowering in torch, we correct the slice shape to [1,?], then the dim 0 from that slice op is being extracted and passed to a list construct for an aten.unflatten.int op, and for some reason, this aten.unflatten.int op cannot figure out that the result shape dim 0 should be statically 1. In the conversion of unflatten.int, we are only converting to tensor.expand_shape if there is less than two dynamic dims (otherwise converting to tensor.reshape). I'm going to see if there is something I can do to get this unflatten.int op to recognize the correct output shape before lowering to linalg on tensors.

@zjgarvey
Copy link
Collaborator

I believe that %arg0 is indeed shaped [?, ?], so %32 probably shouldn't have outer dim statically equal to 1.

Here is something particularly strange:

Let the shape of %arg0 be [dim0, dim1]. From reading the IR above, %reshape_16 (and %reshape_18) should have shape [1, min(dim1, 512), 768], but %reshape should have shape [dim0, dim1, 768]. In %33 above (which computes %reshape + %reshape_16), this is particularly problematic for two reasons:

  1. It should broadcast dim 0 from %reshape_16 in the "adding" generic. That is, the indexing maps aren't properly treating dim 0 for %reshape_16 as it should index into dim 0 as a constant 0.
  2. If a user provides a dim1 > 512, then the accessing for %reshape in dim 1 will exceed the accessing for %reshape_16 in dim 1.

I'll work on addressing the issue (1.) for now. If we can get the unflatten.int to properly recognize the static shape in dim 0, hopefully that issue will be resolved by the lowering of aten.add.tensor.

@zjgarvey
Copy link
Collaborator

This is a small reproducer for the failure to infer shapes:

module {
  func.func @torch_jit(%arg0: !torch.vtensor<[?,768],f32>) -> !torch.vtensor<[?,?,768],f32> {
    %int1 = torch.constant.int 1
    %int0 = torch.constant.int 0
    %0 = torch.aten.size.int %arg0, %int0 : !torch.vtensor<[?,768],f32>, !torch.int -> !torch.int
    %1 = torch.prim.ListConstruct %int1, %0 : (!torch.int, !torch.int) -> !torch.list<int>
    %2 = torch.aten.unflatten.int %arg0, %int0, %1 : !torch.vtensor<[?,768],f32>, !torch.int, !torch.list<int> -> !torch.vtensor<[?,?,768],f32>
    return %2 : !torch.vtensor<[?,?,768],f32>
  }
}

Running

torch-mlir-opt --torch-shape-refinement-pipeline

does not successfully infer the output shape [1, ?, 768].

@zjgarvey
Copy link
Collaborator

I'm guessing that since the %0 isn't a constant int, the shape inference bails. I'm going to look into folding this particular pattern into an aten.unsqueeze op, then the aten.unsqueeze will likely be able to correctly infer the output shape.

@zjgarvey
Copy link
Collaborator

I've added a canonicalization pattern for unflatten -> unsqueeze and tried it out on the original onnx ir in https://gist.github.com/nirvedhmeshram/f350fa447fdf5cdcbff45ced0dd77e6c. What is truly unfortunate is that the canonicalization patten can only take effect after shape inference determines that a certain slice op should have a 1 for the outermost dim. Therefore, shape inference would need to be run twice on this IR, since the unsqueeze op will only appear after the first shape inference is completed.

I'm going to look into other options. If you want to check out the work I'm doing on this, see the branch https://github.com/zjgarvey/torch-mlir/tree/unflatten_fold

@zjgarvey
Copy link
Collaborator

Amazing. Using a static info cast op from the correct (static) unsqueeze type to the original dynamic type will resolve into using the correct (static) type after canonicalization. Adding this change to the branch and submitting a PR.

rsuderman pushed a commit that referenced this issue Sep 3, 2024
Addresses an issue in <#3651>
where some unflatten ops generated from onnx models weren't propagating
static shape information. It may be necessary to add further
optimizations for the more general case when some static information is
present in the unflatten (or possibly reshape/view) op's `sizes` list,
but not reflected in the output shape. These ops will only successfully
infer shapes if the `sizes` list is gotten from a list of constant ints
(with possibly one -1). A common example where this fails is when some
of the `sizes` are determined from `aten.size.int` ops on dynamic
tensors, and other `sizes` are known statically.

This PR includes:
- a canonicalizer for `aten.unflatten.int` which converts to
`aten.unsqueeze` when it is expanding one dim to two, and one of the new
dims is statically 1.
- an improvement to the folder for `aten.__or__.bool` which does not
rely on *both* operands being static.
@zjgarvey
Copy link
Collaborator

I'm going to close this for now. If we find that there is still some ops with missing dimensionality, we can reopen.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants