Skip to content

Commit

Permalink
[ROCM] fix layout for WMMA_F16_16x16x16_F16 intrinsic (iree-org#18206)
Browse files Browse the repository at this point in the history
The existing layout for the intrinsic was for subgroup=64 but we are
using subgroup=32 so it lead to this error
iree-org#18060
This PR fixes this to use the correct layout for subgroup=32 hence fixes
iree-org#18060 and
iree-org#17807
  • Loading branch information
nirvedhmeshram committed Aug 13, 2024
1 parent 08583d5 commit 868f41e
Show file tree
Hide file tree
Showing 5 changed files with 91 additions and 15 deletions.
37 changes: 31 additions & 6 deletions compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -353,8 +353,7 @@ static ConcreteMmaLayout getConcreteMFMALayout(MLIRContext *context,
return ConcreteMmaLayout{opaqueLayout, aMLayout, aKLayout, bKLayout,
bNLayout, cMLayout, cNLayout};
}
case MMAIntrinsic::WMMA_F32_16x16x16_F16:
case MMAIntrinsic::WMMA_F16_16x16x16_F16: {
case MMAIntrinsic::WMMA_F32_16x16x16_F16: {
// #outer = #iree_vector_ext.per_dim_layout<[LANEX], [16]>
// #inner = #iree_vector_ext.per_dim_layout<[LANEY, VECTORX], [1, 16]>
// #layout_a = #iree_vector_ext.layout<#outer, #inner>
Expand All @@ -372,6 +371,24 @@ static ConcreteMmaLayout getConcreteMFMALayout(MLIRContext *context,
return ConcreteMmaLayout{opaqueLayout, aMLayout, aKLayout, bKLayout,
bNLayout, cMLayout, cNLayout};
}
case MMAIntrinsic::WMMA_F16_16x16x16_F16: {
// #outer = #iree_vector_ext.per_dim_layout<[LANEX], [16]>
// #inner = #iree_vector_ext.per_dim_layout<[LANEY, VECTORX], [1, 16]>
// #layout_a = #iree_vector_ext.layout<#outer, #inner>
// #layout_b = #iree_vector_ext.layout<#inner, #outer>

auto outer = PerDimLayoutAttr::get(context, {laneX}, {16});
auto inner = PerDimLayoutAttr::get(context, {laneY, vectorX}, {1, 16});
auto aMLayout = outer;
auto aKLayout = inner;
auto bKLayout = inner;
auto bNLayout = outer;
auto cMLayout =
PerDimLayoutAttr::get(context, {vectorY, laneY, vectorX}, {16, 1, 1});
auto cNLayout = PerDimLayoutAttr::get(context, {laneX}, {16});
return ConcreteMmaLayout{opaqueLayout, aMLayout, aKLayout, bKLayout,
bNLayout, cMLayout, cNLayout};
}
default: {
break;
}
Expand Down Expand Up @@ -463,13 +480,18 @@ MMAAttr::getABCVectorTypes() const {
auto cType = VectorType::get({16}, getCType());
return std::make_tuple(aType, bType, cType);
}
case MMAIntrinsic::WMMA_F32_16x16x16_F16:
case MMAIntrinsic::WMMA_F16_16x16x16_F16: {
case MMAIntrinsic::WMMA_F32_16x16x16_F16: {
auto aType = VectorType::get({16}, getAType());
auto bType = VectorType::get({16}, getBType());
auto cType = VectorType::get({8}, getCType());
return std::make_tuple(aType, bType, cType);
}
case MMAIntrinsic::WMMA_F16_16x16x16_F16: {
auto aType = VectorType::get({16}, getAType());
auto bType = VectorType::get({16}, getBType());
auto cType = VectorType::get({16}, getCType());
return std::make_tuple(aType, bType, cType);
}
}
// This should not happen but just to make GCC happy.
return std::make_tuple(VectorType{}, VectorType{}, VectorType{});
Expand Down Expand Up @@ -597,11 +619,14 @@ MMAAttr::SingleSubgroupLayout MMAAttr::getCSingleSubgroupLayout() const {
return {/*outer=*/{4, 1}, /*thread=*/{2, 32}, /*strides=*/{32, 1},
/*element=*/{4, 1}};
}
case MMAIntrinsic::WMMA_F32_16x16x16_F16:
case MMAIntrinsic::WMMA_F16_16x16x16_F16: {
case MMAIntrinsic::WMMA_F32_16x16x16_F16: {
return {/*outer=*/{8, 1}, /*thread=*/{2, 16}, /*strides=*/{16, 1},
/*element=*/{1, 1}};
}
case MMAIntrinsic::WMMA_F16_16x16x16_F16: {
return {/*outer=*/{16, 1}, /*thread=*/{1, 16}, /*strides=*/{16, 1},
/*element=*/{1, 1}};
}
}
return {};
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -238,8 +238,8 @@ func.func @concretize_WMMA_F16_16x16x16_F16(%lhs: tensor<16x16xf16>, %rhs: tenso
// CHECK-INPUTS: %[[MMA:.+]] = iree_gpu.multi_mma
// CHECK-INPUTS: return %[[MMA]]

// CHECK-RESULT: %[[EXPANDED_ACC:.+]] = tensor.expand_shape %[[ACC]] {{\[}}[0, 1], [2]] output_shape [8, 2, 16]
// CHECK-RESULT: %[[EXPANDED_ACC:.+]] = tensor.expand_shape %[[ACC]] {{\[}}[0, 1], [2]] output_shape [16, 1, 16]
// CHECK-RESULT: %[[MMA:.+]] = iree_gpu.multi_mma %[[LHS]], %[[RHS]], %[[EXPANDED_ACC]]
// CHECK-RESULT-SAME: : tensor<16x16xf16>, tensor<16x16xf16> into tensor<8x2x16xf16>
// CHECK-RESULT-SAME: : tensor<16x16xf16>, tensor<16x16xf16> into tensor<16x1x16xf16>
// CHECK-RESULT: %[[COLLAPSED:.+]] = tensor.collapse_shape %[[MMA]] {{\[}}[0, 1], [2]]
// CHECK-RESULT: return %[[COLLAPSED]]
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ func.func @distribute_WMMA_F16_16x16x16_F16(%lhs: tensor<16x16xf16>, %rhs: tenso
}

// CHECK-DAG: #[[$XMAP:.+]] = affine_map<(d0) -> (d0 mod 16)>
// CHECK-DAG: #[[$YMAP:.+]] = affine_map<(d0) -> ((d0 floordiv 16) mod 2)>
// CHECK-DAG: #[[$YMAP:.+]] = affine_map<() -> ()>

// CHECK-LABEL: func @distribute_WMMA_F16_16x16x16_F16
// CHECK-SAME: %[[LHS:[A-Za-z0-9]+]]: tensor<16x16xf16>
Expand All @@ -248,10 +248,9 @@ func.func @distribute_WMMA_F16_16x16x16_F16(%lhs: tensor<16x16xf16>, %rhs: tenso
// CHECK-DAG: %[[IDX:.+]] = affine.apply #[[$XMAP]](%[[LANEID]])
// CHECK-DAG: %[[LHS_SLICE:.+]] = tensor.extract_slice %[[LHS]][%[[IDX]], 0] [1, 16]
// CHECK-DAG: %[[RHS_SLICE:.+]] = tensor.extract_slice %[[RHS]][0, %[[IDX]]] [16, 1]
// CHECK-DAG: %[[IDY:.+]] = affine.apply #[[$YMAP]](%[[LANEID]])
// CHECK-DAG: %[[ACC_SLICE:.+]] = tensor.extract_slice %[[ACC]][0, %[[IDY]], %[[IDX]]] [8, 1, 1]
// CHECK-DAG: %[[ACC_SLICE:.+]] = tensor.extract_slice %[[ACC]][0, 0, %[[IDX]]] [16, 1, 1]
// CHECK: %[[MMA:.+]] = iree_gpu.multi_mma %[[LHS_SLICE]], %[[RHS_SLICE]], %[[ACC_SLICE]]
// CHECK-SAME: kind = #iree_gpu.mma_layout<WMMA_F16_16x16x16_F16>
// CHECK-SAME: : tensor<1x16xf16>, tensor<16x1xf16> into tensor<8x1x1xf16>
// CHECK: tensor.parallel_insert_slice %[[MMA]] into %[[ACC]][0, %[[IDY]], %[[IDX]]] [8, 1, 1]
// CHECK-SAME: : tensor<1x16xf16>, tensor<16x1xf16> into tensor<16x1x1xf16>
// CHECK: tensor.parallel_insert_slice %[[MMA]] into %[[ACC]][0, 0, %[[IDX]]] [16, 1, 1]
// CHECK: mapping = [#iree_gpu.lane_id<0>]
Original file line number Diff line number Diff line change
Expand Up @@ -495,6 +495,6 @@ hal.executable public @main {
// CHECK-LABEL: func @matmul_transpose_b_wmma_f16_16x16x16_f16
// CHECK-DAG: memref.alloc() : memref<64x36xf16, #gpu.address_space<workgroup>>
// CHECK-DAG: memref.alloc() : memref<64x36xf16, #gpu.address_space<workgroup>>
// CHECK: scf.for %{{.*}} = %c0 to %c80 step %c2 {{.*}} -> (vector<2x2x8x1x1xf16>)
// CHECK-COUNT-8: amdgpu.wmma {{.*}} : vector<16xf16>, vector<16xf16>, vector<8xf16>
// CHECK: scf.for %{{.*}} = %c0 to %c80 step %c2 {{.*}} -> (vector<2x2x16x1x1xf16>)
// CHECK-COUNT-8: amdgpu.wmma {{.*}} : vector<16xf16>, vector<16xf16>, vector<16xf16>
// CHECK: scf.yield
Original file line number Diff line number Diff line change
Expand Up @@ -498,6 +498,58 @@ hal.executable.variant @rocm target(<"rocm", "rocm-hsaco-fb">) {

// -----

#pipeline_layout = #hal.pipeline.layout<push_constants = 0, sets = [
#hal.descriptor_set.layout<0, bindings = [
#hal.descriptor_set.binding<0, storage_buffer>,
#hal.descriptor_set.binding<1, storage_buffer>,
#hal.descriptor_set.binding<2, storage_buffer>
]>
]>
hal.executable @matmul_256x256x256_f16_f16 {
hal.executable.variant @rocm target(<"rocm", "rocm-hsaco-fb">) {
hal.executable.export @matmul_256x256x256_f16_f16 layout(#pipeline_layout) {
^bb0(%arg0: !hal.device, %arg1: index, %arg2 : index):
%x, %y, %z = flow.dispatch.workgroup_count_from_dag_root %arg1, %arg2
hal.return %x, %y, %z : index, index, index
}
builtin.module {
func.func @matmul_256x256x256_f16_f16() {
%cst = arith.constant 0.000000e+00 : f16
%c0 = arith.constant 0 : index
%0 = hal.interface.binding.subspan layout(#pipeline_layout) set(0) binding(0) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<256x256xf16>>
%1 = hal.interface.binding.subspan layout(#pipeline_layout) set(0) binding(1) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<256x256xf16>>
%2 = hal.interface.binding.subspan layout(#pipeline_layout) set(0) binding(2) alignment(64) offset(%c0) : !flow.dispatch.tensor<writeonly:tensor<256x256xf16>>
%3 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [256, 256], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<256x256xf16>> -> tensor<256x256xf16>
%4 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [256, 256], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<256x256xf16>> -> tensor<256x256xf16>
%5 = tensor.empty() : tensor<256x256xf16>
%6 = linalg.fill ins(%cst : f16) outs(%5 : tensor<256x256xf16>) -> tensor<256x256xf16>
%7 = linalg.matmul ins(%3, %4 : tensor<256x256xf16>, tensor<256x256xf16>) outs(%6 : tensor<256x256xf16>) -> tensor<256x256xf16>
flow.dispatch.tensor.store %7, %2, offsets = [0, 0], sizes = [256, 256], strides = [1, 1] : tensor<256x256xf16> -> !flow.dispatch.tensor<writeonly:tensor<256x256xf16>>
return
}
}
}
}

// RDNA3: #[[$TRANSLATION:.+]] = #iree_codegen.translation_info<LLVMGPUVectorDistribute workgroup_size = [64, 2, 1] subgroup_size = 32
// RDNA3-SAME: mma_schedule = #iree_gpu.mma_schedule<intrinsic = #iree_gpu.mma_layout<WMMA_F16_16x16x16_F16>,
// RDNA3-SAME: subgroup_m_count = 2, subgroup_n_count = 2>
// RDNA3-SAME: prefetch_shared_memory

// RDNA3-LABEL: func.func @matmul_256x256x256_f16_f16
// RDNA3-SAME: translation_info = #[[$TRANSLATION]]
// RDNA3: scf.for {{.*}} = %c0 to %c256 step %c128 iter_args({{.*}}) -> (vector<2x2x16x1x1x1xf16>)
// Each subgroup handles 2 * 2 tiles, and for each tile we accumulate 8 times
// along the K dimension. So in total 32 wmma ops.
// RDNA3-COUNT-32: amdgpu.wmma {{.*}} : vector<16xf16>, vector<16xf16>, vector<16xf16>
// RDNA3: scf.yield %{{.+}} : vector<2x2x16x1x1x1xf16>
// Since each subgroup handles 2 * 2 tiles, and for each tile, each lane holds 4 values.
// we will have 32 writes. We cannot do contiguous writes since the outputs columns has interleaved
// thread ids.
// RDNA3-COUNT-32: vector.transfer_write {{.+}} {in_bounds = [true, true]} : vector<1x1xf16>, memref<256x256xf16, #hal.descriptor_type<storage_buffer>>

// -----

#pipeline_layout = #hal.pipeline.layout<push_constants = 0, sets = [
#hal.descriptor_set.layout<0, bindings = [
#hal.descriptor_set.binding<0, storage_buffer>,
Expand Down

0 comments on commit 868f41e

Please sign in to comment.