Skip to content

Commit

Permalink
Add A^t@B benchmark (#2430)
Browse files Browse the repository at this point in the history
Based on this feedback
#2408 (review)

Changed GEMM benchmark to include transposed matrices case.

Closes #2424
Relates to
#1795

A@B^t case is important because weight matrix is often stored in [M, K]
format. For example, in
https://pytorch.org/docs/stable/generated/torch.nn.Linear.html
Right now we are about 1.5 times slower on XPU against raw torch for
that case.


A^t@B case is important because it's part of matmul backprop. Right now
we are about 4 times slower on XPU against raw torch for that case.
  • Loading branch information
Egor-Krivov authored Oct 8, 2024
1 parent 2202ca7 commit b12d0dd
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 320 deletions.
13 changes: 12 additions & 1 deletion .github/workflows/triton-benchmarks.yml
Original file line number Diff line number Diff line change
Expand Up @@ -163,13 +163,24 @@ jobs:
if: ${{ steps.install.outcome == 'success' && !cancelled() }}
run: |
cd benchmarks/triton_kernels_benchmark
python gemm_bt_benchmark.py --reports $REPORTS
TRANSPOSE_B=1 python gemm_benchmark.py --reports $REPORTS
mv $REPORTS/matmul-performance.csv $REPORTS/matmul-performance-bt.csv
source ../../scripts/capture-hw-details.sh
TAG=${{ inputs.tag || 'ci' }}
python ../../scripts/build_report.py $REPORTS/matmul-performance-bt.csv $REPORTS/gemm-bt-triton-report.csv --benchmark gemm-bt --compiler triton --param_cols "B,M,K,N" --tflops_col Triton-TFlops --hbm_col "Triton-GB/s" --tag $TAG
- name: Run Triton GEMM (A^t@B) kernel benchmark
if: ${{ steps.install.outcome == 'success' && !cancelled() }}
run: |
cd benchmarks/triton_kernels_benchmark
TRANSPOSE_A=1 python gemm_benchmark.py --reports $REPORTS
mv $REPORTS/matmul-performance.csv $REPORTS/matmul-performance-at.csv
source ../../scripts/capture-hw-details.sh
TAG=${{ inputs.tag || 'ci' }}
python ../../scripts/build_report.py $REPORTS/matmul-performance-at.csv $REPORTS/gemm-at-triton-report.csv --benchmark gemm-at --compiler triton --param_cols "B,M,K,N" --tflops_col Triton-TFlops --hbm_col "Triton-GB/s" --tag $TAG
- name: Run Triton GEMM (stream-k) kernel benchmark
if: ${{ steps.install.outcome == 'success' && !cancelled() }}
run: |
Expand Down
82 changes: 55 additions & 27 deletions benchmarks/triton_kernels_benchmark/gemm_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
To compare the performance to XeTLA kernel.
"""
import os

import torch
import triton
Expand All @@ -17,6 +18,10 @@
if benchmark_suit.USE_IPEX_OPTION:
import intel_extension_for_pytorch # type: ignore # noqa: F401

TRANSPOSE_A = os.getenv('TRANSPOSE_A', '0') == '1'
TRANSPOSE_B = os.getenv('TRANSPOSE_B', '0') == '1'
use_xetla = not (TRANSPOSE_A or TRANSPOSE_B)


@triton.autotune(
configs=[
Expand Down Expand Up @@ -158,15 +163,22 @@ def matmul_kernel_with_block_pointers_batched(

# We can now create a convenience wrapper function that only takes two input tensors,
# and (1) checks any shape constraint; (2) launches the above kernel.
def matmul(a, b, c):
def matmul(a, b, c, transpose_a=False, transpose_b=False):
a_major, a_minor = -2, -1
if transpose_a:
a_major, a_minor = a_minor, a_major
b_minor, b_major = -2, -1
if transpose_b:
b_major, b_minor = b_minor, b_major

assert a.shape[a_minor] == b.shape[b_minor], 'Incompatible dimensions'
assert a.is_contiguous(), 'Matrix A must be contiguous'
assert b.is_contiguous(), 'Matrix B must be contiguous'
M, N, K = a.shape[a_major], b.shape[b_major], a.shape[a_minor]
# Check constraints.
if len(a.shape) == 3 and len(b.shape) == 3:
assert a.shape[0] == b.shape[0], 'Incompatible Batch dimension'
assert a.shape[2] == b.shape[1], 'Incompatible dimensions'
assert a.is_contiguous(), 'Matrix A must be contiguous'
assert b.is_contiguous(), 'Matrix B must be contiguous'
B, M, K = a.shape
B, K, N = b.shape
B = a.shape[0]
# 1D launch kernel where each block gets its own program.
grid = lambda META: (
B,
Expand All @@ -175,27 +187,37 @@ def matmul(a, b, c):
matmul_kernel_with_block_pointers_batched[grid](
a, b, c, #
B, M, N, K, #
a.stride(0), a.stride(1), a.stride(2), #
b.stride(0), b.stride(1), b.stride(2), #
a.stride(0), a.stride(a_major), a.stride(a_minor), #
b.stride(0), b.stride(b_minor), b.stride(b_major), #
c.stride(0), c.stride(1), c.stride(2))
elif len(a.shape) == 2 and len(b.shape) == 2:
assert a.shape[1] == b.shape[0], 'Incompatible dimensions'
assert a.is_contiguous(), 'Matrix A must be contiguous'
assert b.is_contiguous(), 'Matrix B must be contiguous'
M, K = a.shape
K, N = b.shape
grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), )
matmul_kernel_with_block_pointers[grid](
a, b, c, #
M, N, K, #
a.stride(0), a.stride(1), #
b.stride(0), b.stride(1), #
a.stride(a_major), a.stride(a_minor), #
b.stride(b_minor), b.stride(b_major), #
c.stride(0), c.stride(1))
else:
assert False, 'Input matrixs dimensions mismatch'
return c


def get_shapes(B, M, N, K, transpose_a, transpose_b):
a_shape = (M, K)
if transpose_a:
a_shape = (K, M)

b_shape = (K, N)
if transpose_b:
b_shape = (N, K)

if B != 1:
a_shape = (B, *a_shape)
b_shape = (B, *b_shape)
return a_shape, b_shape


# Benchmark Performance
@benchmark_suit.perf_report(
benchmark_suit.Benchmark(
Expand Down Expand Up @@ -228,9 +250,9 @@ def matmul(a, b, c):
line_arg='provider',
# argument name whose value corresponds to a different line in the plot
# possible values for `line_arg``
line_vals=['triton', 'xetla'],
line_vals=['triton'] + (['xetla'] if use_xetla else []),
# label name for the lines
line_names=['Triton', 'XeTLA'],
line_names=['Triton'] + (['XeTLA'] if use_xetla else []),
# line styles
styles=[('green', '-'), ('green', '--'), ('blue', '-'), ('blue', '--')],
ylabel=['GB/s', 'TFlops'], # label name for the y-axis
Expand All @@ -239,27 +261,33 @@ def matmul(a, b, c):
args={},
))
def benchmark(B, M, N, K, provider):
if B == 1:
a = torch.rand((M, K), device='xpu', dtype=torch.bfloat16)
b = torch.rand((K, N), device='xpu', dtype=torch.bfloat16)
else:
a = torch.rand((B, M, K), device='xpu', dtype=torch.bfloat16)
b = torch.rand((B, K, N), device='xpu', dtype=torch.bfloat16)
a_shape, b_shape = get_shapes(B, M, N, K, transpose_a=TRANSPOSE_A, transpose_b=TRANSPOSE_B)

a = torch.rand(a_shape, device='xpu', dtype=torch.bfloat16)
b = torch.rand(b_shape, device='xpu', dtype=torch.bfloat16)

quantiles = [0.5, 0.0, 1.0]

torch_a = a
if TRANSPOSE_A:
torch_a = torch.transpose(torch_a, -2, -1)

torch_b = b
if TRANSPOSE_B:
torch_b = torch.transpose(torch_b, -2, -1)

if provider == 'onednn':
_, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(lambda: torch.matmul(a, b), warmup=10, rep=10,
quantiles=quantiles)
_, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(lambda: torch.matmul(torch_a, torch_b), warmup=10,
rep=10, quantiles=quantiles)
elif provider == 'triton':
assert len(a.shape) == len(b.shape), 'Incompatible sizes'
if len(a.shape) == 3:
c = torch.empty((B, M, N), device='xpu', dtype=torch.float32)
else:
assert len(a.shape) == 2, 'Expecting shape of length 2'
c = torch.empty((M, N), device='xpu', dtype=torch.float32)
triton_fn = lambda: matmul(a, b, c)
torch_fn = lambda: torch.matmul(a, b).to(torch.float32)
triton_fn = lambda: matmul(a, b, c, transpose_a=TRANSPOSE_A, transpose_b=TRANSPOSE_B)
torch_fn = lambda: torch.matmul(torch_a, torch_b).to(torch.float32)
rtol = 1e-2 if a.dtype == torch.bfloat16 else 1e-3
benchmark_suit.assert_close(triton_fn(), torch_fn(), atol=1e-4, rtol=rtol, err_msg='triton to torch')
_, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(triton_fn, warmup=10, rep=10, quantiles=quantiles,
Expand Down
Loading

0 comments on commit b12d0dd

Please sign in to comment.