diff --git a/python/tutorials/03-matrix-multiplication-cpu.py b/python/tutorials/03-matrix-multiplication-cpu.py new file mode 100644 index 000000000000..0cc90a474052 --- /dev/null +++ b/python/tutorials/03-matrix-multiplication-cpu.py @@ -0,0 +1,394 @@ +""" +Matrix Multiplication +===================== +In this tutorial, you will write a very short high-performance FP32 matrix multiplication kernel. + +You will specifically learn about: + +* Block-level matrix multiplications. + +* Multi-dimensional pointer arithmetic. + +* Program re-ordering for improved L2 cache hit rate. + +* Automatic performance tuning. + +""" + +# %% +# Motivations +# ----------- +# +# Matrix multiplications are a key building block of most modern high-performance computing systems. +# They are notoriously hard to optimize, hence their implementation is generally done by +# hardware vendors themselves as part of so-called "kernel libraries" (e.g., cuBLAS). +# Unfortunately, these libraries are often proprietary and cannot be easily customized +# to accommodate the needs of modern deep learning workloads (e.g., fused activation functions). +# In this tutorial, you will learn how to implement efficient matrix multiplications by +# yourself with Triton, in a way that is easy to customize and extend. +# +# Roughly speaking, the kernel that we will write will implement the following blocked +# algorithm to multiply a (M, K) by a (K, N) matrix: +# +# .. code-block:: python +# +# # Do in parallel +# for m in range(0, M, BLOCK_SIZE_M): +# # Do in parallel +# for n in range(0, N, BLOCK_SIZE_N): +# acc = zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=float32) +# for k in range(0, K, BLOCK_SIZE_K): +# a = A[m : m+BLOCK_SIZE_M, k : k+BLOCK_SIZE_K] +# b = B[k : k+BLOCK_SIZE_K, n : n+BLOCK_SIZE_N] +# acc += dot(a, b) +# C[m : m+BLOCK_SIZE_M, n : n+BLOCK_SIZE_N] = acc +# +# where each iteration of the doubly-nested for-loop is performed by a dedicated Triton program instance. + +# %% +# Compute Kernel +# -------------- +# +# The above algorithm is, actually, fairly straightforward to implement in Triton. +# The main difficulty comes from the computation of the memory locations at which blocks +# of :code:`A` and :code:`B` must be read in the inner loop. For that, we need +# multi-dimensional pointer arithmetic. +# +# Pointer Arithmetic +# ~~~~~~~~~~~~~~~~~~~ +# +# For a row-major 2D tensor :code:`X`, the memory location of :code:`X[i, j]` is given +# by :code:`&X[i, j] = X + i*stride_xi + j*stride_xj`. +# Therefore, blocks of pointers for :code:`A[m : m+BLOCK_SIZE_M, k:k+BLOCK_SIZE_K]` and +# :code:`B[k : k+BLOCK_SIZE_K, n : n+BLOCK_SIZE_N]` can be defined in pseudo-code as: +# +# .. code-block:: python +# +# &A[m : m+BLOCK_SIZE_M, k:k+BLOCK_SIZE_K] = a_ptr + (m : m+BLOCK_SIZE_M)[:, None]*A.stride(0) + (k : k+BLOCK_SIZE_K)[None, :]*A.stride(1); +# &B[k : k+BLOCK_SIZE_K, n:n+BLOCK_SIZE_N] = b_ptr + (k : k+BLOCK_SIZE_K)[:, None]*B.stride(0) + (n : n+BLOCK_SIZE_N)[None, :]*B.stride(1); +# +# Which means that pointers for blocks of A and B can be initialized (i.e., :code:`k=0`) in Triton as the following +# code. Also note that we need an extra modulo to handle the case where :code:`M` is not a multiple of +# :code:`BLOCK_SIZE_M` or :code:`N` is not a multiple of :code:`BLOCK_SIZE_N`, in which case we can pad the data with +# some useless values, which will not contribute to the results. For the :code:`K` dimension, we will handle that later +# using masking load semantics. +# +# .. code-block:: python +# +# offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M +# offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N +# offs_k = tl.arange(0, BLOCK_SIZE_K) +# a_ptrs = a_ptr + (offs_am[:, None]*stride_am + offs_k [None, :]*stride_ak) +# b_ptrs = b_ptr + (offs_k [:, None]*stride_bk + offs_bn[None, :]*stride_bn) +# +# And then updated in the inner loop as follows: +# +# .. code-block:: python +# +# a_ptrs += BLOCK_SIZE_K * stride_ak; +# b_ptrs += BLOCK_SIZE_K * stride_bk; +# +# +# L2 Cache Optimizations +# ~~~~~~~~~~~~~~~~~~~~~~ +# +# As mentioned above, each program instance computes a :code:`[BLOCK_SIZE_M, BLOCK_SIZE_N]` +# block of :code:`C`. +# It is important to remember that the order in which these blocks are computed does +# matter, since it affects the L2 cache hit rate of our program, and unfortunately, a +# simple row-major ordering +# +# .. code-block:: Python +# +# pid = triton.program_id(0); +# grid_m = (M + BLOCK_SIZE_M - 1) // BLOCK_SIZE_M; +# grid_n = (N + BLOCK_SIZE_N - 1) // BLOCK_SIZE_N; +# pid_m = pid / grid_n; +# pid_n = pid % grid_n; +# +# is just not going to cut it. +# +# One possible solution is to launch blocks in an order that promotes data reuse. +# This can be done by 'super-grouping' blocks in groups of :code:`GROUP_M` rows before +# switching to the next column: +# +# .. code-block:: python +# +# # Program ID +# pid = tl.program_id(axis=0) +# # Number of program ids along the M axis +# num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) +# # Number of programs ids along the N axis +# num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) +# # Number of programs in group +# num_pid_in_group = GROUP_SIZE_M * num_pid_n +# # Id of the group this program is in +# group_id = pid // num_pid_in_group +# # Row-id of the first program in the group +# first_pid_m = group_id * GROUP_SIZE_M +# # If `num_pid_m` isn't divisible by `GROUP_SIZE_M`, the last group is smaller +# group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) +# # *Within groups*, programs are ordered in a column-major order +# # Row-id of the program in the *launch grid* +# pid_m = first_pid_m + (pid % group_size_m) +# # Col-id of the program in the *launch grid* +# pid_n = (pid % num_pid_in_group) // group_size_m +# +# For example, in the following matmul where each matrix is 9 blocks by 9 blocks, +# we can see that if we compute the output in row-major ordering, we need to load 90 +# blocks into SRAM to compute the first 9 output blocks, but if we do it in grouped +# ordering, we only need to load 54 blocks. +# +# .. image:: grouped_vs_row_major_ordering.png +# +# In practice, this can improve the performance of our matrix multiplication kernel by +# more than 10\% on some hardware architecture (e.g., 220 to 245 TFLOPS on A100). +# + +# %% +# Final Result +# ------------ + +import torch + +import triton +import triton.language as tl + + +BLOCK_SIZE_M = 32 +BLOCK_SIZE_N = 32 +BLOCK_SIZE_K = 32 +GROUP_SIZE_M = 8 +USE_GPU = True + +@triton.jit +def matmul_kernel( + # Pointers to matrices + a_ptr, b_ptr, c_ptr, + # Matrix dimensions + M, N, K, + # The stride variables represent how much to increase the ptr by when moving by 1 + # element in a particular dimension. E.g. `stride_am` is how much to increase `a_ptr` + # by to get the element one row down (A has M rows). + stride_am, stride_ak, # + stride_bk, stride_bn, # + stride_cm, stride_cn, + # Meta-parameters + BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, # + GROUP_SIZE_M: tl.constexpr, # +): + """Kernel for computing the matmul C = A x B. + A has shape (M, K), B has shape (K, N) and C has shape (M, N) + """ + # ----------------------------------------------------------- + # Map program ids `pid` to the block of C it should compute. + # This is done in a grouped ordering to promote L2 data reuse. + # See above `L2 Cache Optimizations` section for details. + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + (pid % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + # ---------------------------------------------------------- + # Create pointers for the first blocks of A and B. + # We will advance this pointer as we move in the K direction + # and accumulate + # `a_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers + # `b_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointers + # See above `Pointer Arithmetic` section for details + offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M + offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) + b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) + + # ----------------------------------------------------------- + # Iterate to compute a block of the C matrix. + # We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block + # of fp32 values for higher accuracy. + # `accumulator` will be converted back to matrix C's type after the loop, if C has lower precision type (for example, float16 and bfloat16). + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + # Load the next block of A and B, generate a mask by checking the K dimension. + # If it is out of bounds, set it to 0. + + #TODO: Currently masked load is not supported yet. + #a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) + #b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + a = tl.load(a_ptrs) + b = tl.load(b_ptrs) + # We accumulate along the K dimension. + accumulator = tl.dot(a, b, accumulator, out_dtype=tl.float32) + # Advance the ptrs to the next K block. + a_ptrs += BLOCK_SIZE_K * stride_ak + b_ptrs += BLOCK_SIZE_K * stride_bk + + # Convert the accumulator to the output matrix C's type if needed. + c = accumulator + + # ----------------------------------------------------------- + # Write back the block of the output matrix C with masks. + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] + + #TODO: Currently masked load is not supported yet. + #c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + #tl.store(c_ptrs, c, mask=c_mask) + tl.store(c_ptrs, c) + + + +# %% +# We can now create a convenience wrapper function that only takes two input tensors, +# and (1) checks any shape constraint; (2) allocates the output; (3) launches the above kernel. + + +def matmul(a, b): + # Check constraints. + assert a.shape[1] == b.shape[0], "Incompatible dimensions" + assert a.is_contiguous(), "Matrix A must be contiguous" + M, K = a.shape + K, N = b.shape + assert (M % BLOCK_SIZE_M == 0) and (N % BLOCK_SIZE_N == 0) and (K % BLOCK_SIZE_K == 0), "Masking currently not supported, Matrix dimensions must be multiples of block size" + # Allocates output. + c = torch.empty((M, N), device=a.device, dtype=a.dtype) + # 1D launch kernel where each block gets its own program. + grid = (triton.cdiv(M, BLOCK_SIZE_M) * triton.cdiv(N, BLOCK_SIZE_N), ) + matmul_kernel[grid]( + a, b, c, # + M, N, K, # + a.stride(0), a.stride(1), # + b.stride(0), b.stride(1), # + c.stride(0), c.stride(1), # + BLOCK_SIZE_M=BLOCK_SIZE_M, + BLOCK_SIZE_N=BLOCK_SIZE_N, + BLOCK_SIZE_K=BLOCK_SIZE_K, # + GROUP_SIZE_M=GROUP_SIZE_M, # + ) + return c + + +# %% +# Unit Test +# --------- +# +# We can test our custom matrix multiplication operation against a native torch implementation. + +torch.manual_seed(0) + +triton.runtime.driver.set_active_to_cpu() + + +a = torch.randn((512, 512), device='cpu', dtype=torch.float32) +b = torch.randn((512, 512), device='cpu', dtype=torch.float32) +triton_output = matmul(a, b) +torch_output = torch.matmul(a, b) +print(f"triton_cpu_output_with_{a.dtype}_inputs={triton_output}") +print(f"torch_cpu_output_with_{a.dtype}_inputs={torch_output}") +rtol = 0 +if torch.allclose(triton_output, torch_output, atol=1e-2, rtol=rtol): + print("✅ TritonCPU and TorchCPU match") +else: + print("❌ TritonCPU and TorchCPU differ, the maximum difference is "f'{torch.max(torch.abs(triton_output - torch_output))}') + +# %% +# Benchmark +# --------- +# +# Square Matrix Performance +# ~~~~~~~~~~~~~~~~~~~~~~~~~~ +# +# We can now compare the performance of our kernel against that of Pytorch. Here we focus on square matrices, +# but feel free to arrange this script as you wish to benchmark any other matrix shape. + +LINE_VALS = ['triton-cpu-single', 'triton-cpu', 'torch-cpu'] +LINE_NAMES = ['TritonCPU 1', 'TritonCPU', 'TorchCPU'] +LINE_STYLES = [('blue', '-'), ('green', '-'), ('cyan', '-')] + +if USE_GPU and triton.runtime.driver.get_active_gpus(): + triton.runtime.driver.set_active_to_gpu() + a = a.to('cuda') + b = b.to('cuda') + triton_output = matmul(a, b) + torch_output = torch.matmul(a, b) + print(f"triton_gpu_output_with_{a.dtype}_inputs={triton_output}") + print(f"torch_gpu_output_with_{a.dtype}_inputs={torch_output}") + rtol = 0 + if torch.allclose(triton_output, torch_output, atol=1e-2, rtol=rtol): + print("✅ TritonGPU and TorchGPU match") + else: + print("❌ TritonGPU and TorchGPU differ, the maximum difference is "f'{torch.max(torch.abs(triton_output - torch_output))}') + + LINE_VALS += ['triton-gpu', 'torch-gpu'] + LINE_NAMES += ['TritonGPU', 'TorchGPU'] + LINE_STYLES += [('yellow', '-'), ('red', '-')] + + +# %% +# Seems like we're good to go! + +# %% +# Benchmark +# --------- +# +# We can now benchmark our custom op on vectors of increasing sizes to get a sense of how it does relative to PyTorch. +# To make things easier, Triton has a set of built-in utilities that allow us to concisely plot the performance of our custom ops. +# for different problem sizes. + + +@triton.testing.perf_report( + triton.testing.Benchmark( + x_names=["M", "N", "K"], # Argument names to use as an x-axis for the plot + x_vals=[128 * i for i in range(2, 21)], # Different possible values for `x_name` + line_arg='provider', # Argument name whose value corresponds to a different line in the plot. + line_vals=LINE_VALS, # Possible values for `line_arg`. + line_names=LINE_NAMES, # Label name for the lines. + styles=LINE_STYLES, # Line styles. + ylabel='GFLOPS', # Label name for the y-axis. + plot_name= + # Name for the plot. Used also as a file name for saving the plot. + f'matmul-performance-fp32 (BLOCK_SIZE_M={BLOCK_SIZE_M}, BLOCK_SIZE_N={BLOCK_SIZE_N}, BLOCK_SIZE_K={BLOCK_SIZE_K}, GROUP_SIZE_M={GROUP_SIZE_M})', + args={}, # Values for function arguments not in `x_names` and `y_name`. + )) + +def benchmark(M, N, K, provider): + import os + + device = 'cpu' if 'cpu' in provider else 'cuda' + a = torch.randn((M, K), device=device, dtype=torch.float32) + b = torch.randn((K, N), device=device, dtype=torch.float32) + + if device == 'cpu': + triton.runtime.driver.set_active_to_cpu() + if 'single' in provider: + os.environ['TRITON_CPU_SINGLE_CORE'] = '1' + else: + os.unsetenv('TRITON_CPU_SINGLE_CORE') + else: + triton.runtime.driver.set_active_to_gpu() + + quantiles = [0.5, 0.2, 0.8] + if provider == 'torch-gpu': + ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.matmul(a, b), quantiles=quantiles) + elif provider == 'triton-gpu': + ms, min_ms, max_ms = triton.testing.do_bench(lambda: matmul(a, b), quantiles=quantiles) + elif provider == 'torch-cpu': + ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.matmul(a, b), quantiles=quantiles) + elif provider == 'triton-cpu-single': + ms, min_ms, max_ms = triton.testing.do_bench(lambda: matmul(a, b), quantiles=quantiles) + elif provider == 'triton-cpu': + ms, min_ms, max_ms = triton.testing.do_bench(lambda: matmul(a, b), quantiles=quantiles) + perf = lambda ms: 2 * M * N * K * 1e-9 / (ms * 1e-3) + return perf(ms), perf(max_ms), perf(min_ms) + + +# %% +# We can now run the decorated function above. Pass `print_data=True` to see the performance number, `show_plots=True` to plot them, and/or +# `save_path='/path/to/results/' to save them to disk along with raw CSV data: +benchmark.run(print_data=True, show_plots=True)