Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

error: One or more operations with large vector sizes (8192 bytes) were found: #18677

Open
pdhirajkumarprasad opened this issue Oct 3, 2024 · 13 comments
Assignees
Labels
bug 🐞 Something isn't working

Comments

@pdhirajkumarprasad
Copy link

What happened?

For the attached IR, seeing error as

error: One or more operations with large vector sizes (8192 bytes) were found:

    %255 = linalg.generic {indexing_maps = [#map16, #map17, #map16], iterator_types = ["parallel", "parallel"]} ins(%254, %cst_6 : tensor<?x1000xf32>, tensor<1000xf32>) outs(%252 : tensor<?x1000xf32>) {
           ^
tt.mlir:21:3: note: called from

command:

iree-compile --iree-input-type=auto --iree-vm-bytecode-module-output-format=flatbuffer-binary --iree-hal-target-backends=llvm-cpu --mlir-print-debuginfo --mlir-print-op-on-diagnostic=false linalg.mlir

and this linalg.mlir was generated with following command:

torch-mlir-opt --convert-torch-onnx-to-torch --torch-decompose-complex-ops -torch-backend-to-linalg-on-tensors-backend-pipeline model.torch_onnx.mlir

linalg.mlir.txt
model.torch_onnx.mlir.txt

Steps to reproduce your issue

Mentioned above

What component(s) does this issue relate to?

Compiler

Version information

No response

Additional context

No response

@pdhirajkumarprasad pdhirajkumarprasad added the bug 🐞 Something isn't working label Oct 3, 2024
@nirvedhmeshram
Copy link
Contributor

This is an issue in a unpack + elementwise dispatch. the elementwise gets tiled but the unpack does not, here is the IR dump for the dispatch
@hanhanW any idea what needs to be done here?

@hanhanW
Copy link
Contributor

hanhanW commented Oct 3, 2024

There is a tensor.extract_slice created after distribution, which blocks the further TileAndFuse. So it generates large vectors. Someone needs to check why the extract_slice is created and fix it.

If it is not fixable, we can try adding the FoldUnpackWithExtractSliceOp pattern. You can populate it from https://github.com/llvm/llvm-project/blob/428ae0f12e29eff1ddcaf59bdcce904ec056963e/mlir/lib/Dialect/Tensor/Transforms/PackAndUnpackPatterns.cpp#L484-L491

// -----// IR Dump After LowerExecutableUsingTransformDialectPass (iree-codegen-lower-executable-using-transform-dialect) //----- //
module {
  func.func @torch_jit$async_dispatch_24_unpack_elementwise_1x1000_f32_dispatch_0_unpack_elementwise_1x1000_f32() attributes {translation_info = #iree_codegen.translation_info<CPUDoubleTilingExpert>} {
    %cst = arith.constant 0.000000e+00 : f32
    %c0 = arith.constant 0 : index
    %0 = hal.interface.binding.subspan layout(<bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) binding(0) alignment(64) offset(%c0) flags("ReadOnly|Indirect") : !flow.dispatch.tensor<readonly:tensor<1x250x8x4xf32>>
    %1 = hal.interface.binding.subspan layout(<bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) binding(1) alignment(64) offset(%c0) flags(Indirect) : !flow.dispatch.tensor<writeonly:tensor<1x1000xf32>>
    %2 = flow.dispatch.tensor.load %0, offsets = [0, 0, 0, 0], sizes = [1, 250, 8, 4], strides = [1, 1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<1x250x8x4xf32>> -> tensor<1x250x8x4xf32>
    %3 = tensor.empty() : tensor<1x1000xf32>
    %unpack = tensor.unpack %2 outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [8, 4] into %3 {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[0, 1000], [8, 4], [0, 0], [0, 0]]>} : tensor<1x250x8x4xf32> -> tensor<1x1000xf32>
    %4 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%unpack : tensor<1x1000xf32>) outs(%3 : tensor<1x1000xf32>) attrs =  {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[0, 1000], [8, 4], [0, 0], [0, 0]]>} {
    ^bb0(%in: f32, %out: f32):
      %5 = arith.addf %in, %cst : f32
      linalg.yield %5 : f32
    } -> tensor<1x1000xf32>
    flow.dispatch.tensor.store %4, %1, offsets = [0, 0], sizes = [1, 1000], strides = [1, 1] : tensor<1x1000xf32> -> !flow.dispatch.tensor<writeonly:tensor<1x1000xf32>>
    return
  }
}

// -----// IR Dump After TileAndDistributeToWorkgroupsPass (iree-codegen-tile-and-distribute-to-workgroups) //----- //
func.func @torch_jit$async_dispatch_24_unpack_elementwise_1x1000_f32_dispatch_0_unpack_elementwise_1x1000_f32() attributes {translation_info = #iree_codegen.translation_info<CPUDoubleTilingExpert>} {
  %c250 = arith.constant 250 : index
  %c1000 = arith.constant 1000 : index
  %cst = arith.constant 0.000000e+00 : f32
  %c0 = arith.constant 0 : index
  %0 = hal.interface.binding.subspan layout(<bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) binding(0) alignment(64) offset(%c0) flags("ReadOnly|Indirect") : !flow.dispatch.tensor<readonly:tensor<1x250x8x4xf32>>
  %1 = hal.interface.binding.subspan layout(<bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) binding(1) alignment(64) offset(%c0) flags(Indirect) : !flow.dispatch.tensor<writeonly:tensor<1x1000xf32>>
  %2 = flow.dispatch.tensor.load %0, offsets = [0, %c0, 0, 0], sizes = [1, %c250, 8, 4], strides = [1, 1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<1x250x8x4xf32>> -> tensor<1x?x8x4xf32>
  %3 = tensor.empty() : tensor<8x1000xf32>
  %unpack = tensor.unpack %2 outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [8, 4] into %3 {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[0, 1000], [8, 4], [0, 0], [0, 0]]>} : tensor<1x?x8x4xf32> -> tensor<8x1000xf32>
  %4 = tensor.empty() : tensor<1x1000xf32>
  %extracted_slice = tensor.extract_slice %unpack[0, 0] [1, 1000] [1, 1] : tensor<8x1000xf32> to tensor<1x1000xf32>
  %5 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%extracted_slice : tensor<1x1000xf32>) outs(%4 : tensor<1x1000xf32>) attrs =  {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[0, 1000], [8, 4], [0, 0], [0, 0]]>} {
  ^bb0(%in: f32, %out: f32):
    %6 = arith.addf %in, %cst : f32
    linalg.yield %6 : f32
  } -> tensor<1x1000xf32>
  %cast = tensor.cast %5 : tensor<1x1000xf32> to tensor<1x?xf32>
  flow.dispatch.tensor.store %cast, %1, offsets = [0, %c0], sizes = [1, %c1000], strides = [1, 1] : tensor<1x?xf32> -> !flow.dispatch.tensor<writeonly:tensor<1x1000xf32>>
  return
}

@nirvedhmeshram
Copy link
Contributor

good point, which made me notice, isnt this unpack wrong?

      %unpack = tensor.unpack %0 
      outer_dims_perm = [0, 1] 
      inner_dims_pos = [0, 1] 
      inner_tiles = [8, 4] into %1 : tensor<1x250x8x4xf32> -> tensor<1x1000xf32>

it should be

      %unpack = tensor.unpack %0 
      outer_dims_perm = [0, 1] 
      inner_dims_pos = [0, 1] 
      inner_tiles = [8, 4] into %1 : tensor<1x250x8x4xf32> -> tensor<8x1000xf32>

@nirvedhmeshram
Copy link
Contributor

nirvedhmeshram commented Oct 3, 2024

This issue seems to be a friend of #18603 we have

    %115 = linalg.generic 
                {indexing_maps = [#map1, #map1, #map1, #map1], iterator_types = []} 
                 ins(%112, %113, %114 : tensor<i1>, tensor<i64>, tensor<i64>) outs(%27 : tensor<i64>) {
    ^bb0(%in: i1, %in_76: i64, %in_77: i64, %out: i64):
      %216 = arith.select %in, %in_76, %in_77 : i64
      linalg.yield %216 : i64
    } -> tensor<i64>
    ...
    %extracted_61 = tensor.extract %115[] : tensor<i64>
     ...
    %131 = arith.index_cast %extracted_61 : i64 to index
    %212 = iree_encoding.unset_encoding 
                 %211 : tensor<?x1000xf32, #iree_encoding.encoding<operand_index = 2 : index, 
                 op_type =  matmul, element_types = [f32, f32, f32], 
                 user_indexing_maps = [#map23, #map24, #map25], 
                round_dims_to = array<i64: 32, 32, 32>>> -> tensor<?x1000xf32>
    %extracted_slice_75 = tensor.extract_slice %212[0, 0] [%131, 1000] [1, 1] : tensor<?x1000xf32> to tensor<?x1000xf32>
    %213 = linalg.generic 
                 {indexing_maps = [#map18, #map21, #map18], 
                  iterator_types = ["parallel", "parallel"]} 
                  ins(%extracted_slice_75, %cst_16 : tensor<?x1000xf32>, tensor<1000xf32>) 
                  outs(%206 : tensor<?x1000xf32>) {
    ^bb0(%in: f32, %in_76: f32, %out: f32):
      %216 = arith.addf %in, %in_76 : f32
      linalg.yield %216 : f32
    } -> tensor<?x1000xf32>

cc @zjgarvey

@hanhanW
Copy link
Contributor

hanhanW commented Oct 3, 2024

This unpack is valid because there is extract_slice semantic in unpack ops. You can think that it is an inverse operation of pack op. The pack op has padding semantics, and the unpack op has extract_slice semantics. It is valid to fold unpack -> extract_slice into a single unpack op. One of the ideas of having destination tensor for unpack op is that it describes the shape.

      %unpack = tensor.unpack %0 
      outer_dims_perm = [0, 1] 
      inner_dims_pos = [0, 1] 
      inner_tiles = [8, 4] into %1 : tensor<1x250x8x4xf32> -> tensor<1x1000xf32>

@nirvedhmeshram
Copy link
Contributor

This unpack is valid because there is extract_slice semantic in unpack ops. You can think that it is an inverse operation of pack op. The pack op has padding semantics, and the unpack op has extract_slice semantics. It is valid to fold unpack -> extract_slice into a single unpack op. One of the ideas of having destination tensor for unpack op is that it describes the shape.

  %unpack = tensor.unpack %0 
  outer_dims_perm = [0, 1] 
  inner_dims_pos = [0, 1] 
  inner_tiles = [8, 4] into %1 : tensor<1x250x8x4xf32> -> tensor<1x1000xf32>

Thanks that makes sense, however not sure if we intended to reach this unpack with the extract slice semantics or its a bug because of having shape computation encoded in tensor math.

@zjgarvey
Copy link
Contributor

zjgarvey commented Oct 3, 2024

@nirvedhmeshram I'll focus on getting the where.self op to return scalar arithmetic when possible.

@nirvedhmeshram
Copy link
Contributor

@nirvedhmeshram I'll focus on getting the where.self op to return scalar arithmetic when possible.

Sounds good, I will check if we want to support unpack with extract slice fusion as well and either add that support or disable this fusion in such cases based on where that discussion goes.

@hanhanW
Copy link
Contributor

hanhanW commented Oct 3, 2024

Yes, it is intended. It is not a bug. I guess the behavior is triggered by the tiling implementation: https://github.com/llvm/llvm-project/blob/fc4b1a303b296d02f6243a083510c4ee7f290ab0/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp#L561-L588

Looking at the implementation, I think the issue is that it is not treated as perfect tiling case. We can try to enhance the logic. My guess is that the value of tileSize[0] is 1 (which is the output shape). However, it is not a multiple of inner tile size (which is 8 in this case). So it goes with non-perfect-tiling path.

One way to enhance the logic is passing the size of destination tensor to the getUnpackTileDimInfo function. If the sizes match, it returns the perfect tiling config.

zjgarvey added a commit to llvm/torch-mlir that referenced this issue Oct 4, 2024
… generic ops (#3762)

This is motivated by the fact that shapes are stored as tensors in ONNX,
and IREE tries to perform tensor arithmetic on the device. This causes
unnecessary dispatches, and makes it harder for the compiler to reason
about shapes.

Here is a small snippet of torch-IR that is typical seen coming from
ONNX models:

```mlir
module {
  func.func @main_graph(%arg0: !torch.vtensor<[?,?,768],f32>, %arg1: !torch.vtensor<[?,?,768],f32>) -> !torch.vtensor<[],si64> {
    %int0 = torch.constant.int 0
    %0 = torch.vtensor.literal(dense<0> : tensor<1xsi64>) : !torch.vtensor<[1],si64>
    %1 = torch.aten._shape_as_tensor %arg1 : !torch.vtensor<[?,?,768],f32> -> !torch.vtensor<[3],si64>
    %2 = torch.aten.index_select %1, %int0, %0 : !torch.vtensor<[3],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64>
    %3 = torch.aten.squeeze.dim %2, %int0 : !torch.vtensor<[1],si64>, !torch.int -> !torch.vtensor<[],si64>
    %4 = torch.aten.item %3 : !torch.vtensor<[],si64> -> !torch.int
    %5 = torch.aten.eq.int %4, %int0 : !torch.int, !torch.int -> !torch.bool
    %6 = torch.aten.Int.bool %5 : !torch.bool -> !torch.int
    %7 = torch.aten.size.int %arg0, %int0 : !torch.vtensor<[?,?,768],f32>, !torch.int -> !torch.int
    %8 = torch.prim.NumToTensor.Scalar %6 : !torch.int -> !torch.vtensor<[],i1>
    %9 = torch.prim.NumToTensor.Scalar %7 : !torch.int -> !torch.vtensor<[],si64>
    %10 = torch.prim.NumToTensor.Scalar %4 : !torch.int -> !torch.vtensor<[],si64>
    %11 = torch.aten.where.self %8, %9, %10 : !torch.vtensor<[],i1>, !torch.vtensor<[],si64>, !torch.vtensor<[],si64> -> !torch.vtensor<[],si64>
    return %11 : !torch.vtensor<[],si64>
  }
}
```

Without the change in this PR, the result would be:

```mlir
#map = affine_map<() -> ()>
module {
  ml_program.global private mutable @global_seed(dense<0> : tensor<i64>) : tensor<i64>
  func.func @main_graph(%arg0: tensor<?x?x768xf32>, %arg1: tensor<?x?x768xf32>) -> tensor<i64> {
    %c0_i64 = arith.constant 0 : i64
    %c0 = arith.constant 0 : index
    %dim = tensor.dim %arg1, %c0 : tensor<?x?x768xf32>
    %0 = arith.index_cast %dim : index to i64
    %1 = tensor.empty() : tensor<1xi64>
    %collapsed = tensor.collapse_shape %1 [] : tensor<1xi64> into tensor<i64>
    %2 = linalg.fill ins(%0 : i64) outs(%collapsed : tensor<i64>) -> tensor<i64>
    %extracted = tensor.extract %2[] : tensor<i64>
    %3 = arith.cmpi eq, %extracted, %c0_i64 : i64
    %dim_0 = tensor.dim %arg0, %c0 : tensor<?x?x768xf32>
    %4 = arith.index_cast %dim_0 : index to i64
    %5 = tensor.empty() : tensor<i1>
    %6 = linalg.fill ins(%3 : i1) outs(%5 : tensor<i1>) -> tensor<i1>
    %7 = tensor.empty() : tensor<i64>
    %8 = linalg.fill ins(%4 : i64) outs(%7 : tensor<i64>) -> tensor<i64>
    %9 = linalg.fill ins(%extracted : i64) outs(%7 : tensor<i64>) -> tensor<i64>
    %10 = linalg.generic {indexing_maps = [#map, #map, #map, #map], iterator_types = []} ins(%6, %8, %9 : tensor<i1>, tensor<i64>, tensor<i64>) outs(%7 : tensor<i64>) {
    ^bb0(%in: i1, %in_1: i64, %in_2: i64, %out: i64):
      %11 = arith.select %in, %in_1, %in_2 : i64
      linalg.yield %11 : i64
    } -> tensor<i64>
    return %10 : tensor<i64>
  }
}
```

With the change in this PR, we would instead get:

```mlir
module {
  ml_program.global private mutable @global_seed(dense<0> : tensor<i64>) : tensor<i64>
  func.func @main_graph(%arg0: tensor<?x?x768xf32>, %arg1: tensor<?x?x768xf32>) -> tensor<i64> {
    %c0_i64 = arith.constant 0 : i64
    %c0 = arith.constant 0 : index
    %dim = tensor.dim %arg1, %c0 : tensor<?x?x768xf32>
    %0 = arith.index_cast %dim : index to i64
    %1 = tensor.empty() : tensor<1xi64>
    %collapsed = tensor.collapse_shape %1 [] : tensor<1xi64> into tensor<i64>
    %2 = linalg.fill ins(%0 : i64) outs(%collapsed : tensor<i64>) -> tensor<i64>
    %extracted = tensor.extract %2[] : tensor<i64>
    %3 = arith.cmpi eq, %extracted, %c0_i64 : i64
    %dim_0 = tensor.dim %arg0, %c0 : tensor<?x?x768xf32>
    %4 = arith.index_cast %dim_0 : index to i64
    %5 = arith.select %3, %4, %extracted : i64
    %6 = tensor.empty() : tensor<i64>
    %7 = linalg.fill ins(%5 : i64) outs(%6 : tensor<i64>) -> tensor<i64>
    return %7 : tensor<i64>
  }
}
```

Some related issues for context:
1. <iree-org/iree#18677>
2. <iree-org/iree#18631>
@pashu123
Copy link
Contributor

pashu123 commented Oct 6, 2024

There is a tensor.extract_slice created after distribution, which blocks the further TileAndFuse. So it generates large vectors. Someone needs to check why the extract_slice is created and fix it.

If it is not fixable, we can try adding the FoldUnpackWithExtractSliceOp pattern. You can populate it from https://github.com/llvm/llvm-project/blob/428ae0f12e29eff1ddcaf59bdcce904ec056963e/mlir/lib/Dialect/Tensor/Transforms/PackAndUnpackPatterns.cpp#L484-L491

// -----// IR Dump After LowerExecutableUsingTransformDialectPass (iree-codegen-lower-executable-using-transform-dialect) //----- //
module {
func.func @torch_jit$async_dispatch_24_unpack_elementwise_1x1000_f32_dispatch_0_unpack_elementwise_1x1000_f32() attributes {translation_info = #iree_codegen.translation_info} {
%cst = arith.constant 0.000000e+00 : f32
%c0 = arith.constant 0 : index
%0 = hal.interface.binding.subspan layout(<bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) binding(0) alignment(64) offset(%c0) flags("ReadOnly|Indirect") : !flow.dispatch.tensor<readonly:tensor<1x250x8x4xf32>>
%1 = hal.interface.binding.subspan layout(<bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) binding(1) alignment(64) offset(%c0) flags(Indirect) : !flow.dispatch.tensor<writeonly:tensor<1x1000xf32>>
%2 = flow.dispatch.tensor.load %0, offsets = [0, 0, 0, 0], sizes = [1, 250, 8, 4], strides = [1, 1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<1x250x8x4xf32>> -> tensor<1x250x8x4xf32>
%3 = tensor.empty() : tensor<1x1000xf32>
%unpack = tensor.unpack %2 outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [8, 4] into %3 {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[0, 1000], [8, 4], [0, 0], [0, 0]]>} : tensor<1x250x8x4xf32> -> tensor<1x1000xf32>
%4 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%unpack : tensor<1x1000xf32>) outs(%3 : tensor<1x1000xf32>) attrs = {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[0, 1000], [8, 4], [0, 0], [0, 0]]>} {
^bb0(%in: f32, %out: f32):
%5 = arith.addf %in, %cst : f32
linalg.yield %5 : f32
} -> tensor<1x1000xf32>
flow.dispatch.tensor.store %4, %1, offsets = [0, 0], sizes = [1, 1000], strides = [1, 1] : tensor<1x1000xf32> -> !flow.dispatch.tensor<writeonly:tensor<1x1000xf32>>
return
}
}

// -----// IR Dump After TileAndDistributeToWorkgroupsPass (iree-codegen-tile-and-distribute-to-workgroups) //----- //
func.func @torch_jit$async_dispatch_24_unpack_elementwise_1x1000_f32_dispatch_0_unpack_elementwise_1x1000_f32() attributes {translation_info = #iree_codegen.translation_info} {
%c250 = arith.constant 250 : index
%c1000 = arith.constant 1000 : index
%cst = arith.constant 0.000000e+00 : f32
%c0 = arith.constant 0 : index
%0 = hal.interface.binding.subspan layout(<bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) binding(0) alignment(64) offset(%c0) flags("ReadOnly|Indirect") : !flow.dispatch.tensor<readonly:tensor<1x250x8x4xf32>>
%1 = hal.interface.binding.subspan layout(<bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) binding(1) alignment(64) offset(%c0) flags(Indirect) : !flow.dispatch.tensor<writeonly:tensor<1x1000xf32>>
%2 = flow.dispatch.tensor.load %0, offsets = [0, %c0, 0, 0], sizes = [1, %c250, 8, 4], strides = [1, 1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<1x250x8x4xf32>> -> tensor<1x?x8x4xf32>
%3 = tensor.empty() : tensor<8x1000xf32>
%unpack = tensor.unpack %2 outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [8, 4] into %3 {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[0, 1000], [8, 4], [0, 0], [0, 0]]>} : tensor<1x?x8x4xf32> -> tensor<8x1000xf32>
%4 = tensor.empty() : tensor<1x1000xf32>
%extracted_slice = tensor.extract_slice %unpack[0, 0] [1, 1000] [1, 1] : tensor<8x1000xf32> to tensor<1x1000xf32>
%5 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%extracted_slice : tensor<1x1000xf32>) outs(%4 : tensor<1x1000xf32>) attrs = {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[0, 1000], [8, 4], [0, 0], [0, 0]]>} {
^bb0(%in: f32, %out: f32):
%6 = arith.addf %in, %cst : f32
linalg.yield %6 : f32
} -> tensor<1x1000xf32>
%cast = tensor.cast %5 : tensor<1x1000xf32> to tensor<1x?xf32>
flow.dispatch.tensor.store %cast, %1, offsets = [0, %c0], sizes = [1, %c1000], strides = [1, 1] : tensor<1x?xf32> -> !flow.dispatch.tensor<writeonly:tensor<1x1000xf32>>
return
}

This is just merely based on the observation of this IR. Since we are adding zeros, can't we DCE the linalg.generic ?

@hanhanW
Copy link
Contributor

hanhanW commented Oct 7, 2024

This is just merely based on the observation of this IR. Since we are adding zeros, can't we DCE the linalg.generic ?

Good question. I honestly don't know where should it happen. It is not easy to identify these cases (e.g., transpose, etc.) at Linalg level, so we typically rely on ConstEval. It's easier if we can do it at higher level (like arith, or input dialects).

@nirvedhmeshram
Copy link
Contributor

nirvedhmeshram commented Oct 7, 2024

This is just merely based on the observation of this IR. Since we are adding zeros, can't we DCE the linalg.generic ?

Good question. I honestly don't know where should it happen. It is not easy to identify these cases (e.g., transpose, etc.) at Linalg level, so we typically rely on ConstEval. It's easier if we can do it at higher level (like arith, or input dialects).

Actually, this came up in a voice meeting we had last week and we still want to support this dispatch because often the constant is zero due to fake weights but might not be zero in real use cases.

@pashu123
Copy link
Contributor

pashu123 commented Oct 9, 2024

This is fixed with the latest pull. PTAL.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug 🐞 Something isn't working
Projects
Status: No status
Development

No branches or pull requests

5 participants