Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Failure to compile gemm_postop_addmatrix_benchmark.py with #2378

Closed
etiotto opened this issue Sep 27, 2024 · 5 comments · Fixed by #2400
Closed

Failure to compile gemm_postop_addmatrix_benchmark.py with #2378

etiotto opened this issue Sep 27, 2024 · 5 comments · Fixed by #2400

Comments

@etiotto
Copy link
Contributor

etiotto commented Sep 27, 2024

The gemm_postop_addmatrix_benchmark.py benchmark fails to compile with configuration:

        triton.Config(
            {'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4, 'grf_mode': 'large'},
            num_stages=3, num_warps=32),

because the compiler allocates shared memory buffers that exceed the capacity. To reproduce:

USE_IPEX=0 python gemm_postop_addmatrix_benchmark.py

Traceback (most recent call last):
  File "/home/jovyan/intel-xpu-backend-for-triton/benchmarks/triton_kernels_benchmark/gemm_postop_addmatrix_benchmark.py", line 255, in <module>
    benchmark.run(show_plots=False, print_data=True)
  File "/home/jovyan/.conda/envs/triton-3.10/lib/python3.10/site-packages/triton_kernels_benchmark-0.0.0-py3.10.egg/triton_kernels_benchmark/benchmark_testing.py", line 373, in run
    result_dfs.append(self._run(bench, save_path, show_plots, print_data, **kwargs))
  File "/home/jovyan/.conda/envs/triton-3.10/lib/python3.10/site-packages/triton_kernels_benchmark-0.0.0-py3.10.egg/triton_kernels_benchmark/benchmark_testing.py", line 307, in _run
    ret = self.fn(**x_args, **{bench.line_arg: y}, **bench.args, **kwrags)
  File "/home/jovyan/intel-xpu-backend-for-triton/benchmarks/triton_kernels_benchmark/gemm_postop_addmatrix_benchmark.py", line 242, in benchmark
    benchmark_suit.assert_close(triton_fn(), torch_fn(), atol=1e-4, rtol=rtol, err_msg='triton to torch')
  File "/home/jovyan/intel-xpu-backend-for-triton/benchmarks/triton_kernels_benchmark/gemm_postop_addmatrix_benchmark.py", line 239, in <lambda>
    triton_fn = lambda: matmul(a, b, d, c)
  File "/home/jovyan/intel-xpu-backend-for-triton/benchmarks/triton_kernels_benchmark/gemm_postop_addmatrix_benchmark.py", line 188, in matmul
    matmul_kernel_with_block_pointers[grid](
  File "/home/jovyan/intel-xpu-backend-for-triton/python/triton/runtime/jit.py", line 330, in <lambda>
    return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
  File "/home/jovyan/intel-xpu-backend-for-triton/python/triton/runtime/autotuner.py", line 168, in run
    ret = self.fn.run(
  File "/home/jovyan/intel-xpu-backend-for-triton/python/triton/runtime/jit.py", line 687, in run
    kernel.run(grid_0, grid_1, grid_2, stream, kernel.function, kernel.packed_metadata, launch_metadata,
  File "/home/jovyan/intel-xpu-backend-for-triton/python/triton/compiler/compiler.py", line 417, in __getattribute__
    self._init_handles()
  File "/home/jovyan/intel-xpu-backend-for-triton/python/triton/compiler/compiler.py", line 410, in _init_handles
    raise OutOfResources(self.metadata.shared, max_shared, "shared memory")

Notes

  1. the benchmark has several auto tuning configurations (including the one above causing compilation failure). The auto tuner hides the functional failure, and ends up choosing a configuration that allows the benchmark to compile. The configuration chosen has block sizes that are too small and consequently the achievable performance is also significantly lower than that of a GEMM without any postOp operations.
  2. I believe that fixing the functional problem would "unlock" performance for this benchmark. For a small block size ('BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32,) the performance is just a bit less than a GEMM kernel with no postOp:

Performance of GEMM (no postOp)

     B       M       K       N  Triton-GB/s  Triton-GB/s-min  Triton-GB/s-max  Triton-TFlops  Triton-TFlops-min  Triton-TFlops-max  Triton-CV
0  1.0  4096.0  4096.0  4096.0    61.981733        61.977154        61.986312      63.469294          63.464605          63.473983   0.000104

Performance of GEMM + postOp (add matrix to result)

     B       M       K       N  Triton-GB/s  Triton-GB/s-min  Triton-GB/s-max  Triton-TFlops  Triton-TFlops-min  Triton-TFlops-max  Triton-CV
0  1.0  4096.0  4096.0  4096.0    54.684536        54.450264        54.920832      55.996965           55.75707          56.238932   0.006085
@etiotto
Copy link
Contributor Author

etiotto commented Sep 27, 2024

Related to #1716

@etiotto
Copy link
Contributor Author

etiotto commented Sep 30, 2024

Shared memory is required in order to convert layouts (by the triton_gpu.convert_layout operation). Consider the following code:

    %15:3 = scf.for %arg4 = %c0_i32 to %c4096_i32 step %c32_i32 iter_args(%arg5 = %cst, %arg6 = %10, %arg7 = %12) -> (tensor<256x256xf32, #mma>, !tt.ptr<tensor<256x32xbf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>>>, !tt.ptr<tensor<32x256xbf16, #triton_gpu.dot_op<{opIdx = 1, parent
 = #mma, kWidth = 2}>>>)  : i32 {
      %c1_i32 = arith.constant 1 : i32 loc(#loc15)
      %56 = arith.muli %c32_i32, %c1_i32 : i32 loc(#loc15)
      %57 = arith.subi %c4096_i32, %56 : i32 loc(#loc15)
      %58 = arith.cmpi slt, %arg4, %57 : i32 loc(#loc15)
      %59 = tt.advance %arg6, [%c0_i32, %c32_i32] : <tensor<256x32xbf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>>> loc(#loc18)
      %60 = tt.advance %arg7, [%c32_i32, %c0_i32] : <tensor<32x256xbf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>> loc(#loc19)
      triton_intel_gpu.prefetch %59 {boundaryCheck = array<i32: 0, 1>, cache = 1 : i32, evict = 1 : i32, isVolatile = false, operandSegmentSizes = array<i32: 1, 0, 0>, triton_intel_gpu.block_io = "row_major"} : !tt.ptr<tensor<256x32xbf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>>> 
loc(#loc16)
      triton_intel_gpu.prefetch %60 {boundaryCheck = array<i32: 0, 1>, cache = 1 : i32, evict = 1 : i32, isVolatile = false, operandSegmentSizes = array<i32: 1, 0, 0>, triton_intel_gpu.block_io = "row_major"} : !tt.ptr<tensor<32x256xbf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>> 
loc(#loc17)
      %61 = tt.load %arg6 {boundaryCheck = array<i32: 0, 1>, triton_intel_gpu.block_io = "row_major"} : !tt.ptr<tensor<256x32xbf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>>> loc(#loc16)
      %62 = tt.load %arg7 {boundaryCheck = array<i32: 0, 1>, triton_intel_gpu.block_io = "row_major"} : !tt.ptr<tensor<32x256xbf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>> loc(#loc17)
      %63 = tt.dot %61, %62, %arg5, inputPrecision = tf32 : tensor<256x32xbf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<32x256xbf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<256x256xf32, #mma> loc(#loc20)
      scf.yield %63, %59, %60 : tensor<256x256xf32, #mma>, !tt.ptr<tensor<256x32xbf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>>>, !tt.ptr<tensor<32x256xbf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>> loc(#loc15)
    } loc(#loc15)
    %16 = arith.extsi %9 : i32 to i64 loc(#loc21)
    %17 = arith.extsi %11 : i32 to i64 loc(#loc21)
    %18 = tt.splat %arg3 : !tt.ptr<f32> -> tensor<256x256x!tt.ptr<f32>, #mma> loc(#loc22)
    %19 = tt.splat %16 : i64 -> tensor<256xi64, #triton_gpu.slice<{dim = 1, parent = #mma}>> loc(#loc22)
    %20 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #triton_gpu.slice<{dim = 1, parent = #mma}>> loc(#loc22)
    %21 = arith.extsi %20 : tensor<256xi32, #triton_gpu.slice<{dim = 1, parent = #mma}>> to tensor<256xi64, #triton_gpu.slice<{dim = 1, parent = #mma}>> loc(#loc22)
    %22 = arith.addi %19, %21 : tensor<256xi64, #triton_gpu.slice<{dim = 1, parent = #mma}>> loc(#loc22)
    %23 = tt.expand_dims %22 {axis = 1 : i32} : tensor<256xi64, #triton_gpu.slice<{dim = 1, parent = #mma}>> -> tensor<256x1xi64, #mma> loc(#loc22)
    %24 = tt.splat %c4096_i64 : i64 -> tensor<256x1xi64, #mma> loc(#loc22)
    %25 = arith.muli %23, %24 : tensor<256x1xi64, #mma> loc(#loc22)
    %26 = tt.broadcast %25 : tensor<256x1xi64, #mma> -> tensor<256x256xi64, #mma> loc(#loc22)
    %27 = tt.addptr %18, %26 : tensor<256x256x!tt.ptr<f32>, #mma>, tensor<256x256xi64, #mma> loc(#loc22)
    %28 = tt.splat %17 : i64 -> tensor<256xi64, #triton_gpu.slice<{dim = 0, parent = #mma}>> loc(#loc22)
    %29 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #triton_gpu.slice<{dim = 0, parent = #mma}>> loc(#loc22)
    %30 = arith.extsi %29 : tensor<256xi32, #triton_gpu.slice<{dim = 0, parent = #mma}>> to tensor<256xi64, #triton_gpu.slice<{dim = 0, parent = #mma}>> loc(#loc22)
    %31 = arith.addi %28, %30 : tensor<256xi64, #triton_gpu.slice<{dim = 0, parent = #mma}>> loc(#loc22)
    %32 = tt.expand_dims %31 {axis = 0 : i32} : tensor<256xi64, #triton_gpu.slice<{dim = 0, parent = #mma}>> -> tensor<1x256xi64, #mma> loc(#loc22)
    %33 = tt.splat %c1_i64 : i64 -> tensor<1x256xi64, #mma> loc(#loc22)
    %34 = arith.muli %32, %33 : tensor<1x256xi64, #mma> loc(#loc22)
    %35 = tt.broadcast %34 : tensor<1x256xi64, #mma> -> tensor<256x256xi64, #mma> loc(#loc22)
    %36 = tt.addptr %27, %35 : tensor<256x256x!tt.ptr<f32>, #mma>, tensor<256x256xi64, #mma> loc(#loc22)
    %c0_i64 = arith.constant 0 : i64 loc(#loc22)
    %37 = tt.splat %c0_i64 : i64 -> tensor<256x1xi64, #mma> loc(#loc22)
    %38 = arith.cmpi sge, %23, %37 : tensor<256x1xi64, #mma> loc(#loc22)
    %39 = tt.splat %c4096_i64 : i64 -> tensor<256x1xi64, #mma> loc(#loc22)
    %40 = arith.cmpi slt, %23, %39 : tensor<256x1xi64, #mma> loc(#loc22)
    %41 = arith.andi %38, %40 : tensor<256x1xi1, #mma> loc(#loc22)
    %42 = tt.broadcast %41 : tensor<256x1xi1, #mma> -> tensor<256x256xi1, #mma> loc(#loc22)
    %c0_i64_1 = arith.constant 0 : i64 loc(#loc22)
    %43 = tt.splat %c0_i64_1 : i64 -> tensor<1x256xi64, #mma> loc(#loc22)
    %44 = arith.cmpi sge, %32, %43 : tensor<1x256xi64, #mma> loc(#loc22)
    %45 = tt.splat %c4096_i64 : i64 -> tensor<1x256xi64, #mma> loc(#loc22)
    %46 = arith.cmpi slt, %32, %45 : tensor<1x256xi64, #mma> loc(#loc22)
    %47 = arith.andi %44, %46 : tensor<1x256xi1, #mma> loc(#loc22)
    %48 = tt.broadcast %47 : tensor<1x256xi1, #mma> -> tensor<256x256xi1, #mma> loc(#loc22)
    %49 = arith.andi %42, %48 : tensor<256x256xi1, #mma> loc(#loc22)
    %50 = triton_gpu.convert_layout %36 : tensor<256x256x!tt.ptr<f32>, #mma> -> tensor<256x256x!tt.ptr<f32>, #blocked> loc(#loc22)
    %51 = triton_gpu.convert_layout %49 : tensor<256x256xi1, #mma> -> tensor<256x256xi1, #blocked> loc(#loc22)
    %52 = tt.load %50, %51 : tensor<256x256x!tt.ptr<f32>, #blocked> loc(#loc22)
    %53 = triton_gpu.convert_layout %52 : tensor<256x256xf32, #blocked> -> tensor<256x256xf32, #mma> loc(#loc22)
    %54 = arith.addf %15#0, %53 : tensor<256x256xf32, #mma> loc(#loc23)
    %55 = tt.make_tensor_ptr %arg2, [%c4096_i64, %c4096_i64], [%c4096_i64, %c1_i64], [%9, %11] {order = array<i32: 1, 0>} : <tensor<256x256xf32, #mma>> loc(#loc24)
    tt.store %55, %54 {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<256x256xf32, #mma>> loc(#loc25)

And let's focus on this snippet:

   %50 = triton_gpu.convert_layout %36 : tensor<256x256x!tt.ptr<f32>, #mma> -> tensor<256x256x!tt.ptr<f32>, #blocked> loc(#loc22)
   %51 = triton_gpu.convert_layout %49 : tensor<256x256xi1, #mma> -> tensor<256x256xi1, #blocked> loc(#loc22)
   %52 = tt.load %50, %51 : tensor<256x256x!tt.ptr<f32>, #blocked> loc(#loc22)
   %53 = triton_gpu.convert_layout %52 : tensor<256x256xf32, #blocked> -> tensor<256x256xf32, #mma> loc(#loc22)
   %54 = arith.addf %15#0, %53 : tensor<256x256xf32, #mma> loc(#loc23)
   %55 = tt.make_tensor_ptr %arg2, [%c4096_i64, %c4096_i64], [%c4096_i64, %c1_i64], [%9, %11] {order = array<i32: 1, 0>} : <tensor<256x256xf32, #mma>> loc(#loc24)
   tt.store %55, %54 {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<256x256xf32, #mma>> loc(#loc25)

Here the tt.load requires its operands (%50, %51) to be in blocked layout, and therefore 2 convert_layoput operations were injected. Then the result of the load is converted to the dpas layout (#mma) in order to feed the arith.addf operation (which #mma layout).

I think a possible solution is to back propagate the #mma layout from the arith.addf operation to the tt.load operation yielding %52. If that load had the #mma layout then all the convert_layout operations would become noop and could be removed.

What Triton does instead is (eventually) to convert the #mma layout produced by the GEMM loop into a blocked layout, and it then keep that blocked layout all the way to the final store. This works fine is the available shared memory size is sufficiently large. On our current GPU (PVC) the size of shared memory is about half what is needed to use a block size of 256X256. If we instead maximized the #mma layout as I described above, no layout conversion would be necessary and therefore no shared memory would be required in order to run the kernel. At that point we should actually be able to generate 2D read/store operations for the tt.load and tt.store mentioned above.

@etiotto etiotto assigned etiotto and unassigned LiyangLingIntel Sep 30, 2024
@etiotto
Copy link
Contributor Author

etiotto commented Oct 1, 2024

The issue is caused by the RewrteTensorPointer pass. Given:

    %13:3 = scf.for %arg4 = %c0_i32 to %c4096_i32 step %c32_i32 iter_args(%arg5 = %cst, %arg6 = %10, %arg7 = %12) -> (tensor<256x256xf32, #mma>, !tt.ptr<tensor<256x32xbf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>>>, !tt.ptr<tensor<32x256xbf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>>)  : i32 {
      %18 = tt.load %arg6 {boundaryCheck = array<i32: 0, 1>, triton_intel_gpu.block_io = "row_major"} : !tt.ptr<tensor<256x32xbf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>>> loc(#loc)
      %19 = tt.load %arg7 {boundaryCheck = array<i32: 0, 1>, triton_intel_gpu.block_io = "row_major"} : !tt.ptr<tensor<32x256xbf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>> loc(#loc)
      %20 = tt.dot %18, %19, %arg5, inputPrecision = tf32 : tensor<256x32xbf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<32x256xbf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<256x256xf32, #mma> loc(#loc)
      %21 = tt.advance %arg6, [%c0_i32, %c32_i32] : <tensor<256x32xbf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>>> loc(#loc)
      %22 = tt.advance %arg7, [%c32_i32, %c0_i32] : <tensor<32x256xbf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>> loc(#loc)
      scf.yield %20, %21, %22 : tensor<256x256xf32, #mma>, !tt.ptr<tensor<256x32xbf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>>>, !tt.ptr<tensor<32x256xbf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>> loc(#loc)
    } loc(#loc)
    %14 = tt.make_tensor_ptr %arg3, [%c4096_i64, %c4096_i64], [%c4096_i64, %c1_i64], [%9, %11] {order = array<i32: 1, 0>} : <tensor<256x256xf32, #mma>> loc(#loc)
    %15 = tt.load %14 {boundaryCheck = array<i32: 0, 1>, triton_intel_gpu.block_io = "row_major"} : !tt.ptr<tensor<256x256xf32, #mma>> loc(#loc)
    %16 = arith.addf %13#0, %15 : tensor<256x256xf32, #mma> loc(#loc)
    %17 = tt.make_tensor_ptr %arg2, [%c4096_i64, %c4096_i64], [%c4096_i64, %c1_i64], [%9, %11] {order = array<i32: 1, 0>} : <tensor<256x256xf32, #mma>> loc(#loc)
    tt.store %17, %16 {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<256x256xf32, #mma>> loc(#loc)

RewriteTensorPointers changes:

    %14 = tt.make_tensor_ptr %arg3, [%c4096_i64, %c4096_i64], [%c4096_i64, %c1_i64], [%9, %11] {order = array<i32: 1, 0>} : <tensor<256x256xf32, #mma>> loc(#loc)
    %15 = tt.load %14 {boundaryCheck = array<i32: 0, 1>, triton_intel_gpu.block_io = "row_major"} : !tt.ptr<tensor<256x256xf32, #mma>> loc(#loc)

into:

   %14 = arith.extsi %9 : i32 to i64 loc(#loc)
    %15 = arith.extsi %11 : i32 to i64 loc(#loc)
    %16 = tt.splat %arg3 : !tt.ptr<f32> -> tensor<256x256x!tt.ptr<f32>, #mma> loc(#loc)
    %17 = tt.splat %14 : i64 -> tensor<256xi64, #triton_gpu.slice<{dim = 1, parent = #mma}>> loc(#loc)
    %18 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #triton_gpu.slice<{dim = 1, parent = #mma}>> loc(#loc)
    %19 = arith.extsi %18 : tensor<256xi32, #triton_gpu.slice<{dim = 1, parent = #mma}>> to tensor<256xi64, #triton_gpu.slice<{dim = 1, parent = #mma}>> loc(#loc)
    %20 = arith.addi %17, %19 : tensor<256xi64, #triton_gpu.slice<{dim = 1, parent = #mma}>> loc(#loc)
    %21 = tt.expand_dims %20 {axis = 1 : i32} : tensor<256xi64, #triton_gpu.slice<{dim = 1, parent = #mma}>> -> tensor<256x1xi64, #mma> loc(#loc)
    %22 = tt.splat %c4096_i64 : i64 -> tensor<256x1xi64, #mma> loc(#loc)
    %23 = arith.muli %21, %22 : tensor<256x1xi64, #mma> loc(#loc)
    %24 = tt.broadcast %23 : tensor<256x1xi64, #mma> -> tensor<256x256xi64, #mma> loc(#loc)
    %25 = tt.addptr %16, %24 : tensor<256x256x!tt.ptr<f32>, #mma>, tensor<256x256xi64, #mma> loc(#loc)
    %26 = tt.splat %15 : i64 -> tensor<256xi64, #triton_gpu.slice<{dim = 0, parent = #mma}>> loc(#loc)
    %27 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #triton_gpu.slice<{dim = 0, parent = #mma}>> loc(#loc)
    %28 = arith.extsi %27 : tensor<256xi32, #triton_gpu.slice<{dim = 0, parent = #mma}>> to tensor<256xi64, #triton_gpu.slice<{dim = 0, parent = #mma}>> loc(#loc)
    %29 = arith.addi %26, %28 : tensor<256xi64, #triton_gpu.slice<{dim = 0, parent = #mma}>> loc(#loc)
    %30 = tt.expand_dims %29 {axis = 0 : i32} : tensor<256xi64, #triton_gpu.slice<{dim = 0, parent = #mma}>> -> tensor<1x256xi64, #mma> loc(#loc)
    %31 = tt.splat %c1_i64 : i64 -> tensor<1x256xi64, #mma> loc(#loc)
    %32 = arith.muli %30, %31 : tensor<1x256xi64, #mma> loc(#loc)
    %33 = tt.broadcast %32 : tensor<1x256xi64, #mma> -> tensor<256x256xi64, #mma> loc(#loc)
    %34 = tt.addptr %25, %33 : tensor<256x256x!tt.ptr<f32>, #mma>, tensor<256x256xi64, #mma> loc(#loc)
    %c0_i64 = arith.constant 0 : i64 loc(#loc)
    %35 = tt.splat %c0_i64 : i64 -> tensor<256x1xi64, #mma> loc(#loc)
    %36 = arith.cmpi sge, %21, %35 : tensor<256x1xi64, #mma> loc(#loc)
    %37 = tt.splat %c4096_i64 : i64 -> tensor<256x1xi64, #mma> loc(#loc)
    %38 = arith.cmpi slt, %21, %37 : tensor<256x1xi64, #mma> loc(#loc)
    %39 = arith.andi %36, %38 : tensor<256x1xi1, #mma> loc(#loc)
    %40 = tt.broadcast %39 : tensor<256x1xi1, #mma> -> tensor<256x256xi1, #mma> loc(#loc)
    %c0_i64_0 = arith.constant 0 : i64 loc(#loc)
    %41 = tt.splat %c0_i64_0 : i64 -> tensor<1x256xi64, #mma> loc(#loc)
    %42 = arith.cmpi sge, %30, %41 : tensor<1x256xi64, #mma> loc(#loc)
    %43 = tt.splat %c4096_i64 : i64 -> tensor<1x256xi64, #mma> loc(#loc)
    %44 = arith.cmpi slt, %30, %43 : tensor<1x256xi64, #mma> loc(#loc)
    %45 = arith.andi %42, %44 : tensor<1x256xi1, #mma> loc(#loc)
    %46 = tt.broadcast %45 : tensor<1x256xi1, #mma> -> tensor<256x256xi1, #mma> loc(#loc)
    %47 = arith.andi %40, %46 : tensor<256x256xi1, #mma> loc(#loc)
    %48 = tt.load %34, %47 : tensor<256x256x!tt.ptr<f32>, #mma> loc(#loc)

Note that the pass materializes the following load using a tensor of ptrs:

   %48 = tt.load %34, %47 : tensor<256x256x!tt.ptr<f32>, #mma> loc(#loc)

This is the root of the problem. Subsequently that load is transformed to have blocked layout by the TritonGPUCoalesce pass (because blocked layout allow generation of vector instructions) . This is how we get into a situation where layout conversion operations need to be injected to convert from #mma layout (DPAS) to #blocked layouts.

I think that ultimately we want to preserve the blocked pointers and prevent RewriteTensorPointer to change the load.

A quick experiment reveals that removing RewriteTensorPointer causes block size of 256x256 to compile without incurring into the lack of shared memory problem. The benchmark performance then "improves" from ~55TFlops to ~130TFlops for a 4Kx4Kx4K problem size.

@etiotto
Copy link
Contributor Author

etiotto commented Oct 1, 2024

The tentative fix is to avoid rewriting blocked poiters if they are used by tt.load operations that have DPAS layout (similar to what is already done for blocked pointers used by tt.store operations with DPAS layout). This is a conservative fix because the pass also check that the load can be lowered to a 2D block read operation (has correct pitch, etc...).

In the longer term (after #2374 is fixed) we should be in a position to remove the RewrteTensorPointer pass altogether.

@etiotto
Copy link
Contributor Author

etiotto commented Oct 3, 2024

Done, performance improved significantly:

image

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment