Skip to content

Commit

Permalink
fixed codestyle
Browse files Browse the repository at this point in the history
  • Loading branch information
Egor-Krivov committed Oct 7, 2024
1 parent 4068377 commit c6075cd
Showing 1 changed file with 18 additions and 13 deletions.
31 changes: 18 additions & 13 deletions benchmarks/triton_kernels_benchmark/gemm_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@
if benchmark_suit.USE_IPEX_OPTION:
import intel_extension_for_pytorch # type: ignore # noqa: F401

TRANSPOSE_A = os.getenv("TRANSPOSE_A", "0") == "1"
TRANSPOSE_B = os.getenv("TRANSPOSE_B", "0") == "1"
TRANSPOSE_A = os.getenv('TRANSPOSE_A', '0') == '1'
TRANSPOSE_B = os.getenv('TRANSPOSE_B', '0') == '1'
use_xetla = not (TRANSPOSE_A or TRANSPOSE_B)


Expand Down Expand Up @@ -203,6 +203,21 @@ def matmul(a, b, c, transpose_a=False, transpose_b=False):
return c


def get_shapes(B, M, N, K, transpose_a, transpose_b):
a_shape = (M, K)
if transpose_a:
a_shape = (K, M)

b_shape = (K, N)
if transpose_b:
b_shape = (N, K)

if B != 1:
a_shape = (B, *a_shape)
b_shape = (B, *b_shape)
return a_shape, b_shape


# Benchmark Performance
@benchmark_suit.perf_report(
benchmark_suit.Benchmark(
Expand Down Expand Up @@ -246,17 +261,7 @@ def matmul(a, b, c, transpose_a=False, transpose_b=False):
args={},
))
def benchmark(B, M, N, K, provider):
a_shape = (M, K)
if TRANSPOSE_A:
a_shape = (K, M)

b_shape = (K, N)
if TRANSPOSE_B:
b_shape = (N, K)

if B != 1:
a_shape = (B, *a_shape)
b_shape = (B, *b_shape)
a_shape, b_shape = get_shapes(B, M, N, K, transpose_a=TRANSPOSE_A, transpose_b=TRANSPOSE_B)

a = torch.rand(a_shape, device='xpu', dtype=torch.bfloat16)
b = torch.rand(b_shape, device='xpu', dtype=torch.bfloat16)
Expand Down

0 comments on commit c6075cd

Please sign in to comment.