diff --git a/.github/actions/setup-pytorch/action.yml b/.github/actions/setup-pytorch/action.yml index 726df36f11..16fba7afa3 100644 --- a/.github/actions/setup-pytorch/action.yml +++ b/.github/actions/setup-pytorch/action.yml @@ -16,6 +16,9 @@ inputs: ref: description: Branch, tag, commit id default: "" + mode: + description: Source or wheels + default: source runs: using: "composite" steps: @@ -71,7 +74,7 @@ runs: - name: Generate PyTorch cache key shell: bash run: | - PYTORCH_CACHE_KEY=$(echo $PYTHON_VERSION $PYTORCH_COMMIT_ID ${{ hashFiles('scripts/patch-pytorch.sh') }} | sha256sum - | cut -d\ -f1) + PYTORCH_CACHE_KEY=$(echo $PYTHON_VERSION $PYTORCH_COMMIT_ID ${{ hashFiles('scripts/patch-pytorch.sh') }} ${{ inputs.mode }} | sha256sum - | cut -d\ -f1) echo "PYTORCH_CACHE_KEY=$PYTORCH_CACHE_KEY" | tee -a "$GITHUB_ENV" - name: Load PyTorch from a cache @@ -90,11 +93,12 @@ runs: with: repository: ${{ env.PYTORCH_REPO }} ref: ${{ env.PYTORCH_COMMIT_ID }} - submodules: recursive + # To build PyTorch from source we need all submodules, they are not required for benchmarks + submodules: ${{ inputs.mode == 'source' && 'recursive' || 'false' }} path: pytorch - name: Apply additional PR patches - if: ${{ steps.pytorch-cache.outputs.status == 'miss' && inputs.repository == 'pytorch/pytorch' }} + if: ${{ steps.pytorch-cache.outputs.status == 'miss' && inputs.repository == 'pytorch/pytorch' && inputs.mode == 'source' }} shell: bash run: | cd pytorch @@ -108,7 +112,7 @@ runs: pip install 'numpy<2.0.0' - name: Build PyTorch - if: ${{ steps.pytorch-cache.outputs.status == 'miss' }} + if: ${{ steps.pytorch-cache.outputs.status == 'miss' && inputs.mode == 'source' }} shell: bash run: | source ${{ inputs.oneapi }}/setvars.sh @@ -117,11 +121,24 @@ runs: pip install -r requirements.txt python setup.py bdist_wheel - - name: Install PyTorch + - name: Install PyTorch (built from source) + if: ${{ inputs.mode == 'source' }} shell: bash run: | source ${{ inputs.oneapi }}/setvars.sh pip install pytorch/dist/*.whl + + - name: Install PyTorch (from wheels) + if: ${{ inputs.mode == 'wheels' }} + shell: bash + run: | + source ${{ inputs.oneapi }}/setvars.sh + pip install torch --index-url https://download.pytorch.org/whl/nightly/xpu + + - name: Get PyTorch version + shell: bash + run: | + source ${{ inputs.oneapi }}/setvars.sh PYTORCH_VERSION="$(python -c 'import torch;print(torch.__version__)')" echo "PYTORCH_VERSION=$PYTORCH_VERSION" | tee -a "$GITHUB_ENV" diff --git a/.github/workflows/build-test-gpu.yml b/.github/workflows/build-test-gpu.yml index 84442d0283..7fe8511b03 100644 --- a/.github/workflows/build-test-gpu.yml +++ b/.github/workflows/build-test-gpu.yml @@ -12,6 +12,13 @@ on: description: PyTorch ref, keep empty for default type: string default: "" + pytorch_mode: + description: PyTorch mode, source or wheels + type: choice + options: + - source + - wheels + default: source upload_test_reports: description: Upload test reports type: boolean @@ -46,6 +53,7 @@ jobs: device: ${{ inputs.runner_label }} runner_label: ${{ inputs.runner_label }} pytorch_ref: ${{ inputs.pytorch_ref }} + pytorch_mode: ${{ inputs.pytorch_mode || 'source' }} python_version: ${{ matrix.python }} upload_test_reports: ${{ inputs.upload_test_reports }} ignore_errors: ${{ inputs.ignore_errors }} diff --git a/.github/workflows/build-test-reusable.yml b/.github/workflows/build-test-reusable.yml index c5aa02778e..cc623a8100 100644 --- a/.github/workflows/build-test-reusable.yml +++ b/.github/workflows/build-test-reusable.yml @@ -20,6 +20,10 @@ on: description: PyTorch ref, keep empty for default type: string default: "" + pytorch_mode: + description: PyTorch mode, source or wheels + type: string + default: "source" python_version: description: Python version type: string @@ -96,6 +100,7 @@ jobs: with: repository: pytorch/pytorch ref: ${{ inputs.pytorch_ref }} + mode: ${{ inputs.pytorch_mode }} - name: Install pass_rate dependencies run: | diff --git a/.github/workflows/build-test.yml b/.github/workflows/build-test.yml index dfad5d7136..d3d9b29b56 100644 --- a/.github/workflows/build-test.yml +++ b/.github/workflows/build-test.yml @@ -12,6 +12,13 @@ on: description: PyTorch ref, keep empty for default type: string default: "" + pytorch_mode: + description: PyTorch mode, source or wheels + type: choice + options: + - source + - wheels + default: source upload_test_reports: description: Upload test reports type: boolean @@ -120,6 +127,7 @@ jobs: driver_version: ${{ matrix.driver }} runner_label: ${{ inputs.runner_label }} pytorch_ref: ${{ inputs.pytorch_ref }} + pytorch_mode: ${{ inputs.pytorch_mode || 'source' }} python_version: ${{ matrix.python }} upload_test_reports: ${{ inputs.upload_test_reports || false }} ignore_errors: ${{ inputs.ignore_errors || false }} diff --git a/.github/workflows/triton-benchmarks.yml b/.github/workflows/triton-benchmarks.yml index f650d96845..0b0566589d 100644 --- a/.github/workflows/triton-benchmarks.yml +++ b/.github/workflows/triton-benchmarks.yml @@ -159,6 +159,17 @@ jobs: TAG=${{ inputs.tag || 'ci' }}-adv python ../../scripts/build_report.py $REPORTS/matmul-performance-adv-path.csv $REPORTS/gemm-triton-advanced-report.csv --benchmark gemm --compiler triton --param_cols "B,M,K,N" --tflops_col Triton-TFlops --hbm_col "Triton-GB/s" --tag $TAG + - name: Run Triton GEMM (A@B^t) kernel benchmark + if: ${{ steps.install.outcome == 'success' && !cancelled() }} + run: | + cd benchmarks/triton_kernels_benchmark + python gemm_bt_benchmark.py --reports $REPORTS + mv $REPORTS/matmul-performance.csv $REPORTS/matmul-performance-bt.csv + source ../../scripts/capture-hw-details.sh + + TAG=${{ inputs.tag || 'ci' }} + python ../../scripts/build_report.py $REPORTS/matmul-performance-bt.csv $REPORTS/gemm-bt-triton-report.csv --benchmark gemm-bt --compiler triton --param_cols "B,M,K,N" --tflops_col Triton-TFlops --hbm_col "Triton-GB/s" --tag $TAG + - name: Run Triton GEMM (stream-k) kernel benchmark if: ${{ steps.install.outcome == 'success' && !cancelled() }} run: | diff --git a/benchmarks/triton_kernels_benchmark/gemm_bt_benchmark.py b/benchmarks/triton_kernels_benchmark/gemm_bt_benchmark.py new file mode 100644 index 0000000000..b23786a790 --- /dev/null +++ b/benchmarks/triton_kernels_benchmark/gemm_bt_benchmark.py @@ -0,0 +1,292 @@ +""" +Gemm with A@B^t benchmark +==================================== + +This benchmark is modified from gemm_benchmark.py with added transpose. +""" + +import torch +import triton +import triton.language as tl + +import triton_kernels_benchmark as benchmark_suit +import xetla_kernel + +if benchmark_suit.USE_IPEX_OPTION: + import intel_extension_for_pytorch # type: ignore # noqa: F401 + + +@triton.autotune( + configs=[ + triton.Config( + {'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4, 'grf_mode': 'large'}, + num_stages=s, num_warps=32) for s in [1, 2, 3] + ] + [ + triton.Config( + {'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4, 'grf_mode': 'large'}, + num_stages=s, num_warps=32) for s in [2, 3] + ] + [ + triton.Config( + {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4, 'grf_mode': 'large'}, + num_stages=s, num_warps=32) for s in [2] + ] + [ + triton.Config( + {'BLOCK_SIZE_M': 8, 'BLOCK_SIZE_N': 512, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 1, 'grf_mode': 'large'}, + num_stages=s, num_warps=32) for s in [2, 3] + ], + key=['M', 'N', 'K'], +) +@triton.jit +def matmul_kernel_with_block_pointers( + # Pointers to matrices + a_ptr, b_ptr, c_ptr, + # Matrix dimensions + M: tl.constexpr, N: tl.constexpr, K: tl.constexpr, + # Stride variables + stride_am: tl.constexpr, stride_ak: tl.constexpr, # + stride_bk: tl.constexpr, stride_bn: tl.constexpr, # + stride_cm: tl.constexpr, stride_cn: tl.constexpr, + # Meta-parameters + BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr): + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + (pid % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + a_block_ptr = tl.make_block_ptr(base=a_ptr, shape=(M, K), strides=(stride_am, stride_ak), + offsets=(pid_m * BLOCK_SIZE_M, 0), block_shape=(BLOCK_SIZE_M, BLOCK_SIZE_K), + order=(1, 0)) + b_block_ptr = tl.make_block_ptr(base=b_ptr, shape=(K, N), strides=(stride_bk, stride_bn), + offsets=(0, pid_n * BLOCK_SIZE_N), block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_N), + order=(1, 0)) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for _ in range(0, K, BLOCK_SIZE_K): + a = tl.load(a_block_ptr, boundary_check=(0, 1)) + b = tl.load(b_block_ptr, boundary_check=(0, 1)) + accumulator += tl.dot(a, b) + a_block_ptr = tl.advance(a_block_ptr, (0, BLOCK_SIZE_K)) + b_block_ptr = tl.advance(b_block_ptr, (BLOCK_SIZE_K, 0)) + c = accumulator.to(tl.float32) + + c_block_ptr = tl.make_block_ptr(base=c_ptr, shape=(M, N), strides=(stride_cm, stride_cn), + offsets=(pid_m * BLOCK_SIZE_M, pid_n * BLOCK_SIZE_N), + block_shape=(BLOCK_SIZE_M, BLOCK_SIZE_N), order=(1, 0)) + tl.store(c_block_ptr, c, boundary_check=(0, 1)) + + +# pylint: disable=unused-argument +@triton.autotune( + configs=[ + triton.Config( + {'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4, 'grf_mode': 'large'}, + num_stages=s, num_warps=32) for s in [2, 3] + ] + [ + triton.Config( + {'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4, 'grf_mode': 'large'}, + num_stages=s, num_warps=32) for s in [2] + ] + [ + triton.Config( + {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4, 'grf_mode': 'large'}, + num_stages=s, num_warps=32) for s in [2] + ] + [ + triton.Config( + {'BLOCK_SIZE_M': 8, 'BLOCK_SIZE_N': 512, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 1, 'grf_mode': 'large'}, + num_stages=s, num_warps=32) for s in [2] + ] + [ + triton.Config( + {'BLOCK_SIZE_M': 8, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 1, 'grf_mode': 'large'}, + num_stages=s, num_warps=4) for s in [2] + ], + key=['M', 'N', 'K'], +) +@triton.jit +def matmul_kernel_with_block_pointers_batched( + # Pointers to matrices + a_ptr, b_ptr, c_ptr, + # Matrix dimensions + B: tl.constexpr, M: tl.constexpr, N: tl.constexpr, K: tl.constexpr, + # Stride variables + stride_az: tl.constexpr, stride_am: tl.constexpr, stride_ak: tl.constexpr, # + stride_bz: tl.constexpr, stride_bk: tl.constexpr, stride_bn: tl.constexpr, # + stride_cz: tl.constexpr, stride_cm: tl.constexpr, stride_cn: tl.constexpr, + # Meta-parameters + BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr): + bid = tl.program_id(axis=0) + pid = tl.program_id(axis=1) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + (pid % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + offset_a = bid.to(tl.int64) * stride_az + offset_b = bid.to(tl.int64) * stride_bz + + a_block_ptr = tl.make_block_ptr(base=a_ptr + offset_a, shape=(M, K), strides=(stride_am, stride_ak), + offsets=(pid_m * BLOCK_SIZE_M, 0), block_shape=(BLOCK_SIZE_M, BLOCK_SIZE_K), + order=(1, 0)) + b_block_ptr = tl.make_block_ptr(base=b_ptr + offset_b, shape=(K, N), strides=(stride_bk, stride_bn), + offsets=(0, pid_n * BLOCK_SIZE_N), block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_N), + order=(1, 0)) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for _ in range(0, K, BLOCK_SIZE_K): + a = tl.load(a_block_ptr, boundary_check=(0, 1)) + b = tl.load(b_block_ptr, boundary_check=(0, 1)) + accumulator += tl.dot(a, b) + a_block_ptr = tl.advance(a_block_ptr, (0, BLOCK_SIZE_K)) + b_block_ptr = tl.advance(b_block_ptr, (BLOCK_SIZE_K, 0)) + c = accumulator.to(tl.float32) + + offset_c = bid.to(tl.int64) * stride_cz + c_block_ptr = tl.make_block_ptr(base=c_ptr + offset_c, shape=(M, N), strides=(stride_cm, stride_cn), + offsets=(pid_m * BLOCK_SIZE_M, pid_n * BLOCK_SIZE_N), + block_shape=(BLOCK_SIZE_M, BLOCK_SIZE_N), order=(1, 0)) + tl.store(c_block_ptr, c, boundary_check=(0, 1)) + + +# We can now create a convenience wrapper function that only takes two input tensors, +# and (1) checks any shape constraint; (2) launches the above kernel. +def matmul(a, b, c): + # Check constraints. + if len(a.shape) == 3 and len(b.shape) == 3: + assert a.shape[0] == b.shape[0], 'Incompatible Batch dimension' + assert a.shape[2] == b.shape[2], 'Incompatible dimensions' + assert a.is_contiguous(), 'Matrix A must be contiguous' + assert b.is_contiguous(), 'Matrix B must be contiguous' + B, M, K = a.shape + B, N, K = b.shape + # 1D launch kernel where each block gets its own program. + grid = lambda META: ( + B, + triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), + ) + matmul_kernel_with_block_pointers_batched[grid]( + a, b, c, # + B, M, N, K, # + a.stride(0), a.stride(1), a.stride(2), # + b.stride(0), b.stride(2), b.stride(1), # + c.stride(0), c.stride(1), c.stride(2)) + elif len(a.shape) == 2 and len(b.shape) == 2: + assert a.shape[1] == b.shape[1], 'Incompatible dimensions' + assert a.is_contiguous(), 'Matrix A must be contiguous' + assert b.is_contiguous(), 'Matrix B must be contiguous' + M, K = a.shape + N, K = b.shape + grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), ) + matmul_kernel_with_block_pointers[grid]( + a, b, c, # + M, N, K, # + a.stride(0), a.stride(1), # + b.stride(1), b.stride(0), # + c.stride(0), c.stride(1)) + else: + assert False, 'Input matrixs dimensions mismatch' + return c + + +# Benchmark Performance +@benchmark_suit.perf_report( + benchmark_suit.Benchmark( + # argument names to use as an x-axis for the plot + x_names=['B', 'M', 'K', 'N'], + # different possible values for `x_name` + x_vals=[[1, 1024 * i, 1024 * i, 1024 * i] for i in [1, 2, 4, 8]] + # + [ # + [1, 1, 5120, 13824], # + [1, 4, 4096, 12288], # + [1, 512, 8192, 8192], # + [1, 512, 8192, 32768], # + [1, 512, 32768, 8192], # + [1, 1024, 16384, 8192], # + [1, 1024, 28672, 8192], # + [1, 3072, 4096, 3072], # FIXME: Remove this case when gemm_streamk_benchmark works + [1, 4096, 16384, 8192], # + [1, 8192, 16384, 1024], # + [1, 8192, 16384, 4096], # + [1, 16384, 1024, 8192], # + [1, 16384, 4096, 8192], # + [1, 16384, 8192, 1024], # + [1, 16384, 8192, 4096], # + [4, 32768, 128, 4096], # + [4, 32768, 4096, 128], # + [32, 4096, 4096, 128], # + [4096, 8, 128, 16384], # + [4096, 8, 16384, 128] + ], + line_arg='provider', + # argument name whose value corresponds to a different line in the plot + # possible values for `line_arg`` + line_vals=['triton'], + # label name for the lines + line_names=['Triton'], + # line styles + styles=[('green', '-'), ('green', '--'), ('blue', '-'), ('blue', '--')], + ylabel=['GB/s', 'TFlops'], # label name for the y-axis + plot_name='matmul-performance', + # name for the plot. Used also as a file name for saving the plot. + args={}, + )) +def benchmark(B, M, N, K, provider): + if B == 1: + a = torch.rand((M, K), device='xpu', dtype=torch.bfloat16) + b = torch.rand((N, K), device='xpu', dtype=torch.bfloat16) + else: + a = torch.rand((B, M, K), device='xpu', dtype=torch.bfloat16) + b = torch.rand((B, N, K), device='xpu', dtype=torch.bfloat16) + + quantiles = [0.5, 0.0, 1.0] + + if provider == 'onednn': + _, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(lambda: torch.matmul(a, torch.transpose(b, -1, -2)), + warmup=10, rep=10, quantiles=quantiles, + fast_flush=False) + elif provider == 'triton': + assert len(a.shape) == len(b.shape), 'Incompatible sizes' + if len(a.shape) == 3: + c = torch.empty((B, M, N), device='xpu', dtype=torch.float32) + else: + assert len(a.shape) == 2, 'Expecting shape of length 2' + c = torch.empty((M, N), device='xpu', dtype=torch.float32) + triton_fn = lambda: matmul(a, b, c) + torch_fn = lambda: torch.matmul(a, torch.transpose(b, -1, -2)).to(torch.float32) + rtol = 1e-2 if a.dtype == torch.bfloat16 else 1e-3 + benchmark_suit.assert_close(triton_fn(), torch_fn(), atol=1e-4, rtol=rtol, err_msg='triton to torch') + _, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(triton_fn, warmup=10, rep=10, quantiles=quantiles, + fast_flush=False) + elif provider == 'xetla': + if B == 1: + c = torch.empty((M, N), device='xpu', dtype=torch.float32) + acc = torch.empty((M, N), device='xpu', dtype=torch.float32) + cnt = torch.empty((M, N), device='xpu', dtype=torch.int32) + else: + c = torch.empty((B, M, N), device='xpu', dtype=torch.float32) + acc = torch.empty((B, M, N), device='xpu', dtype=torch.float32) + cnt = torch.empty((B, M, N), device='xpu', dtype=torch.int32) + name = f'gemm_shape_{B}_{M}_{K}_{N}' + func = getattr(xetla_kernel, name) + xetla_fn = lambda: func(a, torch.transpose(b, -1, -2), c, acc, cnt) + torch_fn = lambda: torch.matmul(a, torch.tranpose(b, -1, -2)).to(torch.float32) + # benchmark_suit.assert_close(xetla_fn(), torch_fn(), atol=1e-4, rtol=1.0, err_msg='xetla to torch') + _, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(xetla_fn, warmup=10, rep=10, quantiles=quantiles, + fast_flush=False) + else: + raise NotImplementedError(f'Unsupported provider {provider}') + + tflops = lambda ms: 2 * B * M * N * K * (1e-12) / (ms * 1e-3) + gbps = lambda ms: B * (2 * (M * K + K * N) + 4.0 * (M * N)) * (1e-9) / (ms * 1e-3) + + return (gbps(mean_ms), gbps(max_ms), gbps(min_ms)), (tflops(mean_ms), tflops(max_ms), tflops(min_ms)), cv + + +if __name__ == '__main__': + benchmark.run(show_plots=False, print_data=True) diff --git a/include/triton/Analysis/Utility.h b/include/triton/Analysis/Utility.h index 9e3eff155b..ae05e20498 100644 --- a/include/triton/Analysis/Utility.h +++ b/include/triton/Analysis/Utility.h @@ -203,6 +203,8 @@ bool cvtNeedsSharedMemory(RankedTensorType srcTy, RankedTensorType dstTy); bool atomicNeedsSharedMemory(Value result); +bool isBlockedToDotShortcut(RankedTensorType &srcTy, RankedTensorType &dstT); + bool isMfmaToDotShortcut(RankedTensorType srcTy, RankedTensorType dstTy); bool isMmaToDotShortcut(RankedTensorType srcTy, RankedTensorType dstTy); diff --git a/lib/Analysis/Utility.cpp b/lib/Analysis/Utility.cpp index 20d552b15c..3b5b5c3f91 100644 --- a/lib/Analysis/Utility.cpp +++ b/lib/Analysis/Utility.cpp @@ -543,6 +543,75 @@ bool supportMMA(Value value, int version) { (elemTy.isInteger(8) && version >= 2); } +bool isBlockedToDotShortcut(RankedTensorType &srcTy, RankedTensorType &dstTy) { + auto blockedLayout = dyn_cast(srcTy.getEncoding()); + auto dotOperandLayout = dyn_cast(dstTy.getEncoding()); + if (blockedLayout == nullptr || dotOperandLayout == nullptr) + return false; + auto parentLayout = + dyn_cast(dotOperandLayout.getParent()); + if (parentLayout == nullptr) + return false; + auto opShape = srcTy.getShape(); + auto rank = opShape.size(); + + int kDim = dotOperandLayout.getOpIdx() == 0 ? rank - 1 : rank - 2; + int nonKDim = dotOperandLayout.getOpIdx() == 0 ? rank - 2 : rank - 1; + auto ctaLayout = blockedLayout.getCTALayout(); + + // The following logic checks that a source blocked layout matches a + // destination dot operand layout. This means that given tensor in source + // layout could be converted into destination layout without any data movement + // between registers or threads. + // + // It is considered a match if + // 1) Each thread in source layout holds a whole copy of all elements along + // the K dimension of a tensor + // 2) Distribution of data along all other non-K dimensions(Batch/M/N) + // matches between source and destination parent layouts. + // + // First condition comes from the property of dot operand layout with Blocked + // parent: size per threads along K dimension equals size of the tensor along + // K. Second condition comes from other property: dot operand layout + // inherits non-K dimensions from it's parent layout. + // + // clang-format off + // + // For example, following conversion is a no op: + // tensor<128x32xf16, #blocked<{sizePerThread = [2, 32], threadsPerWarp = [32, 1]}>> + // -> + // tensor<128x32xf16, #dot_op<{opIdx=0, parent=#blocked<{sizePerThread = [2, 8], threadsPerWarp = [32, 1]}>>> + // + // clang-format on + bool ctaLayoutCompatible = + ctaLayout.getCTASplitNum()[kDim] == 1 && + blockedLayout.getCTALayout() == parentLayout.getCTALayout(); + bool threadHoldsWholeKDim = + blockedLayout.getSizePerThread()[kDim] == opShape[kDim]; + bool nonKDimCompatible = + blockedLayout.getOrder() == parentLayout.getOrder() && + blockedLayout.getSizePerThread()[nonKDim] == + parentLayout.getSizePerThread()[nonKDim] && + blockedLayout.getThreadsPerWarp()[nonKDim] == + parentLayout.getThreadsPerWarp()[nonKDim] && + blockedLayout.getWarpsPerCTA()[nonKDim] == + parentLayout.getWarpsPerCTA()[nonKDim]; + bool matrixDimsCompatible = + ctaLayoutCompatible && threadHoldsWholeKDim && nonKDimCompatible; + if (rank == 2) + return matrixDimsCompatible; + + // additional check for batch dimension if it is present + assert(rank == 3); + bool bDimCompatible = + blockedLayout.getSizePerThread()[0] == + parentLayout.getSizePerThread()[0] && + blockedLayout.getThreadsPerWarp()[0] == + parentLayout.getThreadsPerWarp()[0] && + blockedLayout.getWarpsPerCTA()[0] == parentLayout.getWarpsPerCTA()[0]; + return matrixDimsCompatible && bDimCompatible; +} + bool isMfmaToDotShortcut(RankedTensorType srcTy, RankedTensorType dstTy) { auto mfmaLayout = dyn_cast(srcTy.getEncoding()); auto dotOperandLayout = dyn_cast(dstTy.getEncoding()); @@ -632,13 +701,14 @@ bool cvtNeedsWarpShuffle(RankedTensorType srcTy, RankedTensorType dstTy) { } bool cvtNeedsSharedMemory(RankedTensorType srcTy, RankedTensorType dstTy) { - // TODO(jlebar): Remove these special cases (`isMmaToDotShortcut` and - // `isMfmaToDotShortcut`) once they're fully subsumed by the linear-layout - // checks. + // TODO(jlebar): Remove these special cases (`isMmaToDotShortcut`, + // `isBlockedToDotShortcut` and `isMfmaToDotShortcut`) once they're fully + // subsumed by the linear-layout checks. // TODO(Keren): We didn't check `cvtNeedsWarpShuffle` here because it's not // supported yet in Triton's backend. return !cvtReordersRegisters(srcTy, dstTy) && !triton::gpu::intel::isDpasToDotShortcut(srcTy, dstTy) && + !isBlockedToDotShortcut(srcTy, dstTy) && !isMmaToDotShortcut(srcTy, dstTy) && !isMfmaToDotShortcut(srcTy, dstTy); } diff --git a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp index 30cb792763..893afc6590 100644 --- a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp @@ -232,6 +232,36 @@ struct ConvertLayoutOpConversion const TargetInfoBase &targetInfo; }; +struct ConvertLayoutOpBlockedToDotOpShortcutConversion + : public ConvertOpToLLVMPattern { + const TargetInfoBase &targetInfo; + explicit ConvertLayoutOpBlockedToDotOpShortcutConversion( + LLVMTypeConverter &typeConverter, const TargetInfoBase &targetInfo, + PatternBenefit benefit = 1) + : ConvertOpToLLVMPattern(typeConverter, benefit), targetInfo(targetInfo) { + } + + LogicalResult + matchAndRewrite(ConvertLayoutOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + MLIRContext *ctx = op.getContext(); + + const auto &shape = op.getType().getShape(); + auto srcTy = op.getSrc().getType(); + auto dstTy = op.getType(); + auto dstDotEncoding = dyn_cast(dstTy.getEncoding()); + if (!dstDotEncoding) + return failure(); + if (!isa(srcTy.getEncoding()) || + !isa(dstDotEncoding.getParent())) + return failure(); + if (cvtNeedsSharedMemory(srcTy, dstTy)) + return failure(); + rewriter.replaceOp(op, adaptor.getSrc()); + return success(); + } +}; + struct ConvertLayoutOpUsingLinearLayoutsConversion : public ConvertOpToLLVMPattern { const TargetInfoBase &targetInfo; @@ -657,5 +687,7 @@ void mlir::triton::populateConvertLayoutOpToLLVMPatterns( // one left. mlir::triton::populateConvertLayoutOpUsingLinearLayoutsToLLVMPattern( typeConverter, targetInfo, patterns, benefit.getBenefit() + 1); + patterns.add( + typeConverter, targetInfo, benefit); patterns.add(typeConverter, targetInfo, benefit); } diff --git a/lib/Conversion/TritonGPUToLLVM/DecomposeUnsupportedConversions.cpp b/lib/Conversion/TritonGPUToLLVM/DecomposeUnsupportedConversions.cpp index e61d5588ad..b6386626fb 100644 --- a/lib/Conversion/TritonGPUToLLVM/DecomposeUnsupportedConversions.cpp +++ b/lib/Conversion/TritonGPUToLLVM/DecomposeUnsupportedConversions.cpp @@ -83,6 +83,8 @@ void decomposeBlockedToDotLayoutConversion(ModuleOp module) { OpBuilder builder(cvtOp); auto srcType = cast(cvtOp.getSrc().getType()); auto dstType = cast(cvtOp.getType()); + if (!cvtNeedsSharedMemory(srcType, dstType)) + return; auto srcBlocked = dyn_cast(srcType.getEncoding()); auto dstDotOp = diff --git a/lib/Dialect/Triton/Transforms/RewriteTensorPointer.cpp b/lib/Dialect/Triton/Transforms/RewriteTensorPointer.cpp index 4d40e0f317..bb22489eac 100644 --- a/lib/Dialect/Triton/Transforms/RewriteTensorPointer.cpp +++ b/lib/Dialect/Triton/Transforms/RewriteTensorPointer.cpp @@ -413,6 +413,7 @@ class RewriteTensorPointerPass auto newForOp = builder.create(op.getLoc(), op.getLowerBound(), op.getUpperBound(), op.getStep(), newIterOperands); + newForOp->setAttrs(op->getAttrs()); // Create value mapping. Note that for tensor pointers, we use identity // mapping. It may refer to a value in the old loop, but we will rewrite it diff --git a/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp b/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp index c35b186fbf..7d508f234f 100644 --- a/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp +++ b/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp @@ -1,5 +1,6 @@ #include +#include "intel/include/Dialect/TritonIntelGPU/IR/LinearLayoutConversions.h" #include "triton/Dialect/Triton/IR/Utility.h" #include "triton/Dialect/TritonGPU/IR/Attributes.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h" @@ -827,6 +828,9 @@ DotOperandEncodingAttr::toLinearLayout(ArrayRef shape) const { if (auto mfmaLayout = llvm::dyn_cast(getParent())) { return dotOperandMfmaToLinearLayout(*this, shape); } + if (auto dpasLayout = llvm::dyn_cast(getParent())) { + return dotOperandDpasToLinearLayout(*this, shape); + } return std::nullopt; } diff --git a/lib/Dialect/TritonGPU/Transforms/ReduceDataDuplication.cpp b/lib/Dialect/TritonGPU/Transforms/ReduceDataDuplication.cpp index 8c1f18e459..b1e296c1bb 100644 --- a/lib/Dialect/TritonGPU/Transforms/ReduceDataDuplication.cpp +++ b/lib/Dialect/TritonGPU/Transforms/ReduceDataDuplication.cpp @@ -42,22 +42,8 @@ class TritonGPUReduceDataDuplicationPass dyn_cast(dstType.getEncoding()); if (!dstDotOp) return; - if (auto srcMmaEncoding = - dyn_cast(srcEncoding)) { - - if (srcMmaEncoding.getVersionMajor() != 2 || - (srcMmaEncoding.getWarpsPerCTA()[1] == 1 && - dstDotOp.getParent() == srcMmaEncoding)) - return; - } - if (auto srcMfmaEncoding = - dyn_cast(srcEncoding)) { - - if (srcMfmaEncoding.getWarpsPerCTA()[1] == 1 && - srcMfmaEncoding.getIsTransposed() && - dstDotOp.getParent() == srcMfmaEncoding) - return; - } + if (!cvtNeedsSharedMemory(srcType, dstType)) + return; auto srcOrder = triton::gpu::getOrder(srcEncoding); auto rank = srcOrder.size(); SmallVector sharedOrder; diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index 683ff5dfa4..7607572653 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -3375,7 +3375,8 @@ def test_dot3d(B, num_warps, M, N, K, BLOCK_M, BLOCK_N, in_dtype_str, out_dtype_ if is_hip(): # hip does not support tf32 precision, so use ieee for all tests input_precision = "ieee" - if "gfx11" in triton.runtime.driver.active.get_current_target().arch: + arch = triton.runtime.driver.active.get_current_target().arch + if "gfx11" in arch or "gfx12" in arch: if in_dtype_str == "float32": pytest.skip(f"{in_dtype_str} is not supported in WMMA dot, FMA does not support dot3d") if out_dtype_str == "float16": diff --git a/scripts/skiplist/xe2/.gitkeep b/scripts/skiplist/xe2/.gitkeep new file mode 100644 index 0000000000..e69de29bb2 diff --git a/scripts/skiplist/xe2/language.txt b/scripts/skiplist/xe2/language.txt new file mode 100644 index 0000000000..cdac848de1 --- /dev/null +++ b/scripts/skiplist/xe2/language.txt @@ -0,0 +1,163 @@ +# https://github.com/intel/intel-xpu-backend-for-triton/issues/1434 +test/unit/language/test_core.py::test_precise_math[1-tl.math.sqrt_rn(x)-tl.math.sqrt(x.to(tl.float64)).to(tl.float32)] +test/unit/language/test_core.py::test_dot3d[1-1-32-32-32-32-32-float16-float16] +test/unit/language/test_core.py::test_dot3d[1-1-32-32-32-32-32-float16-float32] +test/unit/language/test_core.py::test_dot3d[1-1-32-32-32-32-32-float32-float32] +test/unit/language/test_core.py::test_dot3d[1-1-32-32-32-32-32-int8-int8] +test/unit/language/test_core.py::test_dot3d[1-1-64-64-64-32-32-float16-float16] +test/unit/language/test_core.py::test_dot3d[1-1-64-64-64-32-32-float16-float32] +test/unit/language/test_core.py::test_dot3d[1-1-64-64-64-32-32-float32-float32] +test/unit/language/test_core.py::test_dot3d[1-1-64-64-64-32-32-int8-int8] +test/unit/language/test_core.py::test_dot3d[1-16-32-32-32-32-32-float16-float16] +test/unit/language/test_core.py::test_dot3d[1-16-32-32-32-32-32-float16-float32] +test/unit/language/test_core.py::test_dot3d[1-16-32-32-32-32-32-float32-float32] +test/unit/language/test_core.py::test_dot3d[1-16-32-32-32-32-32-int8-int8] +test/unit/language/test_core.py::test_dot3d[1-16-64-64-64-32-32-float16-float16] +test/unit/language/test_core.py::test_dot3d[1-16-64-64-64-32-32-float16-float32] +test/unit/language/test_core.py::test_dot3d[1-16-64-64-64-32-32-float32-float32] +test/unit/language/test_core.py::test_dot3d[1-16-64-64-64-32-32-int8-int8] +test/unit/language/test_core.py::test_dot3d[1-2-32-32-32-32-32-float16-float16] +test/unit/language/test_core.py::test_dot3d[1-2-32-32-32-32-32-float16-float32] +test/unit/language/test_core.py::test_dot3d[1-2-32-32-32-32-32-float32-float32] +test/unit/language/test_core.py::test_dot3d[1-2-32-32-32-32-32-int8-int8] +test/unit/language/test_core.py::test_dot3d[1-2-64-64-64-32-32-float16-float16] +test/unit/language/test_core.py::test_dot3d[1-2-64-64-64-32-32-float16-float32] +test/unit/language/test_core.py::test_dot3d[1-2-64-64-64-32-32-float32-float32] +test/unit/language/test_core.py::test_dot3d[1-2-64-64-64-32-32-int8-int8] +test/unit/language/test_core.py::test_dot3d[1-4-32-32-32-32-32-float16-float16] +test/unit/language/test_core.py::test_dot3d[1-4-32-32-32-32-32-float16-float32] +test/unit/language/test_core.py::test_dot3d[1-4-32-32-32-32-32-float32-float32] +test/unit/language/test_core.py::test_dot3d[1-4-32-32-32-32-32-int8-int8] +test/unit/language/test_core.py::test_dot3d[1-4-64-64-64-32-32-float16-float16] +test/unit/language/test_core.py::test_dot3d[1-4-64-64-64-32-32-float16-float32] +test/unit/language/test_core.py::test_dot3d[1-4-64-64-64-32-32-float32-float32] +test/unit/language/test_core.py::test_dot3d[1-4-64-64-64-32-32-int8-int8] +test/unit/language/test_core.py::test_dot3d[1-8-32-32-32-32-32-float16-float16] +test/unit/language/test_core.py::test_dot3d[1-8-32-32-32-32-32-float16-float32] +test/unit/language/test_core.py::test_dot3d[1-8-32-32-32-32-32-float32-float32] +test/unit/language/test_core.py::test_dot3d[1-8-32-32-32-32-32-int8-int8] +test/unit/language/test_core.py::test_dot3d[1-8-64-64-64-32-32-float16-float16] +test/unit/language/test_core.py::test_dot3d[1-8-64-64-64-32-32-float16-float32] +test/unit/language/test_core.py::test_dot3d[1-8-64-64-64-32-32-float32-float32] +test/unit/language/test_core.py::test_dot3d[1-8-64-64-64-32-32-int8-int8] +test/unit/language/test_core.py::test_dot3d[2-1-32-32-32-32-32-float16-float16] +test/unit/language/test_core.py::test_dot3d[2-1-32-32-32-32-32-float16-float32] +test/unit/language/test_core.py::test_dot3d[2-1-32-32-32-32-32-float32-float32] +test/unit/language/test_core.py::test_dot3d[2-1-32-32-32-32-32-int8-int8] +test/unit/language/test_core.py::test_dot3d[2-1-64-64-64-32-32-float16-float16] +test/unit/language/test_core.py::test_dot3d[2-1-64-64-64-32-32-float16-float32] +test/unit/language/test_core.py::test_dot3d[2-1-64-64-64-32-32-float32-float32] +test/unit/language/test_core.py::test_dot3d[2-1-64-64-64-32-32-int8-int8] +test/unit/language/test_core.py::test_dot3d[2-16-32-32-32-32-32-float16-float16] +test/unit/language/test_core.py::test_dot3d[2-16-32-32-32-32-32-float16-float32] +test/unit/language/test_core.py::test_dot3d[2-16-32-32-32-32-32-float32-float32] +test/unit/language/test_core.py::test_dot3d[2-16-32-32-32-32-32-int8-int8] +test/unit/language/test_core.py::test_dot3d[2-16-64-64-64-32-32-float16-float16] +test/unit/language/test_core.py::test_dot3d[2-16-64-64-64-32-32-float16-float32] +test/unit/language/test_core.py::test_dot3d[2-16-64-64-64-32-32-float32-float32] +test/unit/language/test_core.py::test_dot3d[2-16-64-64-64-32-32-int8-int8] +test/unit/language/test_core.py::test_dot3d[2-2-32-32-32-32-32-float16-float16] +test/unit/language/test_core.py::test_dot3d[2-2-32-32-32-32-32-float16-float32] +test/unit/language/test_core.py::test_dot3d[2-2-32-32-32-32-32-float32-float32] +test/unit/language/test_core.py::test_dot3d[2-2-32-32-32-32-32-int8-int8] +test/unit/language/test_core.py::test_dot3d[2-2-64-64-64-32-32-float16-float16] +test/unit/language/test_core.py::test_dot3d[2-2-64-64-64-32-32-float16-float32] +test/unit/language/test_core.py::test_dot3d[2-2-64-64-64-32-32-float32-float32] +test/unit/language/test_core.py::test_dot3d[2-2-64-64-64-32-32-int8-int8] +test/unit/language/test_core.py::test_dot3d[2-4-32-32-32-32-32-float16-float16] +test/unit/language/test_core.py::test_dot3d[2-4-32-32-32-32-32-float16-float32] +test/unit/language/test_core.py::test_dot3d[2-4-32-32-32-32-32-float32-float32] +test/unit/language/test_core.py::test_dot3d[2-4-32-32-32-32-32-int8-int8] +test/unit/language/test_core.py::test_dot3d[2-4-64-64-64-32-32-float16-float16] +test/unit/language/test_core.py::test_dot3d[2-4-64-64-64-32-32-float16-float32] +test/unit/language/test_core.py::test_dot3d[2-4-64-64-64-32-32-float32-float32] +test/unit/language/test_core.py::test_dot3d[2-4-64-64-64-32-32-int8-int8] +test/unit/language/test_core.py::test_dot3d[2-8-32-32-32-32-32-float16-float16] +test/unit/language/test_core.py::test_dot3d[2-8-32-32-32-32-32-float16-float32] +test/unit/language/test_core.py::test_dot3d[2-8-32-32-32-32-32-float32-float32] +test/unit/language/test_core.py::test_dot3d[2-8-32-32-32-32-32-int8-int8] +test/unit/language/test_core.py::test_dot3d[2-8-64-64-64-32-32-float16-float16] +test/unit/language/test_core.py::test_dot3d[2-8-64-64-64-32-32-float16-float32] +test/unit/language/test_core.py::test_dot3d[2-8-64-64-64-32-32-float32-float32] +test/unit/language/test_core.py::test_dot3d[2-8-64-64-64-32-32-int8-int8] +test/unit/language/test_core.py::test_dot3d[4-1-32-32-32-32-32-float16-float16] +test/unit/language/test_core.py::test_dot3d[4-1-32-32-32-32-32-float16-float32] +test/unit/language/test_core.py::test_dot3d[4-1-32-32-32-32-32-float32-float32] +test/unit/language/test_core.py::test_dot3d[4-1-32-32-32-32-32-int8-int8] +test/unit/language/test_core.py::test_dot3d[4-1-64-64-64-32-32-float16-float16] +test/unit/language/test_core.py::test_dot3d[4-1-64-64-64-32-32-float16-float32] +test/unit/language/test_core.py::test_dot3d[4-1-64-64-64-32-32-float32-float32] +test/unit/language/test_core.py::test_dot3d[4-1-64-64-64-32-32-int8-int8] +test/unit/language/test_core.py::test_dot3d[4-16-32-32-32-32-32-float16-float16] +test/unit/language/test_core.py::test_dot3d[4-16-32-32-32-32-32-float16-float32] +test/unit/language/test_core.py::test_dot3d[4-16-32-32-32-32-32-float32-float32] +test/unit/language/test_core.py::test_dot3d[4-16-32-32-32-32-32-int8-int8] +test/unit/language/test_core.py::test_dot3d[4-16-64-64-64-32-32-float16-float16] +test/unit/language/test_core.py::test_dot3d[4-16-64-64-64-32-32-float16-float32] +test/unit/language/test_core.py::test_dot3d[4-16-64-64-64-32-32-float32-float32] +test/unit/language/test_core.py::test_dot3d[4-16-64-64-64-32-32-int8-int8] +test/unit/language/test_core.py::test_dot3d[4-2-32-32-32-32-32-float16-float16] +test/unit/language/test_core.py::test_dot3d[4-2-32-32-32-32-32-float16-float32] +test/unit/language/test_core.py::test_dot3d[4-2-32-32-32-32-32-float32-float32] +test/unit/language/test_core.py::test_dot3d[4-2-32-32-32-32-32-int8-int8] +test/unit/language/test_core.py::test_dot3d[4-2-64-64-64-32-32-float16-float16] +test/unit/language/test_core.py::test_dot3d[4-2-64-64-64-32-32-float16-float32] +test/unit/language/test_core.py::test_dot3d[4-2-64-64-64-32-32-float32-float32] +test/unit/language/test_core.py::test_dot3d[4-2-64-64-64-32-32-int8-int8] +test/unit/language/test_core.py::test_dot3d[4-4-128-128-64-64-64-float16-float16] +test/unit/language/test_core.py::test_dot3d[4-4-32-32-32-32-32-float16-float16] +test/unit/language/test_core.py::test_dot3d[4-4-32-32-32-32-32-float16-float32] +test/unit/language/test_core.py::test_dot3d[4-4-32-32-32-32-32-float32-float32] +test/unit/language/test_core.py::test_dot3d[4-4-32-32-32-32-32-int8-int8] +test/unit/language/test_core.py::test_dot3d[4-4-64-64-64-32-32-float16-float16] +test/unit/language/test_core.py::test_dot3d[4-4-64-64-64-32-32-float16-float32] +test/unit/language/test_core.py::test_dot3d[4-4-64-64-64-32-32-float32-float32] +test/unit/language/test_core.py::test_dot3d[4-4-64-64-64-32-32-int8-int8] +test/unit/language/test_core.py::test_dot3d[4-8-32-32-32-32-32-float16-float16] +test/unit/language/test_core.py::test_dot3d[4-8-32-32-32-32-32-float16-float32] +test/unit/language/test_core.py::test_dot3d[4-8-32-32-32-32-32-float32-float32] +test/unit/language/test_core.py::test_dot3d[4-8-32-32-32-32-32-int8-int8] +test/unit/language/test_core.py::test_dot3d[4-8-64-64-64-32-32-float16-float16] +test/unit/language/test_core.py::test_dot3d[4-8-64-64-64-32-32-float16-float32] +test/unit/language/test_core.py::test_dot3d[4-8-64-64-64-32-32-float32-float32] +test/unit/language/test_core.py::test_dot3d[4-8-64-64-64-32-32-int8-int8] +test/unit/language/test_core.py::test_dot3d[8-1-32-32-32-32-32-float16-float16] +test/unit/language/test_core.py::test_dot3d[8-1-32-32-32-32-32-float16-float32] +test/unit/language/test_core.py::test_dot3d[8-1-32-32-32-32-32-float32-float32] +test/unit/language/test_core.py::test_dot3d[8-1-32-32-32-32-32-int8-int8] +test/unit/language/test_core.py::test_dot3d[8-1-64-64-64-32-32-float16-float16] +test/unit/language/test_core.py::test_dot3d[8-1-64-64-64-32-32-float16-float32] +test/unit/language/test_core.py::test_dot3d[8-1-64-64-64-32-32-float32-float32] +test/unit/language/test_core.py::test_dot3d[8-1-64-64-64-32-32-int8-int8] +test/unit/language/test_core.py::test_dot3d[8-16-32-32-32-32-32-float16-float16] +test/unit/language/test_core.py::test_dot3d[8-16-32-32-32-32-32-float16-float32] +test/unit/language/test_core.py::test_dot3d[8-16-32-32-32-32-32-float32-float32] +test/unit/language/test_core.py::test_dot3d[8-16-32-32-32-32-32-int8-int8] +test/unit/language/test_core.py::test_dot3d[8-16-64-64-64-32-32-float16-float16] +test/unit/language/test_core.py::test_dot3d[8-16-64-64-64-32-32-float16-float32] +test/unit/language/test_core.py::test_dot3d[8-16-64-64-64-32-32-float32-float32] +test/unit/language/test_core.py::test_dot3d[8-16-64-64-64-32-32-int8-int8] +test/unit/language/test_core.py::test_dot3d[8-2-32-32-32-32-32-float16-float16] +test/unit/language/test_core.py::test_dot3d[8-2-32-32-32-32-32-float16-float32] +test/unit/language/test_core.py::test_dot3d[8-2-32-32-32-32-32-float32-float32] +test/unit/language/test_core.py::test_dot3d[8-2-32-32-32-32-32-int8-int8] +test/unit/language/test_core.py::test_dot3d[8-2-64-64-64-32-32-float16-float16] +test/unit/language/test_core.py::test_dot3d[8-2-64-64-64-32-32-float16-float32] +test/unit/language/test_core.py::test_dot3d[8-2-64-64-64-32-32-float32-float32] +test/unit/language/test_core.py::test_dot3d[8-2-64-64-64-32-32-int8-int8] +test/unit/language/test_core.py::test_dot3d[8-4-32-32-32-32-32-float16-float16] +test/unit/language/test_core.py::test_dot3d[8-4-32-32-32-32-32-float16-float32] +test/unit/language/test_core.py::test_dot3d[8-4-32-32-32-32-32-float32-float32] +test/unit/language/test_core.py::test_dot3d[8-4-32-32-32-32-32-int8-int8] +test/unit/language/test_core.py::test_dot3d[8-4-64-64-64-32-32-float16-float16] +test/unit/language/test_core.py::test_dot3d[8-4-64-64-64-32-32-float16-float32] +test/unit/language/test_core.py::test_dot3d[8-4-64-64-64-32-32-float32-float32] +test/unit/language/test_core.py::test_dot3d[8-4-64-64-64-32-32-int8-int8] +test/unit/language/test_core.py::test_dot3d[8-8-32-32-32-32-32-float16-float16] +test/unit/language/test_core.py::test_dot3d[8-8-32-32-32-32-32-float16-float32] +test/unit/language/test_core.py::test_dot3d[8-8-32-32-32-32-32-float32-float32] +test/unit/language/test_core.py::test_dot3d[8-8-32-32-32-32-32-int8-int8] +test/unit/language/test_core.py::test_dot3d[8-8-64-64-64-32-32-float16-float16] +test/unit/language/test_core.py::test_dot3d[8-8-64-64-64-32-32-float16-float32] +test/unit/language/test_core.py::test_dot3d[8-8-64-64-64-32-32-float32-float32] +test/unit/language/test_core.py::test_dot3d[8-8-64-64-64-32-32-int8-int8] diff --git a/test/Conversion/amd/decompose-unsupported-conversions.mlir b/test/Conversion/amd/decompose-unsupported-conversions.mlir index 0d6220c80d..1bd288449f 100644 --- a/test/Conversion/amd/decompose-unsupported-conversions.mlir +++ b/test/Conversion/amd/decompose-unsupported-conversions.mlir @@ -1,15 +1,15 @@ -// RUN: triton-opt %s --split-input-file --decompose-unsupported-amd-conversions=arch=gfx1130 | FileCheck %s +// RUN: triton-opt %s --split-input-file --decompose-unsupported-amd-conversions | FileCheck %s -// CHECK: #[[BLOCKED:.+]] = #triton_gpu.blocked<{{.*}}> -// CHECK: #[[WMMA:.+]] = #triton_gpu.amd_wmma<{{.*}}> -// CHECK: #[[SHARED:.+]] = #triton_gpu.shared<{{.*}}> -// CHECK: wmma_to_wmma_dot_op +// CHECK: #[[$BLOCKED:.+]] = #triton_gpu.blocked<{{.*}}> +// CHECK: #[[$WMMA:.+]] = #triton_gpu.amd_wmma<{{.*}}> +// CHECK: #[[$SHARED:.+]] = #triton_gpu.shared<{{.*}}> +// CHECK-LABEL: wmma_to_wmma_dot_op #mma = #triton_gpu.amd_wmma<{version = 1, warpsPerCTA = [2, 2]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "hip:gfx1130", "triton_gpu.threads-per-warp" = 32 : i32} { tt.func @wmma_to_wmma_dot_op(%arg0: tensor<16x16xf16, #mma>) { - // CHECK: %[[SRC_BLOCKED:.+]] = triton_gpu.convert_layout %{{.*}} : tensor<16x16xf16, #[[WMMA]]> -> tensor<16x16xf16, #[[BLOCKED]]> - // CHECK-NEXT: %[[INT_SHARED:.+]] = triton_gpu.local_alloc %[[SRC_BLOCKED]] : {{.*}} -> !tt.memdesc<16x16xf16, #[[SHARED]], #triton_gpu.shared_memory> - // CHECK-NEXT: %[[DST_DOT_OP:.+]] = triton_gpu.local_load %[[INT_SHARED]] : {{.*}} -> tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #[[WMMA]], kWidth = 16}>> + // CHECK: %[[SRC_BLOCKED:.+]] = triton_gpu.convert_layout %{{.*}} : tensor<16x16xf16, #[[$WMMA]]> -> tensor<16x16xf16, #[[$BLOCKED]]> + // CHECK-NEXT: %[[INT_SHARED:.+]] = triton_gpu.local_alloc %[[SRC_BLOCKED]] : {{.*}} -> !tt.memdesc<16x16xf16, #[[$SHARED]], #triton_gpu.shared_memory> + // CHECK-NEXT: %[[DST_DOT_OP:.+]] = triton_gpu.local_load %[[INT_SHARED]] : {{.*}} -> tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #[[$WMMA]], kWidth = 16}>> %0 = triton_gpu.convert_layout %arg0 : tensor<16x16xf16, #mma> -> tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 16}>> tt.return } @@ -17,17 +17,89 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // ----- -// CHECK: #[[BLOCKED:.+]] = #triton_gpu.blocked<{{.*}}> -// CHECK: #[[WMMA:.+]] = #triton_gpu.amd_wmma<{{.*}}> -// CHECK: #[[SHARED:.+]] = #triton_gpu.shared<{{.*}}> -// CHECK: wmma_to_wmma_dot3d_op +// CHECK: #[[$BLOCKED:.+]] = #triton_gpu.blocked<{{.*}}> +// CHECK: #[[$WMMA:.+]] = #triton_gpu.amd_wmma<{{.*}}> +// CHECK: #[[$SHARED:.+]] = #triton_gpu.shared<{{.*}}> +// CHECK-LABEL: wmma_to_wmma_dot3d_op #mma = #triton_gpu.amd_wmma<{version = 1, warpsPerCTA = [2, 2, 2]}> module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { tt.func @wmma_to_wmma_dot3d_op(%arg0: tensor<2x16x16xf16, #mma>) { - // CHECK: %[[SRC_BLOCKED:.+]] = triton_gpu.convert_layout %{{.*}} : tensor<2x16x16xf16, #[[WMMA]]> -> tensor<2x16x16xf16, #[[BLOCKED]]> - // CHECK-NEXT: %[[INT_SHARED:.+]] = triton_gpu.local_alloc %[[SRC_BLOCKED]] : {{.*}} -> !tt.memdesc<2x16x16xf16, #[[SHARED]], #triton_gpu.shared_memory> - // CHECK-NEXT: %[[DST_DOT_OP:.+]] = triton_gpu.local_load %[[INT_SHARED]] : {{.*}} -> tensor<2x16x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #[[WMMA]], kWidth = 16}>> + // CHECK: %[[SRC_BLOCKED:.+]] = triton_gpu.convert_layout %{{.*}} : tensor<2x16x16xf16, #[[$WMMA]]> -> tensor<2x16x16xf16, #[[$BLOCKED]]> + // CHECK-NEXT: %[[INT_SHARED:.+]] = triton_gpu.local_alloc %[[SRC_BLOCKED]] : {{.*}} -> !tt.memdesc<2x16x16xf16, #[[$SHARED]], #triton_gpu.shared_memory> + // CHECK-NEXT: %[[DST_DOT_OP:.+]] = triton_gpu.local_load %[[INT_SHARED]] : {{.*}} -> tensor<2x16x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #[[$WMMA]], kWidth = 16}>> %0 = triton_gpu.convert_layout %arg0 : tensor<2x16x16xf16, #mma> -> tensor<2x16x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 16}>> tt.return } } + +// ----- + +// CHECK-LABEL: blocked_to_dot_op_shortcut_gfx1130 +#blocked = #triton_gpu.blocked<{sizePerThread = [1, 32], threadsPerWarp = [32, 1], warpsPerCTA = [2, 2], order = [1, 0]}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "hip:gfx1130", "triton_gpu.threads-per-warp" = 32 : i32} { + tt.func @blocked_to_dot_op_shortcut_gfx1130(%arg0: tensor<32x32xf16, #blocked>) { + // CHECK-NOT: triton_gpu.local_alloc + // CHECK: triton_gpu.convert_layout + // CHECK-NOT: triton_gpu.local_alloc + %0 = triton_gpu.convert_layout %arg0 : tensor<32x32xf16, #blocked> -> tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>> + tt.return + } +} + +// ----- + +// CHECK-LABEL: blocked_to_dot_op_shortcut_gfx940 +#blocked = #triton_gpu.blocked<{sizePerThread = [1, 32], threadsPerWarp = [32, 2], warpsPerCTA = [2, 2], order = [1, 0]}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "hip:gfx940", "triton_gpu.threads-per-warp" = 64 : i32} { + tt.func @blocked_to_dot_op_shortcut_gfx940(%arg0: tensor<32x32xf16, #blocked>) { + // CHECK-NOT: triton_gpu.local_alloc + // CHECK: triton_gpu.convert_layout + // CHECK-NOT: triton_gpu.local_alloc + %0 = triton_gpu.convert_layout %arg0 : tensor<32x32xf16, #blocked> -> tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>> + tt.return + } +} + +// ----- + +// CHECK-LABEL: neg_blocked_to_dot_op_incompatible_elems_gfx940 +#blocked = #triton_gpu.blocked<{sizePerThread = [1, 16], threadsPerWarp = [32, 2], warpsPerCTA = [2, 2], order = [1, 0]}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "hip:gfx940", "triton_gpu.threads-per-warp" = 64 : i32} { + tt.func @neg_blocked_to_dot_op_incompatible_elems_gfx940(%arg0: tensor<32x32xf16, #blocked>) { + // CHECK-NOT: triton_gpu.convert_layout + // CHECK: triton_gpu.local_alloc + // CHECK: triton_gpu.local_load + %0 = triton_gpu.convert_layout %arg0 : tensor<32x32xf16, #blocked> -> tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>> + tt.return + } +} + +// ----- + +// CHECK-LABEL: neg_blocked_to_dot_op_incompatible_threads_gfx940 +#blocked = #triton_gpu.blocked<{sizePerThread = [1, 32], threadsPerWarp = [32, 2], warpsPerCTA = [2, 2], order = [1, 0]}> +#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 32], threadsPerWarp = [16, 4], warpsPerCTA = [2, 2], order = [1, 0]}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "hip:gfx940", "triton_gpu.threads-per-warp" = 64 : i32} { + tt.func @neg_blocked_to_dot_op_incompatible_threads_gfx940(%arg0: tensor<32x32xf16, #blocked>) { + // CHECK-NOT: triton_gpu.convert_layout + // CHECK: triton_gpu.local_alloc + // CHECK: triton_gpu.local_load + %0 = triton_gpu.convert_layout %arg0 : tensor<32x32xf16, #blocked> -> tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked1}>> + tt.return + } +} + +// ----- + +// CHECK-LABEL: neg_blocked_to_dot_op_incompatible_warp_gfx940 +#blocked = #triton_gpu.blocked<{sizePerThread = [1, 32], threadsPerWarp = [32, 2], warpsPerCTA = [2, 2], order = [1, 0]}> +#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 32], threadsPerWarp = [32, 2], warpsPerCTA = [4, 1], order = [1, 0]}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "hip:gfx940", "triton_gpu.threads-per-warp" = 64 : i32} { + tt.func @neg_blocked_to_dot_op_incompatible_warp_gfx940(%arg0: tensor<32x32xf16, #blocked>) { + // CHECK-NOT: triton_gpu.convert_layout + // CHECK: triton_gpu.local_alloc + // CHECK: triton_gpu.local_load + %0 = triton_gpu.convert_layout %arg0 : tensor<32x32xf16, #blocked> -> tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked1}>> + tt.return + } +} diff --git a/test/Conversion/intel/dot_layout_offset.mlir b/test/Conversion/intel/dot_layout_offset.mlir index 26e9d4d603..92129848d0 100644 --- a/test/Conversion/intel/dot_layout_offset.mlir +++ b/test/Conversion/intel/dot_layout_offset.mlir @@ -344,317 +344,307 @@ module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.num-ctas" = 1 : // CHECK: %[[THREAD_ID_I64:.*]] = llvm.call spir_funccc @_Z12get_local_idj(%[[VAL_142]]) // CHECK: %[[THREAD_ID_I32:.*]] = llvm.trunc %[[THREAD_ID_I64]] : i64 to i32 // CHECK: %[[VAL_145:.*]] = llvm.mlir.constant(16 : i32) : i32 + // CHECK: %[[LANE_ID:.*]] = llvm.urem %[[THREAD_ID_I32]], %[[VAL_145]] : i32 // CHECK: %[[WARP_ID:.*]] = llvm.udiv %[[THREAD_ID_I32]], %[[VAL_145]] : i32 - // CHECK: %[[VAL_147:.*]] = llvm.mlir.constant(16 : i32) : i32 - // CHECK: %[[LANE_ID:.*]] = llvm.urem %[[THREAD_ID_I32]], %[[VAL_147]] : i32 + // CHECK-COUNT-3: %[[CST_0:.*]] = llvm.mlir.constant(0 : i32) : i32 // CHECK: %[[VAL_149:.*]] = llvm.mlir.constant(1 : i32) : i32 - // CHECK: %[[WARP_ID_N:.*]] = llvm.urem %[[WARP_ID]], %[[VAL_149]] : i32 - // CHECK: %[[VAL_151:.*]] = llvm.udiv %[[WARP_ID]], %[[VAL_149]] : i32 + // CHECK: %[[VAL_150:.*]] = llvm.and %[[LANE_ID]], %[[VAL_149]] : i32 + // CHECK: %[[VAL_151:.*]] = llvm.icmp "eq" %[[VAL_150]], %[[CST_0]] : i32 // CHECK: %[[VAL_152:.*]] = llvm.mlir.constant(1 : i32) : i32 - // CHECK: %[[WARP_ID_M:.*]] = llvm.urem %[[VAL_151]], %[[VAL_152]] : i32 - // CHECK: %[[VAL_154:.*]] = llvm.udiv %[[VAL_151]], %[[VAL_152]] : i32 + // CHECK: %[[VAL_153:.*]] = llvm.select %[[VAL_151]], %[[CST_0]], %[[VAL_152]] : i1, i32 + // CHECK: %[[VAL_154:.*]] = llvm.xor %[[CST_0]], %[[VAL_153]] : i32 // CHECK: %[[VAL_155:.*]] = llvm.mlir.constant(2 : i32) : i32 - // CHECK: %[[ROUNDED_WARP_ID_N:.*]] = llvm.urem %[[WARP_ID_N]], %[[VAL_155]] : i32 - // CHECK: %[[warpShape_N:.*]] = llvm.mlir.constant(16 : i32) : i32 - // CHECK: %[[warpOffset:.*]] = llvm.mul %[[ROUNDED_WARP_ID_N]], %[[warpShape_N]] : i32 - // CHECK: %[[VAL_159:.*]] = llvm.mlir.constant(8 : i32) : i32 - // CHECK: %[[VAL_160:.*]] = llvm.udiv %[[LANE_ID]], %[[VAL_159]] : i32 - // CHECK: %[[VAL_161:.*]] = llvm.mlir.constant(2 : i32) : i32 - // CHECK: %[[laneRowIndex:.*]] = llvm.mul %[[VAL_160]], %[[VAL_161]] : i32 - // CHECK: %[[VAL_163:.*]] = llvm.mlir.constant(8 : i32) : i32 - // CHECK: %[[laneColIndex:.*]] = llvm.urem %[[LANE_ID]], %[[VAL_163]] : i32 - // CHECK: %[[multiDimBase_N:.*]] = llvm.add %[[laneColIndex]], %[[warpOffset]] : i32 - // CHECK: %[[VAL_166:.*]] = llvm.mlir.constant(0 : i32) : i32 - // CHECK: %[[VAL_167:.*]] = llvm.mlir.constant(1 : i32) : i32 - // CHECK: %[[VAL_168:.*]] = llvm.urem %[[VAL_166]], %[[VAL_167]] : i32 - // CHECK: %[[VAL_169:.*]] = llvm.udiv %[[VAL_166]], %[[VAL_167]] : i32 - // CHECK: %[[VAL_170:.*]] = llvm.mlir.constant(1 : i32) : i32 - // CHECK: %[[VAL_171:.*]] = llvm.urem %[[VAL_169]], %[[VAL_170]] : i32 - // CHECK: %[[VAL_172:.*]] = llvm.udiv %[[VAL_169]], %[[VAL_170]] : i32 - // CHECK: %[[VAL_173:.*]] = llvm.mlir.constant(1 : i32) : i32 - // CHECK: %[[VAL_174:.*]] = llvm.urem %[[VAL_171]], %[[VAL_173]] : i32 - // CHECK: %[[VAL_175:.*]] = llvm.mlir.constant(1 : i32) : i32 - // CHECK: %[[VAL_176:.*]] = llvm.urem %[[VAL_168]], %[[VAL_175]] : i32 - // CHECK: %[[VAL_177:.*]] = llvm.mlir.constant(32 : i32) : i32 - // CHECK: %[[CTAOffset_M:.*]] = llvm.mul %[[VAL_174]], %[[VAL_177]] : i32 - // CHECK: %[[VAL_179:.*]] = llvm.mlir.constant(32 : i32) : i32 - // CHECK: %[[CTAOffset_N:.*]] = llvm.mul %[[VAL_176]], %[[VAL_179]] : i32 - // CHECK: %[[VAL_181:.*]] = llvm.add %[[laneRowIndex]], %[[CTAOffset_M]] : i32 - // CHECK: %[[VAL_182:.*]] = llvm.add %[[multiDimBase_N]], %[[CTAOffset_N]] : i32 + // CHECK: %[[VAL_156:.*]] = llvm.and %[[LANE_ID]], %[[VAL_155]] : i32 + // CHECK: %[[VAL_157:.*]] = llvm.icmp "eq" %[[VAL_156]], %[[CST_0]] : i32 + // CHECK: %[[VAL_158:.*]] = llvm.mlir.constant(2 : i32) : i32 + // CHECK: %[[VAL_159:.*]] = llvm.select %[[VAL_157]], %[[CST_0]], %[[VAL_158]] : i1, i32 + // CHECK: %[[VAL_160:.*]] = llvm.xor %[[VAL_154]], %[[VAL_159]] : i32 + // CHECK: %[[VAL_161:.*]] = llvm.mlir.constant(4 : i32) : i32 + // CHECK: %[[VAL_162:.*]] = llvm.and %[[LANE_ID]], %[[VAL_161]] : i32 + // CHECK: %[[VAL_163:.*]] = llvm.icmp "eq" %[[VAL_162]], %[[CST_0]] : i32 + // CHECK: %[[VAL_164:.*]] = llvm.mlir.constant(4 : i32) : i32 + // CHECK: %[[VAL_165:.*]] = llvm.select %[[VAL_163]], %[[CST_0]], %[[VAL_164]] : i1, i32 + // CHECK: %[[VAL_182:.*]] = llvm.xor %[[VAL_160]], %[[VAL_165]] : i32 + // CHECK: %[[VAL_167:.*]] = llvm.mlir.constant(8 : i32) : i32 + // CHECK: %[[VAL_168:.*]] = llvm.and %[[LANE_ID]], %[[VAL_167]] : i32 + // CHECK: %[[VAL_169:.*]] = llvm.icmp "eq" %[[VAL_168]], %[[CST_0]] : i32 + // CHECK: %[[VAL_170:.*]] = llvm.mlir.constant(2 : i32) : i32 + // CHECK: %[[VAL_171:.*]] = llvm.select %[[VAL_169]], %[[CST_0]], %[[VAL_170]] : i1, i32 + // CHECK: %[[VAL_181:.*]] = llvm.xor %[[CST_0]], %[[VAL_171]] : i32 // COM: There are total [2, 4] repetitions of tensor shape [32, 32] per warp of B. // COM: The repetitions are clustered as [1, 2] for B operand. The repetitions orders are [0, 0], [0, 1], [1, 0], [1, 1], [0, 2], [0, 3], [1, 2], [1, 3] // COM: Offsets of rep [0, 0]. // CHECK: %[[VAL_183:.*]] = llvm.mlir.constant(0 : i32) : i32 - // CHECK: %[[VAL_184:.*]] = llvm.add %[[VAL_181]], %[[VAL_183]] : i32 + // CHECK: %[[VAL_184:.*]] = llvm.xor %[[VAL_181]], %[[VAL_183]] : i32 // CHECK: %[[VAL_185:.*]] = llvm.mlir.constant(0 : i32) : i32 - // CHECK: %[[VAL_186:.*]] = llvm.add %[[VAL_182]], %[[VAL_185]] : i32 + // CHECK: %[[VAL_186:.*]] = llvm.xor %[[VAL_182]], %[[VAL_185]] : i32 // CHECK: %[[VAL_187:.*]] = llvm.mlir.constant(1 : i32) : i32 - // CHECK: %[[VAL_188:.*]] = llvm.add %[[VAL_181]], %[[VAL_187]] : i32 + // CHECK: %[[VAL_188:.*]] = llvm.xor %[[VAL_181]], %[[VAL_187]] : i32 // CHECK: %[[VAL_189:.*]] = llvm.mlir.constant(0 : i32) : i32 - // CHECK: %[[VAL_190:.*]] = llvm.add %[[VAL_182]], %[[VAL_189]] : i32 + // CHECK: %[[VAL_190:.*]] = llvm.xor %[[VAL_182]], %[[VAL_189]] : i32 // CHECK: %[[VAL_191:.*]] = llvm.mlir.constant(4 : i32) : i32 - // CHECK: %[[VAL_192:.*]] = llvm.add %[[VAL_181]], %[[VAL_191]] : i32 + // CHECK: %[[VAL_192:.*]] = llvm.xor %[[VAL_181]], %[[VAL_191]] : i32 // CHECK: %[[VAL_193:.*]] = llvm.mlir.constant(0 : i32) : i32 - // CHECK: %[[VAL_194:.*]] = llvm.add %[[VAL_182]], %[[VAL_193]] : i32 + // CHECK: %[[VAL_194:.*]] = llvm.xor %[[VAL_182]], %[[VAL_193]] : i32 // CHECK: %[[VAL_195:.*]] = llvm.mlir.constant(5 : i32) : i32 - // CHECK: %[[VAL_196:.*]] = llvm.add %[[VAL_181]], %[[VAL_195]] : i32 + // CHECK: %[[VAL_196:.*]] = llvm.xor %[[VAL_181]], %[[VAL_195]] : i32 // CHECK: %[[VAL_197:.*]] = llvm.mlir.constant(0 : i32) : i32 - // CHECK: %[[VAL_198:.*]] = llvm.add %[[VAL_182]], %[[VAL_197]] : i32 + // CHECK: %[[VAL_198:.*]] = llvm.xor %[[VAL_182]], %[[VAL_197]] : i32 // CHECK: %[[VAL_199:.*]] = llvm.mlir.constant(8 : i32) : i32 - // CHECK: %[[VAL_200:.*]] = llvm.add %[[VAL_181]], %[[VAL_199]] : i32 + // CHECK: %[[VAL_200:.*]] = llvm.xor %[[VAL_181]], %[[VAL_199]] : i32 // CHECK: %[[VAL_201:.*]] = llvm.mlir.constant(0 : i32) : i32 - // CHECK: %[[VAL_202:.*]] = llvm.add %[[VAL_182]], %[[VAL_201]] : i32 + // CHECK: %[[VAL_202:.*]] = llvm.xor %[[VAL_182]], %[[VAL_201]] : i32 // CHECK: %[[VAL_203:.*]] = llvm.mlir.constant(9 : i32) : i32 - // CHECK: %[[VAL_204:.*]] = llvm.add %[[VAL_181]], %[[VAL_203]] : i32 + // CHECK: %[[VAL_204:.*]] = llvm.xor %[[VAL_181]], %[[VAL_203]] : i32 // CHECK: %[[VAL_205:.*]] = llvm.mlir.constant(0 : i32) : i32 - // CHECK: %[[VAL_206:.*]] = llvm.add %[[VAL_182]], %[[VAL_205]] : i32 + // CHECK: %[[VAL_206:.*]] = llvm.xor %[[VAL_182]], %[[VAL_205]] : i32 // CHECK: %[[VAL_207:.*]] = llvm.mlir.constant(12 : i32) : i32 - // CHECK: %[[VAL_208:.*]] = llvm.add %[[VAL_181]], %[[VAL_207]] : i32 + // CHECK: %[[VAL_208:.*]] = llvm.xor %[[VAL_181]], %[[VAL_207]] : i32 // CHECK: %[[VAL_209:.*]] = llvm.mlir.constant(0 : i32) : i32 - // CHECK: %[[VAL_210:.*]] = llvm.add %[[VAL_182]], %[[VAL_209]] : i32 + // CHECK: %[[VAL_210:.*]] = llvm.xor %[[VAL_182]], %[[VAL_209]] : i32 // CHECK: %[[VAL_211:.*]] = llvm.mlir.constant(13 : i32) : i32 - // CHECK: %[[VAL_212:.*]] = llvm.add %[[VAL_181]], %[[VAL_211]] : i32 + // CHECK: %[[VAL_212:.*]] = llvm.xor %[[VAL_181]], %[[VAL_211]] : i32 // CHECK: %[[VAL_213:.*]] = llvm.mlir.constant(0 : i32) : i32 - // CHECK: %[[VAL_214:.*]] = llvm.add %[[VAL_182]], %[[VAL_213]] : i32 + // CHECK: %[[VAL_214:.*]] = llvm.xor %[[VAL_182]], %[[VAL_213]] : i32 // COM: Offsets of rep [0, 1]. // CHECK: %[[VAL_215:.*]] = llvm.mlir.constant(0 : i32) : i32 - // CHECK: %[[VAL_216:.*]] = llvm.add %[[VAL_181]], %[[VAL_215]] : i32 + // CHECK: %[[VAL_216:.*]] = llvm.xor %[[VAL_181]], %[[VAL_215]] : i32 // CHECK: %[[VAL_217:.*]] = llvm.mlir.constant(8 : i32) : i32 - // CHECK: %[[VAL_218:.*]] = llvm.add %[[VAL_182]], %[[VAL_217]] : i32 + // CHECK: %[[VAL_218:.*]] = llvm.xor %[[VAL_182]], %[[VAL_217]] : i32 // CHECK: %[[VAL_219:.*]] = llvm.mlir.constant(1 : i32) : i32 - // CHECK: %[[VAL_220:.*]] = llvm.add %[[VAL_181]], %[[VAL_219]] : i32 + // CHECK: %[[VAL_220:.*]] = llvm.xor %[[VAL_181]], %[[VAL_219]] : i32 // CHECK: %[[VAL_221:.*]] = llvm.mlir.constant(8 : i32) : i32 - // CHECK: %[[VAL_222:.*]] = llvm.add %[[VAL_182]], %[[VAL_221]] : i32 + // CHECK: %[[VAL_222:.*]] = llvm.xor %[[VAL_182]], %[[VAL_221]] : i32 // CHECK: %[[VAL_223:.*]] = llvm.mlir.constant(4 : i32) : i32 - // CHECK: %[[VAL_224:.*]] = llvm.add %[[VAL_181]], %[[VAL_223]] : i32 + // CHECK: %[[VAL_224:.*]] = llvm.xor %[[VAL_181]], %[[VAL_223]] : i32 // CHECK: %[[VAL_225:.*]] = llvm.mlir.constant(8 : i32) : i32 - // CHECK: %[[VAL_226:.*]] = llvm.add %[[VAL_182]], %[[VAL_225]] : i32 + // CHECK: %[[VAL_226:.*]] = llvm.xor %[[VAL_182]], %[[VAL_225]] : i32 // CHECK: %[[VAL_227:.*]] = llvm.mlir.constant(5 : i32) : i32 - // CHECK: %[[VAL_228:.*]] = llvm.add %[[VAL_181]], %[[VAL_227]] : i32 + // CHECK: %[[VAL_228:.*]] = llvm.xor %[[VAL_181]], %[[VAL_227]] : i32 // CHECK: %[[VAL_229:.*]] = llvm.mlir.constant(8 : i32) : i32 - // CHECK: %[[VAL_230:.*]] = llvm.add %[[VAL_182]], %[[VAL_229]] : i32 + // CHECK: %[[VAL_230:.*]] = llvm.xor %[[VAL_182]], %[[VAL_229]] : i32 // CHECK: %[[VAL_231:.*]] = llvm.mlir.constant(8 : i32) : i32 - // CHECK: %[[VAL_232:.*]] = llvm.add %[[VAL_181]], %[[VAL_231]] : i32 + // CHECK: %[[VAL_232:.*]] = llvm.xor %[[VAL_181]], %[[VAL_231]] : i32 // CHECK: %[[VAL_233:.*]] = llvm.mlir.constant(8 : i32) : i32 - // CHECK: %[[VAL_234:.*]] = llvm.add %[[VAL_182]], %[[VAL_233]] : i32 + // CHECK: %[[VAL_234:.*]] = llvm.xor %[[VAL_182]], %[[VAL_233]] : i32 // CHECK: %[[VAL_235:.*]] = llvm.mlir.constant(9 : i32) : i32 - // CHECK: %[[VAL_236:.*]] = llvm.add %[[VAL_181]], %[[VAL_235]] : i32 + // CHECK: %[[VAL_236:.*]] = llvm.xor %[[VAL_181]], %[[VAL_235]] : i32 // CHECK: %[[VAL_237:.*]] = llvm.mlir.constant(8 : i32) : i32 - // CHECK: %[[VAL_238:.*]] = llvm.add %[[VAL_182]], %[[VAL_237]] : i32 + // CHECK: %[[VAL_238:.*]] = llvm.xor %[[VAL_182]], %[[VAL_237]] : i32 // CHECK: %[[VAL_239:.*]] = llvm.mlir.constant(12 : i32) : i32 - // CHECK: %[[VAL_240:.*]] = llvm.add %[[VAL_181]], %[[VAL_239]] : i32 + // CHECK: %[[VAL_240:.*]] = llvm.xor %[[VAL_181]], %[[VAL_239]] : i32 // CHECK: %[[VAL_241:.*]] = llvm.mlir.constant(8 : i32) : i32 - // CHECK: %[[VAL_242:.*]] = llvm.add %[[VAL_182]], %[[VAL_241]] : i32 + // CHECK: %[[VAL_242:.*]] = llvm.xor %[[VAL_182]], %[[VAL_241]] : i32 // CHECK: %[[VAL_243:.*]] = llvm.mlir.constant(13 : i32) : i32 - // CHECK: %[[VAL_244:.*]] = llvm.add %[[VAL_181]], %[[VAL_243]] : i32 + // CHECK: %[[VAL_244:.*]] = llvm.xor %[[VAL_181]], %[[VAL_243]] : i32 // CHECK: %[[VAL_245:.*]] = llvm.mlir.constant(8 : i32) : i32 - // CHECK: %[[VAL_246:.*]] = llvm.add %[[VAL_182]], %[[VAL_245]] : i32 + // CHECK: %[[VAL_246:.*]] = llvm.xor %[[VAL_182]], %[[VAL_245]] : i32 // COM: Offsets of rep [1, 0]. // CHECK: %[[VAL_247:.*]] = llvm.mlir.constant(16 : i32) : i32 - // CHECK: %[[VAL_248:.*]] = llvm.add %[[VAL_181]], %[[VAL_247]] : i32 + // CHECK: %[[VAL_248:.*]] = llvm.xor %[[VAL_181]], %[[VAL_247]] : i32 // CHECK: %[[VAL_249:.*]] = llvm.mlir.constant(0 : i32) : i32 - // CHECK: %[[VAL_250:.*]] = llvm.add %[[VAL_182]], %[[VAL_249]] : i32 + // CHECK: %[[VAL_250:.*]] = llvm.xor %[[VAL_182]], %[[VAL_249]] : i32 // CHECK: %[[VAL_251:.*]] = llvm.mlir.constant(17 : i32) : i32 - // CHECK: %[[VAL_252:.*]] = llvm.add %[[VAL_181]], %[[VAL_251]] : i32 + // CHECK: %[[VAL_252:.*]] = llvm.xor %[[VAL_181]], %[[VAL_251]] : i32 // CHECK: %[[VAL_253:.*]] = llvm.mlir.constant(0 : i32) : i32 - // CHECK: %[[VAL_254:.*]] = llvm.add %[[VAL_182]], %[[VAL_253]] : i32 + // CHECK: %[[VAL_254:.*]] = llvm.xor %[[VAL_182]], %[[VAL_253]] : i32 // CHECK: %[[VAL_255:.*]] = llvm.mlir.constant(20 : i32) : i32 - // CHECK: %[[VAL_256:.*]] = llvm.add %[[VAL_181]], %[[VAL_255]] : i32 + // CHECK: %[[VAL_256:.*]] = llvm.xor %[[VAL_181]], %[[VAL_255]] : i32 // CHECK: %[[VAL_257:.*]] = llvm.mlir.constant(0 : i32) : i32 - // CHECK: %[[VAL_258:.*]] = llvm.add %[[VAL_182]], %[[VAL_257]] : i32 + // CHECK: %[[VAL_258:.*]] = llvm.xor %[[VAL_182]], %[[VAL_257]] : i32 // CHECK: %[[VAL_259:.*]] = llvm.mlir.constant(21 : i32) : i32 - // CHECK: %[[VAL_260:.*]] = llvm.add %[[VAL_181]], %[[VAL_259]] : i32 + // CHECK: %[[VAL_260:.*]] = llvm.xor %[[VAL_181]], %[[VAL_259]] : i32 // CHECK: %[[VAL_261:.*]] = llvm.mlir.constant(0 : i32) : i32 - // CHECK: %[[VAL_262:.*]] = llvm.add %[[VAL_182]], %[[VAL_261]] : i32 + // CHECK: %[[VAL_262:.*]] = llvm.xor %[[VAL_182]], %[[VAL_261]] : i32 // CHECK: %[[VAL_263:.*]] = llvm.mlir.constant(24 : i32) : i32 - // CHECK: %[[VAL_264:.*]] = llvm.add %[[VAL_181]], %[[VAL_263]] : i32 + // CHECK: %[[VAL_264:.*]] = llvm.xor %[[VAL_181]], %[[VAL_263]] : i32 // CHECK: %[[VAL_265:.*]] = llvm.mlir.constant(0 : i32) : i32 - // CHECK: %[[VAL_266:.*]] = llvm.add %[[VAL_182]], %[[VAL_265]] : i32 + // CHECK: %[[VAL_266:.*]] = llvm.xor %[[VAL_182]], %[[VAL_265]] : i32 // CHECK: %[[VAL_267:.*]] = llvm.mlir.constant(25 : i32) : i32 - // CHECK: %[[VAL_268:.*]] = llvm.add %[[VAL_181]], %[[VAL_267]] : i32 + // CHECK: %[[VAL_268:.*]] = llvm.xor %[[VAL_181]], %[[VAL_267]] : i32 // CHECK: %[[VAL_269:.*]] = llvm.mlir.constant(0 : i32) : i32 - // CHECK: %[[VAL_270:.*]] = llvm.add %[[VAL_182]], %[[VAL_269]] : i32 + // CHECK: %[[VAL_270:.*]] = llvm.xor %[[VAL_182]], %[[VAL_269]] : i32 // CHECK: %[[VAL_271:.*]] = llvm.mlir.constant(28 : i32) : i32 - // CHECK: %[[VAL_272:.*]] = llvm.add %[[VAL_181]], %[[VAL_271]] : i32 + // CHECK: %[[VAL_272:.*]] = llvm.xor %[[VAL_181]], %[[VAL_271]] : i32 // CHECK: %[[VAL_273:.*]] = llvm.mlir.constant(0 : i32) : i32 - // CHECK: %[[VAL_274:.*]] = llvm.add %[[VAL_182]], %[[VAL_273]] : i32 + // CHECK: %[[VAL_274:.*]] = llvm.xor %[[VAL_182]], %[[VAL_273]] : i32 // CHECK: %[[VAL_275:.*]] = llvm.mlir.constant(29 : i32) : i32 - // CHECK: %[[VAL_276:.*]] = llvm.add %[[VAL_181]], %[[VAL_275]] : i32 + // CHECK: %[[VAL_276:.*]] = llvm.xor %[[VAL_181]], %[[VAL_275]] : i32 // CHECK: %[[VAL_277:.*]] = llvm.mlir.constant(0 : i32) : i32 - // CHECK: %[[VAL_278:.*]] = llvm.add %[[VAL_182]], %[[VAL_277]] : i32 + // CHECK: %[[VAL_278:.*]] = llvm.xor %[[VAL_182]], %[[VAL_277]] : i32 // COM: Offsets of rep [1, 1]. // CHECK: %[[VAL_279:.*]] = llvm.mlir.constant(16 : i32) : i32 - // CHECK: %[[VAL_280:.*]] = llvm.add %[[VAL_181]], %[[VAL_279]] : i32 + // CHECK: %[[VAL_280:.*]] = llvm.xor %[[VAL_181]], %[[VAL_279]] : i32 // CHECK: %[[VAL_281:.*]] = llvm.mlir.constant(8 : i32) : i32 - // CHECK: %[[VAL_282:.*]] = llvm.add %[[VAL_182]], %[[VAL_281]] : i32 + // CHECK: %[[VAL_282:.*]] = llvm.xor %[[VAL_182]], %[[VAL_281]] : i32 // CHECK: %[[VAL_283:.*]] = llvm.mlir.constant(17 : i32) : i32 - // CHECK: %[[VAL_284:.*]] = llvm.add %[[VAL_181]], %[[VAL_283]] : i32 + // CHECK: %[[VAL_284:.*]] = llvm.xor %[[VAL_181]], %[[VAL_283]] : i32 // CHECK: %[[VAL_285:.*]] = llvm.mlir.constant(8 : i32) : i32 - // CHECK: %[[VAL_286:.*]] = llvm.add %[[VAL_182]], %[[VAL_285]] : i32 + // CHECK: %[[VAL_286:.*]] = llvm.xor %[[VAL_182]], %[[VAL_285]] : i32 // CHECK: %[[VAL_287:.*]] = llvm.mlir.constant(20 : i32) : i32 - // CHECK: %[[VAL_288:.*]] = llvm.add %[[VAL_181]], %[[VAL_287]] : i32 + // CHECK: %[[VAL_288:.*]] = llvm.xor %[[VAL_181]], %[[VAL_287]] : i32 // CHECK: %[[VAL_289:.*]] = llvm.mlir.constant(8 : i32) : i32 - // CHECK: %[[VAL_290:.*]] = llvm.add %[[VAL_182]], %[[VAL_289]] : i32 + // CHECK: %[[VAL_290:.*]] = llvm.xor %[[VAL_182]], %[[VAL_289]] : i32 // CHECK: %[[VAL_291:.*]] = llvm.mlir.constant(21 : i32) : i32 - // CHECK: %[[VAL_292:.*]] = llvm.add %[[VAL_181]], %[[VAL_291]] : i32 + // CHECK: %[[VAL_292:.*]] = llvm.xor %[[VAL_181]], %[[VAL_291]] : i32 // CHECK: %[[VAL_293:.*]] = llvm.mlir.constant(8 : i32) : i32 - // CHECK: %[[VAL_294:.*]] = llvm.add %[[VAL_182]], %[[VAL_293]] : i32 + // CHECK: %[[VAL_294:.*]] = llvm.xor %[[VAL_182]], %[[VAL_293]] : i32 // CHECK: %[[VAL_295:.*]] = llvm.mlir.constant(24 : i32) : i32 - // CHECK: %[[VAL_296:.*]] = llvm.add %[[VAL_181]], %[[VAL_295]] : i32 + // CHECK: %[[VAL_296:.*]] = llvm.xor %[[VAL_181]], %[[VAL_295]] : i32 // CHECK: %[[VAL_297:.*]] = llvm.mlir.constant(8 : i32) : i32 - // CHECK: %[[VAL_298:.*]] = llvm.add %[[VAL_182]], %[[VAL_297]] : i32 + // CHECK: %[[VAL_298:.*]] = llvm.xor %[[VAL_182]], %[[VAL_297]] : i32 // CHECK: %[[VAL_299:.*]] = llvm.mlir.constant(25 : i32) : i32 - // CHECK: %[[VAL_300:.*]] = llvm.add %[[VAL_181]], %[[VAL_299]] : i32 + // CHECK: %[[VAL_300:.*]] = llvm.xor %[[VAL_181]], %[[VAL_299]] : i32 // CHECK: %[[VAL_301:.*]] = llvm.mlir.constant(8 : i32) : i32 - // CHECK: %[[VAL_302:.*]] = llvm.add %[[VAL_182]], %[[VAL_301]] : i32 + // CHECK: %[[VAL_302:.*]] = llvm.xor %[[VAL_182]], %[[VAL_301]] : i32 // CHECK: %[[VAL_303:.*]] = llvm.mlir.constant(28 : i32) : i32 - // CHECK: %[[VAL_304:.*]] = llvm.add %[[VAL_181]], %[[VAL_303]] : i32 + // CHECK: %[[VAL_304:.*]] = llvm.xor %[[VAL_181]], %[[VAL_303]] : i32 // CHECK: %[[VAL_305:.*]] = llvm.mlir.constant(8 : i32) : i32 - // CHECK: %[[VAL_306:.*]] = llvm.add %[[VAL_182]], %[[VAL_305]] : i32 + // CHECK: %[[VAL_306:.*]] = llvm.xor %[[VAL_182]], %[[VAL_305]] : i32 // CHECK: %[[VAL_307:.*]] = llvm.mlir.constant(29 : i32) : i32 - // CHECK: %[[VAL_308:.*]] = llvm.add %[[VAL_181]], %[[VAL_307]] : i32 + // CHECK: %[[VAL_308:.*]] = llvm.xor %[[VAL_181]], %[[VAL_307]] : i32 // CHECK: %[[VAL_309:.*]] = llvm.mlir.constant(8 : i32) : i32 - // CHECK: %[[VAL_310:.*]] = llvm.add %[[VAL_182]], %[[VAL_309]] : i32 + // CHECK: %[[VAL_310:.*]] = llvm.xor %[[VAL_182]], %[[VAL_309]] : i32 // COM: Offsets of rep [0, 2]. // CHECK: %[[VAL_311:.*]] = llvm.mlir.constant(0 : i32) : i32 - // CHECK: %[[VAL_312:.*]] = llvm.add %[[VAL_181]], %[[VAL_311]] : i32 + // CHECK: %[[VAL_312:.*]] = llvm.xor %[[VAL_181]], %[[VAL_311]] : i32 // CHECK: %[[VAL_313:.*]] = llvm.mlir.constant(16 : i32) : i32 - // CHECK: %[[VAL_314:.*]] = llvm.add %[[VAL_182]], %[[VAL_313]] : i32 + // CHECK: %[[VAL_314:.*]] = llvm.xor %[[VAL_182]], %[[VAL_313]] : i32 // CHECK: %[[VAL_315:.*]] = llvm.mlir.constant(1 : i32) : i32 - // CHECK: %[[VAL_316:.*]] = llvm.add %[[VAL_181]], %[[VAL_315]] : i32 + // CHECK: %[[VAL_316:.*]] = llvm.xor %[[VAL_181]], %[[VAL_315]] : i32 // CHECK: %[[VAL_317:.*]] = llvm.mlir.constant(16 : i32) : i32 - // CHECK: %[[VAL_318:.*]] = llvm.add %[[VAL_182]], %[[VAL_317]] : i32 + // CHECK: %[[VAL_318:.*]] = llvm.xor %[[VAL_182]], %[[VAL_317]] : i32 // CHECK: %[[VAL_319:.*]] = llvm.mlir.constant(4 : i32) : i32 - // CHECK: %[[VAL_320:.*]] = llvm.add %[[VAL_181]], %[[VAL_319]] : i32 + // CHECK: %[[VAL_320:.*]] = llvm.xor %[[VAL_181]], %[[VAL_319]] : i32 // CHECK: %[[VAL_321:.*]] = llvm.mlir.constant(16 : i32) : i32 - // CHECK: %[[VAL_322:.*]] = llvm.add %[[VAL_182]], %[[VAL_321]] : i32 + // CHECK: %[[VAL_322:.*]] = llvm.xor %[[VAL_182]], %[[VAL_321]] : i32 // CHECK: %[[VAL_323:.*]] = llvm.mlir.constant(5 : i32) : i32 - // CHECK: %[[VAL_324:.*]] = llvm.add %[[VAL_181]], %[[VAL_323]] : i32 + // CHECK: %[[VAL_324:.*]] = llvm.xor %[[VAL_181]], %[[VAL_323]] : i32 // CHECK: %[[VAL_325:.*]] = llvm.mlir.constant(16 : i32) : i32 - // CHECK: %[[VAL_326:.*]] = llvm.add %[[VAL_182]], %[[VAL_325]] : i32 + // CHECK: %[[VAL_326:.*]] = llvm.xor %[[VAL_182]], %[[VAL_325]] : i32 // CHECK: %[[VAL_327:.*]] = llvm.mlir.constant(8 : i32) : i32 - // CHECK: %[[VAL_328:.*]] = llvm.add %[[VAL_181]], %[[VAL_327]] : i32 + // CHECK: %[[VAL_328:.*]] = llvm.xor %[[VAL_181]], %[[VAL_327]] : i32 // CHECK: %[[VAL_329:.*]] = llvm.mlir.constant(16 : i32) : i32 - // CHECK: %[[VAL_330:.*]] = llvm.add %[[VAL_182]], %[[VAL_329]] : i32 + // CHECK: %[[VAL_330:.*]] = llvm.xor %[[VAL_182]], %[[VAL_329]] : i32 // CHECK: %[[VAL_331:.*]] = llvm.mlir.constant(9 : i32) : i32 - // CHECK: %[[VAL_332:.*]] = llvm.add %[[VAL_181]], %[[VAL_331]] : i32 + // CHECK: %[[VAL_332:.*]] = llvm.xor %[[VAL_181]], %[[VAL_331]] : i32 // CHECK: %[[VAL_333:.*]] = llvm.mlir.constant(16 : i32) : i32 - // CHECK: %[[VAL_334:.*]] = llvm.add %[[VAL_182]], %[[VAL_333]] : i32 + // CHECK: %[[VAL_334:.*]] = llvm.xor %[[VAL_182]], %[[VAL_333]] : i32 // CHECK: %[[VAL_335:.*]] = llvm.mlir.constant(12 : i32) : i32 - // CHECK: %[[VAL_336:.*]] = llvm.add %[[VAL_181]], %[[VAL_335]] : i32 + // CHECK: %[[VAL_336:.*]] = llvm.xor %[[VAL_181]], %[[VAL_335]] : i32 // CHECK: %[[VAL_337:.*]] = llvm.mlir.constant(16 : i32) : i32 - // CHECK: %[[VAL_338:.*]] = llvm.add %[[VAL_182]], %[[VAL_337]] : i32 + // CHECK: %[[VAL_338:.*]] = llvm.xor %[[VAL_182]], %[[VAL_337]] : i32 // CHECK: %[[VAL_339:.*]] = llvm.mlir.constant(13 : i32) : i32 - // CHECK: %[[VAL_340:.*]] = llvm.add %[[VAL_181]], %[[VAL_339]] : i32 + // CHECK: %[[VAL_340:.*]] = llvm.xor %[[VAL_181]], %[[VAL_339]] : i32 // CHECK: %[[VAL_341:.*]] = llvm.mlir.constant(16 : i32) : i32 - // CHECK: %[[VAL_342:.*]] = llvm.add %[[VAL_182]], %[[VAL_341]] : i32 + // CHECK: %[[VAL_342:.*]] = llvm.xor %[[VAL_182]], %[[VAL_341]] : i32 // COM: Offsets of rep [0, 3]. // CHECK: %[[VAL_343:.*]] = llvm.mlir.constant(0 : i32) : i32 - // CHECK: %[[VAL_344:.*]] = llvm.add %[[VAL_181]], %[[VAL_343]] : i32 + // CHECK: %[[VAL_344:.*]] = llvm.xor %[[VAL_181]], %[[VAL_343]] : i32 // CHECK: %[[VAL_345:.*]] = llvm.mlir.constant(24 : i32) : i32 - // CHECK: %[[VAL_346:.*]] = llvm.add %[[VAL_182]], %[[VAL_345]] : i32 + // CHECK: %[[VAL_346:.*]] = llvm.xor %[[VAL_182]], %[[VAL_345]] : i32 // CHECK: %[[VAL_347:.*]] = llvm.mlir.constant(1 : i32) : i32 - // CHECK: %[[VAL_348:.*]] = llvm.add %[[VAL_181]], %[[VAL_347]] : i32 + // CHECK: %[[VAL_348:.*]] = llvm.xor %[[VAL_181]], %[[VAL_347]] : i32 // CHECK: %[[VAL_349:.*]] = llvm.mlir.constant(24 : i32) : i32 - // CHECK: %[[VAL_350:.*]] = llvm.add %[[VAL_182]], %[[VAL_349]] : i32 + // CHECK: %[[VAL_350:.*]] = llvm.xor %[[VAL_182]], %[[VAL_349]] : i32 // CHECK: %[[VAL_351:.*]] = llvm.mlir.constant(4 : i32) : i32 - // CHECK: %[[VAL_352:.*]] = llvm.add %[[VAL_181]], %[[VAL_351]] : i32 + // CHECK: %[[VAL_352:.*]] = llvm.xor %[[VAL_181]], %[[VAL_351]] : i32 // CHECK: %[[VAL_353:.*]] = llvm.mlir.constant(24 : i32) : i32 - // CHECK: %[[VAL_354:.*]] = llvm.add %[[VAL_182]], %[[VAL_353]] : i32 + // CHECK: %[[VAL_354:.*]] = llvm.xor %[[VAL_182]], %[[VAL_353]] : i32 // CHECK: %[[VAL_355:.*]] = llvm.mlir.constant(5 : i32) : i32 - // CHECK: %[[VAL_356:.*]] = llvm.add %[[VAL_181]], %[[VAL_355]] : i32 + // CHECK: %[[VAL_356:.*]] = llvm.xor %[[VAL_181]], %[[VAL_355]] : i32 // CHECK: %[[VAL_357:.*]] = llvm.mlir.constant(24 : i32) : i32 - // CHECK: %[[VAL_358:.*]] = llvm.add %[[VAL_182]], %[[VAL_357]] : i32 + // CHECK: %[[VAL_358:.*]] = llvm.xor %[[VAL_182]], %[[VAL_357]] : i32 // CHECK: %[[VAL_359:.*]] = llvm.mlir.constant(8 : i32) : i32 - // CHECK: %[[VAL_360:.*]] = llvm.add %[[VAL_181]], %[[VAL_359]] : i32 + // CHECK: %[[VAL_360:.*]] = llvm.xor %[[VAL_181]], %[[VAL_359]] : i32 // CHECK: %[[VAL_361:.*]] = llvm.mlir.constant(24 : i32) : i32 - // CHECK: %[[VAL_362:.*]] = llvm.add %[[VAL_182]], %[[VAL_361]] : i32 + // CHECK: %[[VAL_362:.*]] = llvm.xor %[[VAL_182]], %[[VAL_361]] : i32 // CHECK: %[[VAL_363:.*]] = llvm.mlir.constant(9 : i32) : i32 - // CHECK: %[[VAL_364:.*]] = llvm.add %[[VAL_181]], %[[VAL_363]] : i32 + // CHECK: %[[VAL_364:.*]] = llvm.xor %[[VAL_181]], %[[VAL_363]] : i32 // CHECK: %[[VAL_365:.*]] = llvm.mlir.constant(24 : i32) : i32 - // CHECK: %[[VAL_366:.*]] = llvm.add %[[VAL_182]], %[[VAL_365]] : i32 + // CHECK: %[[VAL_366:.*]] = llvm.xor %[[VAL_182]], %[[VAL_365]] : i32 // CHECK: %[[VAL_367:.*]] = llvm.mlir.constant(12 : i32) : i32 - // CHECK: %[[VAL_368:.*]] = llvm.add %[[VAL_181]], %[[VAL_367]] : i32 + // CHECK: %[[VAL_368:.*]] = llvm.xor %[[VAL_181]], %[[VAL_367]] : i32 // CHECK: %[[VAL_369:.*]] = llvm.mlir.constant(24 : i32) : i32 - // CHECK: %[[VAL_370:.*]] = llvm.add %[[VAL_182]], %[[VAL_369]] : i32 + // CHECK: %[[VAL_370:.*]] = llvm.xor %[[VAL_182]], %[[VAL_369]] : i32 // CHECK: %[[VAL_371:.*]] = llvm.mlir.constant(13 : i32) : i32 - // CHECK: %[[VAL_372:.*]] = llvm.add %[[VAL_181]], %[[VAL_371]] : i32 + // CHECK: %[[VAL_372:.*]] = llvm.xor %[[VAL_181]], %[[VAL_371]] : i32 // CHECK: %[[VAL_373:.*]] = llvm.mlir.constant(24 : i32) : i32 - // CHECK: %[[VAL_374:.*]] = llvm.add %[[VAL_182]], %[[VAL_373]] : i32 + // CHECK: %[[VAL_374:.*]] = llvm.xor %[[VAL_182]], %[[VAL_373]] : i32 // COM: Offsets of rep [1, 2]. // CHECK: %[[VAL_375:.*]] = llvm.mlir.constant(16 : i32) : i32 - // CHECK: %[[VAL_376:.*]] = llvm.add %[[VAL_181]], %[[VAL_375]] : i32 + // CHECK: %[[VAL_376:.*]] = llvm.xor %[[VAL_181]], %[[VAL_375]] : i32 // CHECK: %[[VAL_377:.*]] = llvm.mlir.constant(16 : i32) : i32 - // CHECK: %[[VAL_378:.*]] = llvm.add %[[VAL_182]], %[[VAL_377]] : i32 + // CHECK: %[[VAL_378:.*]] = llvm.xor %[[VAL_182]], %[[VAL_377]] : i32 // CHECK: %[[VAL_379:.*]] = llvm.mlir.constant(17 : i32) : i32 - // CHECK: %[[VAL_380:.*]] = llvm.add %[[VAL_181]], %[[VAL_379]] : i32 + // CHECK: %[[VAL_380:.*]] = llvm.xor %[[VAL_181]], %[[VAL_379]] : i32 // CHECK: %[[VAL_381:.*]] = llvm.mlir.constant(16 : i32) : i32 - // CHECK: %[[VAL_382:.*]] = llvm.add %[[VAL_182]], %[[VAL_381]] : i32 + // CHECK: %[[VAL_382:.*]] = llvm.xor %[[VAL_182]], %[[VAL_381]] : i32 // CHECK: %[[VAL_383:.*]] = llvm.mlir.constant(20 : i32) : i32 - // CHECK: %[[VAL_384:.*]] = llvm.add %[[VAL_181]], %[[VAL_383]] : i32 + // CHECK: %[[VAL_384:.*]] = llvm.xor %[[VAL_181]], %[[VAL_383]] : i32 // CHECK: %[[VAL_385:.*]] = llvm.mlir.constant(16 : i32) : i32 - // CHECK: %[[VAL_386:.*]] = llvm.add %[[VAL_182]], %[[VAL_385]] : i32 + // CHECK: %[[VAL_386:.*]] = llvm.xor %[[VAL_182]], %[[VAL_385]] : i32 // CHECK: %[[VAL_387:.*]] = llvm.mlir.constant(21 : i32) : i32 - // CHECK: %[[VAL_388:.*]] = llvm.add %[[VAL_181]], %[[VAL_387]] : i32 + // CHECK: %[[VAL_388:.*]] = llvm.xor %[[VAL_181]], %[[VAL_387]] : i32 // CHECK: %[[VAL_389:.*]] = llvm.mlir.constant(16 : i32) : i32 - // CHECK: %[[VAL_390:.*]] = llvm.add %[[VAL_182]], %[[VAL_389]] : i32 + // CHECK: %[[VAL_390:.*]] = llvm.xor %[[VAL_182]], %[[VAL_389]] : i32 // CHECK: %[[VAL_391:.*]] = llvm.mlir.constant(24 : i32) : i32 - // CHECK: %[[VAL_392:.*]] = llvm.add %[[VAL_181]], %[[VAL_391]] : i32 + // CHECK: %[[VAL_392:.*]] = llvm.xor %[[VAL_181]], %[[VAL_391]] : i32 // CHECK: %[[VAL_393:.*]] = llvm.mlir.constant(16 : i32) : i32 - // CHECK: %[[VAL_394:.*]] = llvm.add %[[VAL_182]], %[[VAL_393]] : i32 + // CHECK: %[[VAL_394:.*]] = llvm.xor %[[VAL_182]], %[[VAL_393]] : i32 // CHECK: %[[VAL_395:.*]] = llvm.mlir.constant(25 : i32) : i32 - // CHECK: %[[VAL_396:.*]] = llvm.add %[[VAL_181]], %[[VAL_395]] : i32 + // CHECK: %[[VAL_396:.*]] = llvm.xor %[[VAL_181]], %[[VAL_395]] : i32 // CHECK: %[[VAL_397:.*]] = llvm.mlir.constant(16 : i32) : i32 - // CHECK: %[[VAL_398:.*]] = llvm.add %[[VAL_182]], %[[VAL_397]] : i32 + // CHECK: %[[VAL_398:.*]] = llvm.xor %[[VAL_182]], %[[VAL_397]] : i32 // CHECK: %[[VAL_399:.*]] = llvm.mlir.constant(28 : i32) : i32 - // CHECK: %[[VAL_400:.*]] = llvm.add %[[VAL_181]], %[[VAL_399]] : i32 + // CHECK: %[[VAL_400:.*]] = llvm.xor %[[VAL_181]], %[[VAL_399]] : i32 // CHECK: %[[VAL_401:.*]] = llvm.mlir.constant(16 : i32) : i32 - // CHECK: %[[VAL_402:.*]] = llvm.add %[[VAL_182]], %[[VAL_401]] : i32 + // CHECK: %[[VAL_402:.*]] = llvm.xor %[[VAL_182]], %[[VAL_401]] : i32 // CHECK: %[[VAL_403:.*]] = llvm.mlir.constant(29 : i32) : i32 - // CHECK: %[[VAL_404:.*]] = llvm.add %[[VAL_181]], %[[VAL_403]] : i32 + // CHECK: %[[VAL_404:.*]] = llvm.xor %[[VAL_181]], %[[VAL_403]] : i32 // CHECK: %[[VAL_405:.*]] = llvm.mlir.constant(16 : i32) : i32 - // CHECK: %[[VAL_406:.*]] = llvm.add %[[VAL_182]], %[[VAL_405]] : i32 + // CHECK: %[[VAL_406:.*]] = llvm.xor %[[VAL_182]], %[[VAL_405]] : i32 // COM: Offsets of rep [1, 3]. // CHECK: %[[VAL_407:.*]] = llvm.mlir.constant(16 : i32) : i32 - // CHECK: %[[VAL_408:.*]] = llvm.add %[[VAL_181]], %[[VAL_407]] : i32 + // CHECK: %[[VAL_408:.*]] = llvm.xor %[[VAL_181]], %[[VAL_407]] : i32 // CHECK: %[[VAL_409:.*]] = llvm.mlir.constant(24 : i32) : i32 - // CHECK: %[[VAL_410:.*]] = llvm.add %[[VAL_182]], %[[VAL_409]] : i32 + // CHECK: %[[VAL_410:.*]] = llvm.xor %[[VAL_182]], %[[VAL_409]] : i32 // CHECK: %[[VAL_411:.*]] = llvm.mlir.constant(17 : i32) : i32 - // CHECK: %[[VAL_412:.*]] = llvm.add %[[VAL_181]], %[[VAL_411]] : i32 + // CHECK: %[[VAL_412:.*]] = llvm.xor %[[VAL_181]], %[[VAL_411]] : i32 // CHECK: %[[VAL_413:.*]] = llvm.mlir.constant(24 : i32) : i32 - // CHECK: %[[VAL_414:.*]] = llvm.add %[[VAL_182]], %[[VAL_413]] : i32 + // CHECK: %[[VAL_414:.*]] = llvm.xor %[[VAL_182]], %[[VAL_413]] : i32 // CHECK: %[[VAL_415:.*]] = llvm.mlir.constant(20 : i32) : i32 - // CHECK: %[[VAL_416:.*]] = llvm.add %[[VAL_181]], %[[VAL_415]] : i32 + // CHECK: %[[VAL_416:.*]] = llvm.xor %[[VAL_181]], %[[VAL_415]] : i32 // CHECK: %[[VAL_417:.*]] = llvm.mlir.constant(24 : i32) : i32 - // CHECK: %[[VAL_418:.*]] = llvm.add %[[VAL_182]], %[[VAL_417]] : i32 + // CHECK: %[[VAL_418:.*]] = llvm.xor %[[VAL_182]], %[[VAL_417]] : i32 // CHECK: %[[VAL_419:.*]] = llvm.mlir.constant(21 : i32) : i32 - // CHECK: %[[VAL_420:.*]] = llvm.add %[[VAL_181]], %[[VAL_419]] : i32 + // CHECK: %[[VAL_420:.*]] = llvm.xor %[[VAL_181]], %[[VAL_419]] : i32 // CHECK: %[[VAL_421:.*]] = llvm.mlir.constant(24 : i32) : i32 - // CHECK: %[[VAL_422:.*]] = llvm.add %[[VAL_182]], %[[VAL_421]] : i32 + // CHECK: %[[VAL_422:.*]] = llvm.xor %[[VAL_182]], %[[VAL_421]] : i32 // CHECK: %[[VAL_423:.*]] = llvm.mlir.constant(24 : i32) : i32 - // CHECK: %[[VAL_424:.*]] = llvm.add %[[VAL_181]], %[[VAL_423]] : i32 + // CHECK: %[[VAL_424:.*]] = llvm.xor %[[VAL_181]], %[[VAL_423]] : i32 // CHECK: %[[VAL_425:.*]] = llvm.mlir.constant(24 : i32) : i32 - // CHECK: %[[VAL_426:.*]] = llvm.add %[[VAL_182]], %[[VAL_425]] : i32 + // CHECK: %[[VAL_426:.*]] = llvm.xor %[[VAL_182]], %[[VAL_425]] : i32 // CHECK: %[[VAL_427:.*]] = llvm.mlir.constant(25 : i32) : i32 - // CHECK: %[[VAL_428:.*]] = llvm.add %[[VAL_181]], %[[VAL_427]] : i32 + // CHECK: %[[VAL_428:.*]] = llvm.xor %[[VAL_181]], %[[VAL_427]] : i32 // CHECK: %[[VAL_429:.*]] = llvm.mlir.constant(24 : i32) : i32 - // CHECK: %[[VAL_430:.*]] = llvm.add %[[VAL_182]], %[[VAL_429]] : i32 + // CHECK: %[[VAL_430:.*]] = llvm.xor %[[VAL_182]], %[[VAL_429]] : i32 // CHECK: %[[VAL_431:.*]] = llvm.mlir.constant(28 : i32) : i32 - // CHECK: %[[VAL_432:.*]] = llvm.add %[[VAL_181]], %[[VAL_431]] : i32 + // CHECK: %[[VAL_432:.*]] = llvm.xor %[[VAL_181]], %[[VAL_431]] : i32 // CHECK: %[[VAL_433:.*]] = llvm.mlir.constant(24 : i32) : i32 - // CHECK: %[[VAL_434:.*]] = llvm.add %[[VAL_182]], %[[VAL_433]] : i32 + // CHECK: %[[VAL_434:.*]] = llvm.xor %[[VAL_182]], %[[VAL_433]] : i32 // CHECK: %[[VAL_435:.*]] = llvm.mlir.constant(29 : i32) : i32 - // CHECK: %[[VAL_436:.*]] = llvm.add %[[VAL_181]], %[[VAL_435]] : i32 + // CHECK: %[[VAL_436:.*]] = llvm.xor %[[VAL_181]], %[[VAL_435]] : i32 // CHECK: %[[VAL_437:.*]] = llvm.mlir.constant(24 : i32) : i32 - // CHECK: %[[VAL_438:.*]] = llvm.add %[[VAL_182]], %[[VAL_437]] : i32 + // CHECK: %[[VAL_438:.*]] = llvm.xor %[[VAL_182]], %[[VAL_437]] : i32 tt.print " x: " {hex = false, isSigned = array} : %cst : tensor<32x32xf16, #dot_operand_b> tt.return } diff --git a/test/Conversion/tritongpu_to_llvm_block_dot_shortcut.mlir b/test/Conversion/tritongpu_to_llvm_block_dot_shortcut.mlir new file mode 100644 index 0000000000..49128064a8 --- /dev/null +++ b/test/Conversion/tritongpu_to_llvm_block_dot_shortcut.mlir @@ -0,0 +1,47 @@ +// RUN: triton-opt %s -split-input-file --allocate-shared-memory --convert-triton-gpu-to-llvm | FileCheck %s + +// CHECK-LABEL: blocked_to_dot_op_shortcut_warp32 +#blocked = #triton_gpu.blocked<{sizePerThread = [32, 1], threadsPerWarp = [1, 32], warpsPerCTA = [2, 2], order = [0, 1]}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "cuda:80", "triton_gpu.threads-per-warp" = 32 : i32} { + tt.func @blocked_to_dot_op_shortcut_warp32(%arg0: tensor<32x32xf16, #blocked>, %arg1: tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>>) { + %0 = triton_gpu.convert_layout %arg0 : tensor<32x32xf16, #blocked> -> tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>> + // CHECK-NOT: load + tt.return + } +} + +// ----- + +// CHECK-LABEL: blocked_to_dot_op_shortcut_warp64 +#blocked = #triton_gpu.blocked<{sizePerThread = [32, 1], threadsPerWarp = [2, 32], warpsPerCTA = [2, 2], order = [1, 0]}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "hip:gfx940", "triton_gpu.threads-per-warp" = 64 : i32} { + tt.func @blocked_to_dot_op_shortcut_warp64(%arg0: tensor<32x32xf16, #blocked>) { + %0 = triton_gpu.convert_layout %arg0 : tensor<32x32xf16, #blocked> -> tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>> + // CHECK-NOT: load + tt.return + } +} + +// ----- + +// CHECK-LABEL: blocked_to_dot3d_op_shortcut_warp32 +#blocked = #triton_gpu.blocked<{sizePerThread = [2, 32, 1], threadsPerWarp = [1, 1, 32], warpsPerCTA = [2, 1, 2], order = [1, 2, 0]}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "cuda:80", "triton_gpu.threads-per-warp" = 32 : i32} { + tt.func @blocked_to_dot3d_op_shortcut_warp32(%arg0: tensor<8x32x32xf16, #blocked>) { + %0 = triton_gpu.convert_layout %arg0 : tensor<8x32x32xf16, #blocked> -> tensor<8x32x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>> + // CHECK-NOT: load + tt.return + } +} + +// ----- + +// CHECK-LABEL: blocked_to_dot3d_op_shortcut_warp64 +#blocked = #triton_gpu.blocked<{sizePerThread = [1, 32, 1], threadsPerWarp = [1, 2, 32], warpsPerCTA = [2, 2, 1], order = [2, 1, 0]}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "hip:gfx940", "triton_gpu.threads-per-warp" = 64 : i32} { + tt.func @blocked_to_dot3d_op_shortcut_warp64(%arg0: tensor<8x32x32xf16, #blocked>) { + %0 = triton_gpu.convert_layout %arg0 : tensor<8x32x32xf16, #blocked> -> tensor<8x32x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>> + // CHECK-NOT: load + tt.return + } +} diff --git a/test/Triton/rewrite-tensor-pointer.mlir b/test/Triton/rewrite-tensor-pointer.mlir index 1d659fa031..26625c3a0f 100644 --- a/test/Triton/rewrite-tensor-pointer.mlir +++ b/test/Triton/rewrite-tensor-pointer.mlir @@ -111,7 +111,7 @@ tt.func public @rewrite_for(%arg0: !tt.ptr, %arg1: !tt.ptr) { %4 = arith.addf %arg3, %3 : tensor<128x32xf16> %5 = tt.advance %arg4, [%c32_i32, %c0_i32] : !tt.ptr> scf.yield %4, %5 : tensor<128x32xf16>, !tt.ptr> - } + } {tt.num_stages = 3 : i32} %2 = tt.splat %arg1 : !tt.ptr -> tensor<128x32x!tt.ptr> tt.store %2, %1#0 : tensor<128x32x!tt.ptr> tt.return @@ -138,6 +138,7 @@ tt.func public @rewrite_for(%arg0: !tt.ptr, %arg1: !tt.ptr) { // CHECK: %[[EXTSI3:.*]] = arith.extsi %[[C0_I32]] : i32 to i64 // CHECK: %[[ADDI1:.*]] = arith.addi %[[ARG5]], %[[EXTSI3]] : i64 // CHECK: scf.yield %{{.*}}, %[[ADDI0]], %[[ADDI1]] : tensor<128x32xf16>, i64, i64 +// CHECK: tt.num_stages = 3 // ----- tt.func public @rewrite_if(%arg0: !tt.ptr, %arg1: i1) -> tensor<128x32xf16> { diff --git a/test/TritonGPU/loop-pipeline-hip.mlir b/test/TritonGPU/loop-pipeline-hip.mlir index 55e1b65faa..28c815febb 100644 --- a/test/TritonGPU/loop-pipeline-hip.mlir +++ b/test/TritonGPU/loop-pipeline-hip.mlir @@ -198,3 +198,38 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : tt.return } } // end module + +// ----- + +// CHECK-NOT: #triton_gpu.shared<{{.*}} order = [2, 0, 1] +// CHECK: #triton_gpu.shared<{{.*}} order = [2, 1, 0] +// CHECK-NOT: #triton_gpu.shared<{{.*}} order = [2, 0, 1] + +// CHECK-LABEL: tt.func public @slowest_dim_is_batch +#blocked = #triton_gpu.blocked<{sizePerThread = [1, 1, 2], threadsPerWarp = [4, 1, 16], warpsPerCTA = [4, 1, 1], order = [2, 1, 0]}> +#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 1, 8], threadsPerWarp = [16, 1, 4], warpsPerCTA = [4, 1, 1], order = [2, 0, 1]}> +#blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 2], threadsPerWarp = [1, 64], warpsPerCTA = [1, 4], order = [1, 0]}> +#blocked5 = #triton_gpu.blocked<{sizePerThread = [1, 1, 2], threadsPerWarp = [16, 1, 4], warpsPerCTA = [4, 1, 1], order = [2, 0, 1]}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "hip:gfx90a", "triton_gpu.threads-per-warp" = 64 : i32} { + tt.func public @slowest_dim_is_batch(%arg0: tensor<1x512x!tt.ptr, #blocked2>, %arg1: tensor<64x8x32x!tt.ptr, #blocked1>, %arg2: tensor<64x1x32x!tt.ptr, #blocked>) attributes {noinline = false} { + %cst = arith.constant dense<0.000000e+00> : tensor<64x1x32xf32, #blocked> + %cst_0 = arith.constant dense<512> : tensor<1x512xi32, #blocked2> + %cst_1 = arith.constant dense<128> : tensor<64x8x32xi32, #blocked1> + %c1_i32 = arith.constant 1 : i32 + %c5_i32 = arith.constant 2 : i32 + %c0_i32 = arith.constant 0 : i32 + %33:3 = scf.for %arg7 = %c0_i32 to %c5_i32 step %c1_i32 iter_args(%arg8 = %cst, %arg9 = %arg0, %arg10 = %arg1) -> (tensor<64x1x32xf32, #blocked>, tensor<1x512x!tt.ptr, #blocked2>, tensor<64x8x32x!tt.ptr, #blocked1>) : i32 { + %39 = tt.load %arg9 : tensor<1x512x!tt.ptr, #blocked2> + %40 = tt.load %arg10 : tensor<64x8x32x!tt.ptr, #blocked1> + %41 = tt.reshape %39 {allow_reorder = true} : tensor<1x512xf32, #blocked2> -> tensor<64x1x8xf32, #blocked5> + %43 = triton_gpu.convert_layout %41 : tensor<64x1x8xf32, #blocked5> -> tensor<64x1x8xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>> + %44 = triton_gpu.convert_layout %40 : tensor<64x8x32xf32, #blocked1> -> tensor<64x8x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>> + %45 = tt.dot %43, %44, %arg8 : tensor<64x1x8xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<64x8x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<64x1x32xf32, #blocked> + %46 = tt.addptr %arg9, %cst_0 : tensor<1x512x!tt.ptr, #blocked2>, tensor<1x512xi32, #blocked2> + %47 = tt.addptr %arg10, %cst_1 : tensor<64x8x32x!tt.ptr, #blocked1>, tensor<64x8x32xi32, #blocked1> + scf.yield %45, %46, %47 : tensor<64x1x32xf32, #blocked>, tensor<1x512x!tt.ptr, #blocked2>, tensor<64x8x32x!tt.ptr, #blocked1> + } + tt.store %arg2, %33#0 : tensor<64x1x32x!tt.ptr, #blocked> + tt.return + } +} diff --git a/test/TritonGPU/reduce-data-duplication.mlir b/test/TritonGPU/reduce-data-duplication.mlir index e98a6108d7..9fca92c9b0 100644 --- a/test/TritonGPU/reduce-data-duplication.mlir +++ b/test/TritonGPU/reduce-data-duplication.mlir @@ -1,8 +1,8 @@ // RUN: triton-opt %s -split-input-file -tritongpu-reduce-data-duplication | FileCheck %s -// CHECK: #[[SHARED:.*]] = #triton_gpu.shared<{vec = 8, perPhase = 8, maxPhase = 2, order = [0, 1], hasLeadingOffset = false} -// CHECK: apply_swizzle -// CHECK: %{{.*}} = triton_gpu.local_alloc %{{.*}} : (tensor<16x256xf16, #{{.*}}>) -> !tt.memdesc<16x256xf16, #[[SHARED]], #triton_gpu.shared_memory> +// CHECK: #[[$SHARED:.*]] = #triton_gpu.shared<{vec = 8, perPhase = 8, maxPhase = 2, order = [0, 1], hasLeadingOffset = false} +// CHECK-LABEL: apply_swizzle +// CHECK: %{{.*}} = triton_gpu.local_alloc %{{.*}} : (tensor<16x256xf16, #{{.*}}>) -> !tt.memdesc<16x256xf16, #[[$SHARED]], #triton_gpu.shared_memory> #blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [16, 2], warpsPerCTA = [1, 4], order = [0, 1]}> #mma = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 4], instrShape = [16, 8]}> @@ -12,3 +12,31 @@ module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : tt.return } } + +// ----- + +// CHECK-LABEL: conversion_shortcut_blocked_dotop_warp32 +// CHECK-NOT: triton_gpu.local_alloc +// CHECK: triton_gpu.convert_layout +// CHECK-NOT: triton_gpu.local_alloc +#blocked = #triton_gpu.blocked<{sizePerThread = [1, 64], threadsPerWarp = [16, 2], warpsPerCTA = [2, 2], order = [0, 1]}> +module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { + tt.func @conversion_shortcut_blocked_dotop_warp32(%arg0: tensor<64x64xf16, #blocked>) { + %0 = triton_gpu.convert_layout %arg0 : tensor<64x64xf16, #blocked> -> tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>> + tt.return + } +} + +// ----- + +// CHECK-LABEL: conversion_shortcut_blocked_dotop_warp64 +// CHECK-NOT: triton_gpu.local_alloc +// CHECK: triton_gpu.convert_layout +// CHECK-NOT: triton_gpu.local_alloc +#blocked = #triton_gpu.blocked<{sizePerThread = [1, 64], threadsPerWarp = [32, 2], warpsPerCTA = [2, 2], order = [0, 1]}> +module attributes {"triton_gpu.target" = "hip:gfx940", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { + tt.func @conversion_shortcut_blocked_dotop_warp64(%arg0: tensor<64x64xf16, #blocked>) { + %0 = triton_gpu.convert_layout %arg0 : tensor<64x64xf16, #blocked> -> tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>> + tt.return + } +} diff --git a/test/TritonIntelGPU/schedule-load.mlir b/test/TritonIntelGPU/schedule-load.mlir index c984cfabd5..6352b6d033 100644 --- a/test/TritonIntelGPU/schedule-load.mlir +++ b/test/TritonIntelGPU/schedule-load.mlir @@ -303,3 +303,27 @@ module attributes {"triton_gpu.num-warps" = 32 : i32, "triton_gpu.threads-per-wa tt.return } } + +// ----- + +tt.func public @test(%arg0: !tt.ptr>, %arg1: !tt.ptr>) { + %lb = arith.constant 0 : i32 + %ub = tt.get_program_id x : i32 + %st = arith.constant 32 : i32 + %zero = arith.constant dense<0.000000e+00> : tensor<8x16xf32> + %common = tt.load %arg1 {DotIdx = 0 : i32} : !tt.ptr> + // COM: Check %common is not moved in the loop. + // CHECK: tt.load %arg1 + // CHECK-COUNT-2: scf.for + scf.for %iv0 = %lb to %ub step %st : i32 { + %load1 = tt.load %arg0 {DotIdx = 1 : i32} : !tt.ptr> + %extract1 = triton_intel_gpu.extract %common[0] : tensor<8x32xf16> -> tensor<8x16xf16> + %dot1 = tt.dot %extract1, %load1, %zero, inputPrecision = tf32 {"schedule-group" = 0 : i32} : tensor<8x16xf16> * tensor<16x16xf16> -> tensor<8x16xf32> + } + scf.for %iv1 = %lb to %ub step %st : i32 { + %load2 = tt.load %arg0 {DotIdx = 1 : i32} : !tt.ptr> + %extract2 = triton_intel_gpu.extract %common[0] : tensor<8x32xf16> -> tensor<8x16xf16> + %dot2 = tt.dot %extract2, %load2, %zero, inputPrecision = tf32 {"schedule-group" = 0 : i32} : tensor<8x16xf16> * tensor<16x16xf16> -> tensor<8x16xf32> + } + tt.return +} diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/DotOpToLLVM/WMMA.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/DotOpToLLVM/WMMA.cpp index 2c93d6f0ee..9368443255 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/DotOpToLLVM/WMMA.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/DotOpToLLVM/WMMA.cpp @@ -162,6 +162,8 @@ std::string getTypeStr(Type ty) { scalarName = "bf16"; } else if (ty.isInteger(32)) { scalarName = "i32"; + } else if (ty.isInteger(16)) { + scalarName = "i16"; } else if (ty.isInteger(8)) { scalarName = "iu8"; } else if (ty.isInteger(4)) { diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/StreamPipeline.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/StreamPipeline.cpp index 224df90283..784ce52e1b 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/StreamPipeline.cpp +++ b/third_party/amd/lib/TritonAMDGPUTransforms/StreamPipeline.cpp @@ -403,9 +403,24 @@ void LoopPipeliner::createBufferTypes() { // unsigned bitWidth = dotOpEnc.getMMAv2kWidth() // ? 32 / dotOpEnc.getMMAv2kWidth() // : ty.getElementType().getIntOrFloatBitWidth(); - auto sharedEnc = ttg::SharedEncodingAttr::get( - ty.getContext(), dotOpEnc, ty.getShape(), - ttg::getOrder(ty.getEncoding()), CTALayout, eType); + auto srcOrder = ttg::getOrder(ty.getEncoding()); + SmallVector sharedOrder; + int rank = srcOrder.size(); + // TODO rework this when shared -> dotOp conversions support arbitrary + // shared memory ordering + if (rank == 3) { + // Move the batch dimension (dim #0) to be the last so that it will be the + // slowest varying dimension. + for (unsigned i = 0; i < rank; ++i) + if (srcOrder[i] != 0) + sharedOrder.emplace_back(srcOrder[i]); + sharedOrder.emplace_back(0); + } else { + sharedOrder = srcOrder; + } + auto sharedEnc = + ttg::SharedEncodingAttr::get(ty.getContext(), dotOpEnc, ty.getShape(), + sharedOrder, CTALayout, eType); loadsBufferType[loadOp] = triton::MemDescType::get( bufferShape, eType, sharedEnc, triton::gpu::SharedMemorySpaceAttr::get(ty.getContext()), diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/StreamPipelineV2.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/StreamPipelineV2.cpp index f1d04b7270..027f06652f 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/StreamPipelineV2.cpp +++ b/third_party/amd/lib/TritonAMDGPUTransforms/StreamPipelineV2.cpp @@ -207,8 +207,22 @@ getSharedEncIfAllUsersAreDotEnc(Value val) { auto CTALayout = ttg::getCTALayout(srcTy.getEncoding()); auto order = ttg::getOrder(srcTy.getEncoding()); unsigned bitWidth = srcTy.getElementType().getIntOrFloatBitWidth(); + SmallVector sharedOrder; + int rank = order.size(); + // TODO rework this when shared -> dotOp conversions support arbitrary + // shared memory ordering + if (rank == 3) { + // Move the batch dimension (dim #0) to be the last so that it will be + // the slowest varying dimension. + for (unsigned i = 0; i < rank; ++i) + if (order[i] != 0) + sharedOrder.emplace_back(order[i]); + sharedOrder.emplace_back(0); + } else { + sharedOrder = order; + } tempAttr = ttg::SharedEncodingAttr::get( - val.getContext(), dotOpEnc, srcTy.getShape(), order, CTALayout, + val.getContext(), dotOpEnc, srcTy.getShape(), sharedOrder, CTALayout, bitWidth, /*needTrans=*/false); } // Check that the shared encodings needed by the users are compatible. diff --git a/third_party/intel/include/Dialect/TritonIntelGPU/IR/LinearLayoutConversions.h b/third_party/intel/include/Dialect/TritonIntelGPU/IR/LinearLayoutConversions.h index 2758e6341a..249849520f 100644 --- a/third_party/intel/include/Dialect/TritonIntelGPU/IR/LinearLayoutConversions.h +++ b/third_party/intel/include/Dialect/TritonIntelGPU/IR/LinearLayoutConversions.h @@ -18,6 +18,10 @@ namespace mlir::triton::gpu { LinearLayout DPAStoLinearLayout(ArrayRef shape, Attribute layout, unsigned opIdx = 2); +std::optional +dotOperandDpasToLinearLayout(DotOperandEncodingAttr dotDpasLayout, + ArrayRef shape); + } // namespace mlir::triton::gpu #endif // TRITON_DIALECT_TRITONINTELGPU_IR_LINEARLAYOUTCONVERSIONS_H diff --git a/third_party/intel/lib/Dialect/TritonIntelGPU/IR/LinearLayoutConversions.cpp b/third_party/intel/lib/Dialect/TritonIntelGPU/IR/LinearLayoutConversions.cpp index d056fb2293..90e950bd0c 100644 --- a/third_party/intel/lib/Dialect/TritonIntelGPU/IR/LinearLayoutConversions.cpp +++ b/third_party/intel/lib/Dialect/TritonIntelGPU/IR/LinearLayoutConversions.cpp @@ -582,4 +582,13 @@ LinearLayout DPAStoLinearLayout(ArrayRef shape, Attribute layout, CTALayoutAttr::getDefault(ctx, rank), shape); } +std::optional +dotOperandDpasToLinearLayout(DotOperandEncodingAttr dotDpasLayout, + ArrayRef shape) { + auto dpasLayout = cast(dotDpasLayout.getParent()); + if (dotDpasLayout.getOpIdx() == 0) + return std::nullopt; + return DPAStoLinearLayout(shape, dpasLayout, dotDpasLayout.getOpIdx()); +} + } // namespace mlir::triton::gpu diff --git a/third_party/intel/lib/TritonIntelGPUTransforms/ScheduleLoad.cpp b/third_party/intel/lib/TritonIntelGPUTransforms/ScheduleLoad.cpp index 07c9e06109..41e975f9f9 100644 --- a/third_party/intel/lib/TritonIntelGPUTransforms/ScheduleLoad.cpp +++ b/third_party/intel/lib/TritonIntelGPUTransforms/ScheduleLoad.cpp @@ -71,8 +71,15 @@ class ScheduleLoadPass for (SmallVector &dots : dotsGroup) { SmallVector notVisited = getNotVisitedUses(dots); for (Value val : notVisited) { - if (Operation *op = val.getDefiningOp()) + if (Operation *op = val.getDefiningOp()) { + // Cannot move op that used by other ops in another region. + Region *rgn = dots.begin()->getOperation()->getParentRegion(); + if (any_of(val.getUsers(), [&](Operation *user) { + return user->getParentRegion() != rgn; + })) + continue; op->moveBefore(dots.begin()->getOperation()); + } } } }); diff --git a/third_party/proton/proton/viewer.py b/third_party/proton/proton/viewer.py index 9fe0e7e67d..ae689589c6 100644 --- a/third_party/proton/proton/viewer.py +++ b/third_party/proton/proton/viewer.py @@ -105,6 +105,7 @@ def get_min_time_bytes(df, device_info): def derive_metrics(gf, metrics, raw_metrics, device_info): derived_metrics = [] original_metrics = [] + exclusive_metrics = ["util"] + list(derivable_metrics.keys()) + list(avg_time_factor_dict.factor.keys()) internal_frame_indices = gf.dataframe["device_id"].isna() def get_time_seconds(df): @@ -133,6 +134,7 @@ def get_time_seconds(df): gf.dataframe[f"{metric} (inc)"] = (get_time_seconds(gf.dataframe) / time_factor_dict.factor[metric_time_unit]) derived_metrics.append(f"{metric} (inc)") + metric_name = match_available_metrics([time_factor_dict.name], raw_metrics)[0] elif metric in avg_time_factor_dict.factor: metric_time_unit = avg_time_factor_dict.name + "/" + metric.split("/")[1] gf.dataframe[f"{metric} (inc)"] = (get_time_seconds(gf.dataframe) / gf.dataframe['count'] / @@ -141,7 +143,12 @@ def get_time_seconds(df): derived_metrics.append(f"{metric} (inc)") else: original_metrics.append(metric) - + if metric not in exclusive_metrics: + single_frame = gf.dataframe[metric_name] + total = gf.dataframe[metric_name].iloc[0] + metric = metric.split("/")[0] + gf.dataframe[f"{metric}/% (inc)"] = (single_frame / total) * 100.0 + derived_metrics.append(f"{metric}/% (inc)") if original_metrics: original_metrics = match_available_metrics(original_metrics, raw_metrics) return derived_metrics + original_metrics @@ -227,6 +234,10 @@ def main(): - flop[<8/16/32/64>]/s, gflop[<8/16/32/64>]/s, tflop[<8/16/32/64>]/s: flops / time - byte/s, gbyte/s, tbyte/s: bytes / time - util: max(sum(flops) / peak_flops_time, sum(bytes) / peak_bandwidth_time) + +For inclusive metrics (e.g. time) an additional column is printed showing the percentage +each frame is of the full model. + """, ) argparser.add_argument( diff --git a/third_party/proton/test/test_viewer.py b/third_party/proton/test/test_viewer.py index 998825bbc8..b2d4d39f9b 100644 --- a/third_party/proton/test/test_viewer.py +++ b/third_party/proton/test/test_viewer.py @@ -118,8 +118,11 @@ def test_util(): def test_time_derivation(): derivation_metrics_test( metrics=["time/s", "time/ms", "time/us", "time/ns"], expected_data={ - 'time/s (inc)': [0.0004096, 0.0002048, 0.0002048], 'time/ms (inc)': [0.4096, 0.2048, 0.2048], - 'time/us (inc)': [409.6, 204.8, 204.8], 'time/ns (inc)': [409600.0, 204800.0, 204800.0] + 'time/s (inc)': [0.0004096, 0.0002048, 0.0002048], + 'time/ms (inc)': [0.4096, 0.2048, 0.2048], + 'time/us (inc)': [409.6, 204.8, 204.8], + 'time/ns (inc)': [409600.0, 204800.0, 204800.0], + 'time/% (inc)': [100.0, 50.0, 50.0], }, sample_file=cuda_example_file)