Skip to content

Commit

Permalink
[GPU] Updates mfma/wmma attribute names in the matmul test generator. (
Browse files Browse the repository at this point in the history
…iree-org#18134)

This is the follop-up for
iree-org@82012e6


iree-org@ef28786
only fixes the mlir files but not python test generator. The revision
includes the case.

Signed-off-by: hanhanW <[email protected]>
  • Loading branch information
hanhanW authored Aug 7, 2024
1 parent ef28786 commit ca24b96
Showing 1 changed file with 28 additions and 28 deletions.
56 changes: 28 additions & 28 deletions tests/e2e/matmul/generate_e2e_matmul_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,29 +266,29 @@ def get_rocm_test_compilation_infos(
MMASchedule("MFMA_F32_16x16x4_F32", 1, 1, 1, 2, 1),
MMASchedule("MFMA_F32_16x16x4_F32", 1, 1, 2, 1, 1),
MMASchedule("MFMA_F32_16x16x4_F32", 2, 2, 1, 1, 2),
MMASchedule("MFMA_F16_16x16x16_F32", 1, 1, 1, 1, 1),
MMASchedule("MFMA_F16_16x16x16_F32", 1, 1, 1, 1, 2),
MMASchedule("MFMA_F16_16x16x16_F32", 1, 1, 1, 2, 1),
MMASchedule("MFMA_F16_16x16x16_F32", 1, 1, 2, 1, 1),
MMASchedule("MFMA_F16_16x16x16_F32", 2, 2, 1, 1, 1),
MMASchedule("MFMA_F16_16x16x16_F32", 2, 4, 2, 1, 2),
MMASchedule("MFMA_F16_16x16x16_F32", 4, 2, 4, 2, 2),
MMASchedule("MFMA_F16_32x32x8_F32", 1, 1, 1, 2, 2),
MMASchedule("MFMA_F16_32x32x8_F32", 2, 2, 1, 1, 1),
MMASchedule("MFMA_F16_32x32x8_F32", 1, 4, 2, 1, 2),
MMASchedule("MFMA_F16_32x32x8_F32", 4, 2, 1, 2, 4),
MMASchedule("MFMA_F8E4M3FNUZ_16x16x32_F32", 1, 1, 1, 1, 1),
MMASchedule("MFMA_F8E4M3FNUZ_16x16x32_F32", 2, 2, 1, 1, 2),
MMASchedule("MFMA_F8E4M3FNUZ_16x16x32_F32", 4, 1, 4, 1, 1),
MMASchedule("MFMA_F8E4M3FNUZ_16x16x32_F32", 4, 2, 4, 2, 1),
MMASchedule("MFMA_I8_16x16x32_I32", 1, 1, 1, 1, 1),
MMASchedule("MFMA_I8_16x16x32_I32", 2, 2, 1, 1, 2),
MMASchedule("MFMA_I8_16x16x32_I32", 4, 1, 4, 1, 1),
MMASchedule("MFMA_I8_16x16x32_I32", 4, 2, 4, 2, 1),
MMASchedule("MFMA_I8_32x32x16_I32", 1, 1, 1, 1, 1),
MMASchedule("MFMA_I8_32x32x16_I32", 2, 2, 1, 1, 2),
MMASchedule("MFMA_I8_32x32x16_I32", 4, 1, 1, 2, 2),
MMASchedule("MFMA_I8_32x32x16_I32", 4, 2, 2, 2, 2),
MMASchedule("MFMA_F32_16x16x16_F16", 1, 1, 1, 1, 1),
MMASchedule("MFMA_F32_16x16x16_F16", 1, 1, 1, 1, 2),
MMASchedule("MFMA_F32_16x16x16_F16", 1, 1, 1, 2, 1),
MMASchedule("MFMA_F32_16x16x16_F16", 1, 1, 2, 1, 1),
MMASchedule("MFMA_F32_16x16x16_F16", 2, 2, 1, 1, 1),
MMASchedule("MFMA_F32_16x16x16_F16", 2, 4, 2, 1, 2),
MMASchedule("MFMA_F32_16x16x16_F16", 4, 2, 4, 2, 2),
MMASchedule("MFMA_F32_32x32x8_F16", 1, 1, 1, 2, 2),
MMASchedule("MFMA_F32_32x32x8_F16", 2, 2, 1, 1, 1),
MMASchedule("MFMA_F32_32x32x8_F16", 1, 4, 2, 1, 2),
MMASchedule("MFMA_F32_32x32x8_F16", 4, 2, 1, 2, 4),
MMASchedule("MFMA_F32_16x16x32_F8E4M3FNUZ", 1, 1, 1, 1, 1),
MMASchedule("MFMA_F32_16x16x32_F8E4M3FNUZ", 2, 2, 1, 1, 2),
MMASchedule("MFMA_F32_16x16x32_F8E4M3FNUZ", 4, 1, 4, 1, 1),
MMASchedule("MFMA_F32_16x16x32_F8E4M3FNUZ", 4, 2, 4, 2, 1),
MMASchedule("MFMA_I32_16x16x32_I8", 1, 1, 1, 1, 1),
MMASchedule("MFMA_I32_16x16x32_I8", 2, 2, 1, 1, 2),
MMASchedule("MFMA_I32_16x16x32_I8", 4, 1, 4, 1, 1),
MMASchedule("MFMA_I32_16x16x32_I8", 4, 2, 4, 2, 1),
MMASchedule("MFMA_I32_32x32x16_I8", 1, 1, 1, 1, 1),
MMASchedule("MFMA_I32_32x32x16_I8", 2, 2, 1, 1, 2),
MMASchedule("MFMA_I32_32x32x16_I8", 4, 1, 1, 2, 2),
MMASchedule("MFMA_I32_32x32x16_I8", 4, 2, 2, 2, 2),
]
elif intrinsic == "WMMA":
schedules = [
Expand Down Expand Up @@ -319,22 +319,22 @@ def get_rocm_test_compilation_infos(
wg_tile_m = schedule.m_count * schedule.m_tile_count * 16
wg_tile_n = schedule.n_count * schedule.n_tile_count * 16
wg_tile_k = schedule.k_tile_count * 4
elif schedule.intrinsic == "MFMA_F16_16x16x16_F32":
elif schedule.intrinsic == "MFMA_F32_16x16x16_F16":
wg_tile_m = schedule.m_count * schedule.m_tile_count * 16
wg_tile_n = schedule.n_count * schedule.n_tile_count * 16
wg_tile_k = schedule.k_tile_count * 16
elif schedule.intrinsic == "MFMA_F16_32x32x8_F32":
elif schedule.intrinsic == "MFMA_F32_32x32x8_F16":
wg_tile_m = schedule.m_count * schedule.m_tile_count * 32
wg_tile_n = schedule.n_count * schedule.n_tile_count * 32
wg_tile_k = schedule.k_tile_count * 8
elif (
schedule.intrinsic == "MFMA_I8_16x16x32_I32"
or schedule.intrinsic == "MFMA_F8E4M3FNUZ_16x16x32_F32"
schedule.intrinsic == "MFMA_I32_16x16x32_I8"
or schedule.intrinsic == "MFMA_F32_16x16x32_F8E4M3FNUZ"
):
wg_tile_m = schedule.m_count * schedule.m_tile_count * 16
wg_tile_n = schedule.n_count * schedule.n_tile_count * 16
wg_tile_k = schedule.k_tile_count * 32
elif schedule.intrinsic == "MFMA_I8_32x32x16_I32":
elif schedule.intrinsic == "MFMA_I32_32x32x16_I8":
wg_tile_m = schedule.m_count * schedule.m_tile_count * 32
wg_tile_n = schedule.n_count * schedule.n_tile_count * 32
wg_tile_k = schedule.k_tile_count * 16
Expand Down

0 comments on commit ca24b96

Please sign in to comment.