From c6075cd74e3975fb1a2e2c168d0df54b7cb37d02 Mon Sep 17 00:00:00 2001 From: Egor Krivov Date: Mon, 7 Oct 2024 13:39:37 +0000 Subject: [PATCH] fixed codestyle --- .../gemm_benchmark.py | 31 +++++++++++-------- 1 file changed, 18 insertions(+), 13 deletions(-) diff --git a/benchmarks/triton_kernels_benchmark/gemm_benchmark.py b/benchmarks/triton_kernels_benchmark/gemm_benchmark.py index 6d74a8e96..6bb87b871 100644 --- a/benchmarks/triton_kernels_benchmark/gemm_benchmark.py +++ b/benchmarks/triton_kernels_benchmark/gemm_benchmark.py @@ -18,8 +18,8 @@ if benchmark_suit.USE_IPEX_OPTION: import intel_extension_for_pytorch # type: ignore # noqa: F401 -TRANSPOSE_A = os.getenv("TRANSPOSE_A", "0") == "1" -TRANSPOSE_B = os.getenv("TRANSPOSE_B", "0") == "1" +TRANSPOSE_A = os.getenv('TRANSPOSE_A', '0') == '1' +TRANSPOSE_B = os.getenv('TRANSPOSE_B', '0') == '1' use_xetla = not (TRANSPOSE_A or TRANSPOSE_B) @@ -203,6 +203,21 @@ def matmul(a, b, c, transpose_a=False, transpose_b=False): return c +def get_shapes(B, M, N, K, transpose_a, transpose_b): + a_shape = (M, K) + if transpose_a: + a_shape = (K, M) + + b_shape = (K, N) + if transpose_b: + b_shape = (N, K) + + if B != 1: + a_shape = (B, *a_shape) + b_shape = (B, *b_shape) + return a_shape, b_shape + + # Benchmark Performance @benchmark_suit.perf_report( benchmark_suit.Benchmark( @@ -246,17 +261,7 @@ def matmul(a, b, c, transpose_a=False, transpose_b=False): args={}, )) def benchmark(B, M, N, K, provider): - a_shape = (M, K) - if TRANSPOSE_A: - a_shape = (K, M) - - b_shape = (K, N) - if TRANSPOSE_B: - b_shape = (N, K) - - if B != 1: - a_shape = (B, *a_shape) - b_shape = (B, *b_shape) + a_shape, b_shape = get_shapes(B, M, N, K, transpose_a=TRANSPOSE_A, transpose_b=TRANSPOSE_B) a = torch.rand(a_shape, device='xpu', dtype=torch.bfloat16) b = torch.rand(b_shape, device='xpu', dtype=torch.bfloat16)