From 0e05cae1de410600fe68a3babd8c95e5ac3e5524 Mon Sep 17 00:00:00 2001 From: Adam Siemieniuk Date: Fri, 1 Nov 2024 12:48:07 +0100 Subject: [PATCH] Matmul tutorial - cache padding (#14) Adds extra optional padding that can be use to ensure that input matrices' strides are non-power-of-two to improve cache behavior. Currently, it is most useful with DYNAMIC_K_BLOCK enabled. --- python/tutorials/03-matrix-multiplication-cpu.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/python/tutorials/03-matrix-multiplication-cpu.py b/python/tutorials/03-matrix-multiplication-cpu.py index 00706f5d3dd9..22e9f3932a07 100644 --- a/python/tutorials/03-matrix-multiplication-cpu.py +++ b/python/tutorials/03-matrix-multiplication-cpu.py @@ -165,6 +165,7 @@ DATA_TYPE = torch.float32 K_DIM_PADDING = False DYNAMIC_K_BLOCK = False +CACHE_PADDING = False @triton.jit def matmul_kernel( @@ -322,6 +323,13 @@ def matmul(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor): b = torch.nn.functional.pad(b, (0, 0, 0, padding_size), mode='constant', value=0) K = a.shape[1] + # TODO: Check if padding is needed at all. + # Currently, cache padding is most useful together with dynamic K blocking + # to ensure that stride is non-power-of-two to improve cache behavior. + if CACHE_PADDING: + a = torch.nn.functional.pad(a, (0, 32, 0, 0), mode='constant', value=0) + b = torch.nn.functional.pad(b, (0, 32, 0, 0), mode='constant', value=0) + #TODO: Currently masked load is not supported yet. assert (M % BLOCK_SIZE_M == 0) and (N % BLOCK_SIZE_N == 0) and ( K % k_block == 0), "Masking currently not supported, Matrix dimensions must be multiples of block size"