diff --git a/test/TritonGEN/tritongen-invalid.mlir b/test/TritonGEN/tritongen-invalid.mlir index a3b9a5790d..58fc37a032 100644 --- a/test/TritonGEN/tritongen-invalid.mlir +++ b/test/TritonGEN/tritongen-invalid.mlir @@ -1,64 +1,113 @@ // RUN: triton-opt -split-input-file -verify-diagnostics %s -llvm.func @triton_gen.dpas(%c : vector<8xi32>, %a : vector<16xi8>, %b : vector<32xi8>) { +llvm.func @triton_gen.dpas(%c : vector<8xi32>, %a : vector<8xi16>, %b : vector<8xi32>) { // expected-error @+1 {{'triton_gen.dpas' op expecting repeat count to be 1, 2, 4, or 8}} - %0 = triton_gen.dpas %c, %a, %b {pa=i8, pb=i8, rc=6} : (vector<8xi32>, vector<16xi8>, vector<32xi8>) -> vector<8xi32> + %0 = triton_gen.dpas %c, %a, %b {pa=i8, pb=i8, rc=16} : (vector<8xi32>, vector<8xi16>, vector<8xi32>) -> vector<8xi32> llvm.return } // ----- -llvm.func @triton_gen.dpas(%c : vector<8xi32>, %a : vector<16xi8>, %b : vector<32xi8>) { +llvm.func @triton_gen.dpas(%c : vector<8xi32>, %a : vector<8xi16>, %b : vector<8xi32>) { // expected-error @+1 {{'triton_gen.dpas' op expecting precision of matrix A and B to be the same}} - %0 = triton_gen.dpas %c, %a, %b {pa=i8, pb=u8, rc=8} : (vector<8xi32>, vector<16xi8>, vector<32xi8>) -> vector<8xi32> + %0 = triton_gen.dpas %c, %a, %b {pa=i8, pb=f16, rc=8} : (vector<8xi32>, vector<8xi16>, vector<8xi32>) -> vector<8xi32> llvm.return } // ----- -llvm.func @triton_gen.dpas(%c : vector<8xi8>, %a : vector<16xi8>, %b : vector<32xi8>) { +llvm.func @triton_gen.dpas(%c : vector<8xi8>, %a : vector<8xi16>, %b : vector<8xi32>) { // expected-error @+1 {{'triton_gen.dpas' op 1st operand (C) and result (D) should have the same type}} - %0 = triton_gen.dpas %c, %a, %b {pa=i8, pb=i8, rc=8} : (vector<8xi8>, vector<16xi8>, vector<32xi8>) -> vector<8xi32> + %0 = triton_gen.dpas %c, %a, %b {pa=i8, pb=i8, rc=8} : (vector<8xi8>, vector<8xi16>, vector<8xi32>) -> vector<8xi32> llvm.return } // ----- -llvm.func @triton_gen.dpas(%c : vector<16xi32>, %a : vector<16xi8>, %b : vector<32xi8>) { +llvm.func @triton_gen.dpas(%c : vector<16xi32>, %a : vector<8xi16>, %b : vector<8xi32>) { // expected-error @+1 {{'triton_gen.dpas' op the dimension for 1st operand (C) and result (D) should match repeat count}} - %0 = triton_gen.dpas %c, %a, %b {pa=i8, pb=i8, rc=8} : (vector<16xi32>, vector<16xi8>, vector<32xi8>) -> vector<16xi32> + %0 = triton_gen.dpas %c, %a, %b {pa=i8, pb=i8, rc=8} : (vector<16xi32>, vector<8xi16>, vector<8xi32>) -> vector<16xi32> llvm.return } // ----- -llvm.func @triton_gen.dpas(%c : vector<8xi32>, %a : vector<8xi8>, %b : vector<8xi8>) { - // expected-error @+1 {{'triton_gen.dpas' op 2nd operand (A) bit-size should be repeat count times 16}} - %0 = triton_gen.dpas %c, %a, %b {pa=i8, pb=i8, rc=8} : (vector<8xi32>, vector<8xi8>, vector<8xi8>) -> vector<8xi32> +llvm.func @triton_gen.dpas(%c : vector<8xi32>, %a : vector<8xi16>, %b : vector<16xi32>) { + // expected-error @+1 {{'triton_gen.dpas' op the dimension for the 3rd operand (B) should match the systolic depth of 8}} + %0 = triton_gen.dpas %c, %a, %b {pa=i8, pb=i8, rc=8} : (vector<8xi32>, vector<8xi16>, vector<16xi32>) -> vector<8xi32> llvm.return } // ----- -llvm.func @triton_gen.dpas(%c : vector<8xi32>, %a : vector<16xi8>, %b : vector<16xi8>) { - // expected-error @+1 {{'triton_gen.dpas' op 3rd operand (B) bit-size should be systolic depth (8) times 32}} - %0 = triton_gen.dpas %c, %a, %b {pa=i8, pb=i8, rc=8} : (vector<8xi32>, vector<16xi8>, vector<16xi8>) -> vector<8xi32> +llvm.func @triton_gen.dpas(%c : vector<8xi8>, %a : vector<8xi16>, %b : vector<8xi32>) { + // expected-error @+1 {{'triton_gen.dpas' op the element type for 1st operand (C) and the result should be i32}} + %0 = triton_gen.dpas %c, %a, %b {pa=i8, pb=i8, rc=8} : (vector<8xi8>, vector<8xi16>, vector<8xi32>) -> vector<8xi8> llvm.return } // ----- -llvm.func @triton_gen.dpas(%c : vector<8xi8>, %a : vector<16xi8>, %b : vector<32xi8>) { - // expected-error @+1 {{'triton_gen.dpas' op the element type for 1st operand (C) and the result should be i32}} - %0 = triton_gen.dpas %c, %a, %b {pa=i8, pb=i8, rc=8} : (vector<8xi8>, vector<16xi8>, vector<32xi8>) -> vector<8xi8> +llvm.func @triton_gen.dpas(%c : vector<8xi32>, %a : vector<8xi16>, %b : vector<8xi32>) { + // expected-error @+1 {{'triton_gen.dpas' op the element type for 1st operand (C) and the result should be f32}} + %0 = triton_gen.dpas %c, %a, %b {pa=f16, pb=f16, rc=8} : (vector<8xi32>, vector<8xi16>, vector<8xi32>) -> vector<8xi32> + llvm.return +} + +// ----- + +llvm.func @triton_gen.dpas(%c : vector<8xf32>, %a : vector<8xf32>, %b : vector<8xf32>) { + // expected-error @+1 {{'triton_gen.dpas' op the dimension for the 2nd operand (A) should be equal to half of the repeat count}} + %0 = triton_gen.dpas %c, %a, %b {pa = tf32, pb = tf32, rc = 8} : (vector<8xf32>, vector<8xf32>, vector<8xf32>) -> vector<8xf32> + llvm.return +} + +// ----- + +llvm.func @triton_gen.dpas(%c : vector<8xf32>, %a : vector<4xi16>, %b : vector<8xf32>) { + // expected-error @+1 {{'triton_gen.dpas' op 2nd operand (A) element type should be f32 or i32 when the precision type is tf32}} + %0 = triton_gen.dpas %c, %a, %b {pa = tf32, pb = tf32, rc = 8} : (vector<8xf32>, vector<4xi16>, vector<8xf32>) -> vector<8xf32> + llvm.return +} + +// ----- + + +llvm.func @triton_gen.dpas(%c : vector<8xf32>, %a : vector<4xf32>, %b : vector<8xi16>) { + // expected-error @+1 {{'triton_gen.dpas' op 3rd operand (B) element type should be f32 or i32 when the precision type is tf32}} + %0 = triton_gen.dpas %c, %a, %b {pa = tf32, pb = tf32, rc = 8} : (vector<8xf32>, vector<4xf32>, vector<8xi16>) -> vector<8xf32> + llvm.return +} + +// ----- + +llvm.func @triton_gen.dpas(%c : vector<8xi32>, %a : vector<4xi16>, %b : vector<8xi32>) { + // expected-error @+1 {{'triton_gen.dpas' op 2nd operand (A) should have the same number of elements as repeat count}} + %0 = triton_gen.dpas %c, %a, %b {pa=i8, pb=i8, rc=8} : (vector<8xi32>, vector<4xi16>, vector<8xi32>) -> vector<8xi32> + llvm.return +} + +// ----- + +llvm.func @triton_gen.dpas(%c : vector<8xi32>, %a : vector<8xi8>, %b : vector<8xi32>) { + // expected-error @+1 {{'triton_gen.dpas' op 2nd operand (A) element type should be i16 when the precision type is not tf32}} + %0 = triton_gen.dpas %c, %a, %b {pa=i8, pb=i8, rc=8} : (vector<8xi32>, vector<8xi8>, vector<8xi32>) -> vector<8xi32> + llvm.return +} + +// ----- + +llvm.func @triton_gen.dpas(%c : vector<8xf32>, %a : vector<8xi16>, %b : vector<8xf32>) { + // expected-error @+1 {{'triton_gen.dpas' op 3rd operand (B) element type should be i32 when the precision type is not tf32}} + %0 = triton_gen.dpas %c, %a, %b {pa=f16, pb=f16, rc=8} : (vector<8xf32>, vector<8xi16>, vector<8xf32>) -> vector<8xf32> llvm.return } // ----- -llvm.func @triton_gen.dpas(%c : vector<8xf32>, %a : vector<8xf16>, %b : vector<16xf16>) { +llvm.func @triton_gen.dpas(%c : vector<8xf32>, %a : vector<8xi16>, %b : vector<8xi32>) { // expected-error @+1 {{'triton_gen.dpas' op expecting precision type to be tf32, bf16, fp16, u8, or s8}} - %0 = triton_gen.dpas %c, %a, %b {pa=i4, pb=i4, rc=8} : (vector<8xf32>, vector<8xf16>, vector<16xf16>) -> vector<8xf32> + %0 = triton_gen.dpas %c, %a, %b {pa=i4, pb=i4, rc=8} : (vector<8xf32>, vector<8xi16>, vector<8xi32>) -> vector<8xf32> llvm.return } diff --git a/test/TritonGEN/tritongen-to-llvm.mlir b/test/TritonGEN/tritongen-to-llvm.mlir index 7ed000d7b1..320eecb637 100644 --- a/test/TritonGEN/tritongen-to-llvm.mlir +++ b/test/TritonGEN/tritongen-to-llvm.mlir @@ -161,12 +161,10 @@ llvm.func @triton_gen.sub_group_shuffle() { // CHECK: llvm.func spir_funccc @_Z36intel_sub_group_i8_i8_matrix_mad_k32Dv8_sDv8_iS0_(vector<8xi16>, vector<8xi32>, vector<8xi32>) -> vector<8xi32> attributes {passthrough = ["convergent"]} -llvm.func @triton_gen.dpas.i8(%c : vector<8xi32>, %a : vector<16xi8>, %b : vector<32xi8>) { - // CHECK: llvm.func @triton_gen.dpas.i8(%arg0: vector<8xi32>, %arg1: vector<16xi8>, %arg2: vector<32xi8>) { - // CHECK-DAG: [[A:%.*]] = llvm.bitcast %arg1 : vector<16xi8> to vector<8xi16> - // CHECK-DAG: [[B:%.*]] = llvm.bitcast %arg2 : vector<32xi8> to vector<8xi32> - // CHECK-NEXT: llvm.call @_Z36intel_sub_group_i8_i8_matrix_mad_k32Dv8_sDv8_iS0_([[A]], [[B]], %arg0) {passthrough = ["convergent"]} : (vector<8xi16>, vector<8xi32>, vector<8xi32>) -> vector<8xi32> - %0 = triton_gen.dpas %c, %a, %b {pa = i8, pb = i8, rc = 8} : (vector<8xi32>, vector<16xi8>, vector<32xi8>) -> vector<8xi32> +llvm.func @triton_gen.dpas.i8(%c : vector<8xi32>, %a : vector<8xi16>, %b : vector<8xi32>) { + // CHECK: llvm.func @triton_gen.dpas.i8(%arg0: vector<8xi32>, %arg1: vector<8xi16>, %arg2: vector<8xi32>) { + // CHECK-NEXT: llvm.call @_Z36intel_sub_group_i8_i8_matrix_mad_k32Dv8_sDv8_iS0_(%arg1, %arg2, %arg0) {passthrough = ["convergent"]} : (vector<8xi16>, vector<8xi32>, vector<8xi32>) -> vector<8xi32> + %0 = triton_gen.dpas %c, %a, %b {pa = i8, pb = i8, rc = 8} : (vector<8xi32>, vector<8xi16>, vector<8xi32>) -> vector<8xi32> llvm.return } @@ -174,12 +172,10 @@ llvm.func @triton_gen.dpas.i8(%c : vector<8xi32>, %a : vector<16xi8>, %b : vecto // CHECK: llvm.func spir_funccc @_Z36intel_sub_group_u8_u8_matrix_mad_k32Dv8_sDv8_iS0_(vector<8xi16>, vector<8xi32>, vector<8xi32>) -> vector<8xi32> attributes {passthrough = ["convergent"]} -llvm.func @triton_gen.dpas.u8(%c : vector<8xi32>, %a : vector<16xi8>, %b : vector<32xi8>) { - // CHECK: llvm.func @triton_gen.dpas.u8(%arg0: vector<8xi32>, %arg1: vector<16xi8>, %arg2: vector<32xi8>) { - // CHECK-DAG: [[A:%.*]] = llvm.bitcast %arg1 : vector<16xi8> to vector<8xi16> - // CHECK-DAG: [[B:%.*]] = llvm.bitcast %arg2 : vector<32xi8> to vector<8xi32> - // CHECK-NEXT: llvm.call @_Z36intel_sub_group_u8_u8_matrix_mad_k32Dv8_sDv8_iS0_([[A]], [[B]], %arg0) {passthrough = ["convergent"]} : (vector<8xi16>, vector<8xi32>, vector<8xi32>) -> vector<8xi32> - %0 = triton_gen.dpas %c, %a, %b {pa = u8, pb = u8, rc = 8} : (vector<8xi32>, vector<16xi8>, vector<32xi8>) -> vector<8xi32> +llvm.func @triton_gen.dpas.u8(%c : vector<8xi32>, %a : vector<8xi16>, %b : vector<8xi32>) { + // CHECK: llvm.func @triton_gen.dpas.u8(%arg0: vector<8xi32>, %arg1: vector<8xi16>, %arg2: vector<8xi32>) { + // CHECK-NEXT: llvm.call @_Z36intel_sub_group_u8_u8_matrix_mad_k32Dv8_sDv8_iS0_(%arg1, %arg2, %arg0) {passthrough = ["convergent"]} : (vector<8xi16>, vector<8xi32>, vector<8xi32>) -> vector<8xi32> + %0 = triton_gen.dpas %c, %a, %b {pa = u8, pb = u8, rc = 8} : (vector<8xi32>, vector<8xi16>, vector<8xi32>) -> vector<8xi32> llvm.return } @@ -187,12 +183,10 @@ llvm.func @triton_gen.dpas.u8(%c : vector<8xi32>, %a : vector<16xi8>, %b : vecto // CHECK: llvm.func spir_funccc @_Z40intel_sub_group_bf16_bf16_matrix_mad_k16Dv8_sDv8_iDv8_f(vector<8xi16>, vector<8xi32>, vector<8xf32>) -> vector<8xf32> attributes {passthrough = ["convergent"]} -llvm.func @triton_gen.dpas.bf16(%c : vector<8xf32>, %a : vector<8xbf16>, %b : vector<16xbf16>) { - // CHECK: llvm.func @triton_gen.dpas.bf16(%arg0: vector<8xf32>, %arg1: vector<8xbf16>, %arg2: vector<16xbf16>) { - // CHECK-DAG: [[A:%.*]] = llvm.bitcast %arg1 : vector<8xbf16> to vector<8xi16> - // CHECK-DAG: [[B:%.*]] = llvm.bitcast %arg2 : vector<16xbf16> to vector<8xi32> - // CHECK-NEXT: llvm.call @_Z40intel_sub_group_bf16_bf16_matrix_mad_k16Dv8_sDv8_iDv8_f([[A]], [[B]], %arg0) {passthrough = ["convergent"]} : (vector<8xi16>, vector<8xi32>, vector<8xf32>) -> vector<8xf32> - %0 = triton_gen.dpas %c, %a, %b {pa = bf16, pb = bf16, rc = 8} : (vector<8xf32>, vector<8xbf16>, vector<16xbf16>) -> vector<8xf32> +llvm.func @triton_gen.dpas.bf16(%c : vector<8xf32>, %a : vector<8xi16>, %b : vector<8xi32>) { + // CHECK: llvm.func @triton_gen.dpas.bf16(%arg0: vector<8xf32>, %arg1: vector<8xi16>, %arg2: vector<8xi32>) { + // CHECK-NEXT: llvm.call @_Z40intel_sub_group_bf16_bf16_matrix_mad_k16Dv8_sDv8_iDv8_f(%arg1, %arg2, %arg0) {passthrough = ["convergent"]} : (vector<8xi16>, vector<8xi32>, vector<8xf32>) -> vector<8xf32> + %0 = triton_gen.dpas %c, %a, %b {pa = bf16, pb = bf16, rc = 8} : (vector<8xf32>, vector<8xi16>, vector<8xi32>) -> vector<8xf32> llvm.return } @@ -200,12 +194,10 @@ llvm.func @triton_gen.dpas.bf16(%c : vector<8xf32>, %a : vector<8xbf16>, %b : ve // CHECK: llvm.func spir_funccc @_Z38intel_sub_group_f16_f16_matrix_mad_k16Dv8_sDv8_iDv8_f(vector<8xi16>, vector<8xi32>, vector<8xf32>) -> vector<8xf32> attributes {passthrough = ["convergent"]} -llvm.func @triton_gen.dpas.f16(%c : vector<8xf32>, %a : vector<8xf16>, %b : vector<16xf16>) { - // CHECK: llvm.func @triton_gen.dpas.f16(%arg0: vector<8xf32>, %arg1: vector<8xf16>, %arg2: vector<16xf16>) { - // CHECK-DAG: [[A:%.*]] = llvm.bitcast %arg1 : vector<8xf16> to vector<8xi16> - // CHECK-DAG: [[B:%.*]] = llvm.bitcast %arg2 : vector<16xf16> to vector<8xi32> - // CHECK-NEXT: llvm.call @_Z38intel_sub_group_f16_f16_matrix_mad_k16Dv8_sDv8_iDv8_f([[A]], [[B]], %arg0) {passthrough = ["convergent"]} : (vector<8xi16>, vector<8xi32>, vector<8xf32>) -> vector<8xf32> - %0 = triton_gen.dpas %c, %a, %b {pa = f16, pb = f16, rc = 8} : (vector<8xf32>, vector<8xf16>, vector<16xf16>) -> vector<8xf32> +llvm.func @triton_gen.dpas.f16(%c : vector<8xf32>, %a : vector<8xi16>, %b : vector<8xi32>) { + // CHECK: llvm.func @triton_gen.dpas.f16(%arg0: vector<8xf32>, %arg1: vector<8xi16>, %arg2: vector<8xi32>) { + // CHECK-NEXT: llvm.call @_Z38intel_sub_group_f16_f16_matrix_mad_k16Dv8_sDv8_iDv8_f(%arg1, %arg2, %arg0) {passthrough = ["convergent"]} : (vector<8xi16>, vector<8xi32>, vector<8xf32>) -> vector<8xf32> + %0 = triton_gen.dpas %c, %a, %b {pa = f16, pb = f16, rc = 8} : (vector<8xf32>, vector<8xi16>, vector<8xi32>) -> vector<8xf32> llvm.return } diff --git a/test/TritonGEN/tritongen.mlir b/test/TritonGEN/tritongen.mlir index 961c31d63c..12da8687a4 100644 --- a/test/TritonGEN/tritongen.mlir +++ b/test/TritonGEN/tritongen.mlir @@ -98,10 +98,10 @@ llvm.func @triton_gen.sub_group_shuffle() { llvm.return } -llvm.func @triton_gen.dpas(%c : vector<8xi32>, %a : vector<16xi8>, %b : vector<32xi8>) { - // CHECK: llvm.func @triton_gen.dpas(%arg0: vector<8xi32>, %arg1: vector<16xi8>, %arg2: vector<32xi8>) { - // CHECK-NEXT: %0 = triton_gen.dpas %arg0, %arg1, %arg2 {pa = i8, pb = i8, rc = 8} : (vector<8xi32>, vector<16xi8>, vector<32xi8>) -> vector<8xi32> - %0 = triton_gen.dpas %c, %a, %b {pa=i8, pb=i8, rc=8} : (vector<8xi32>, vector<16xi8>, vector<32xi8>) -> vector<8xi32> +llvm.func @triton_gen.dpas(%c : vector<8xi32>, %a : vector<8xi16>, %b : vector<8xi32>) { + // CHECK: llvm.func @triton_gen.dpas(%arg0: vector<8xi32>, %arg1: vector<8xi16>, %arg2: vector<8xi32>) { + // CHECK-NEXT: %0 = triton_gen.dpas %arg0, %arg1, %arg2 {pa = i8, pb = i8, rc = 8} : (vector<8xi32>, vector<8xi16>, vector<8xi32>) -> vector<8xi32> + %0 = triton_gen.dpas %c, %a, %b {pa=i8, pb=i8, rc=8} : (vector<8xi32>, vector<8xi16>, vector<8xi32>) -> vector<8xi32> llvm.return } diff --git a/third_party/intel/lib/Dialect/TritonGEN/IR/TritonGENOps.cpp b/third_party/intel/lib/Dialect/TritonGEN/IR/TritonGENOps.cpp index 10b3efac4d..82267b06b1 100644 --- a/third_party/intel/lib/Dialect/TritonGEN/IR/TritonGENOps.cpp +++ b/third_party/intel/lib/Dialect/TritonGEN/IR/TritonGENOps.cpp @@ -95,21 +95,15 @@ LogicalResult TritonGEN::MatrixDPASOp::verify() { return this->emitOpError("the dimension for 1st operand (C) and " "result (D) should match repeat count"); + constexpr unsigned SD = 8; + if (BTy.getNumElements() != SD) + return this->emitOpError("the dimension for the 3rd operand (B) should " + "match the systolic depth of 8"); + Type AElemTy = ATy.getElementType(); Type BElemTy = BTy.getElementType(); Type CElemTy = CTy.getElementType(); - // ATy is required to be vector as hard coded by IGC. - if (ATy.getNumElements() * AElemTy.getIntOrFloatBitWidth() != getRc() * 16) - return this->emitOpError( - "2nd operand (A) bit-size should be repeat count times 16"); - - // BTy is required to be vector as hard coded by IGC. - constexpr unsigned SD = 8; - if (BTy.getNumElements() * BElemTy.getIntOrFloatBitWidth() != SD * 32) - return this->emitOpError( - "3rd operand (B) bit-size should be systolic depth (8) times 32"); - if (precision == TritonGEN::PrecisionType::U8 || precision == TritonGEN::PrecisionType::S8) { if (!CElemTy.isInteger(32)) @@ -121,31 +115,31 @@ LogicalResult TritonGEN::MatrixDPASOp::verify() { switch (precision) { case TritonGEN::PrecisionType::TF32: + if (ATy.getNumElements() != getRc() / 2) + return this->emitOpError("the dimension for the 2nd operand (A) should " + "be equal to half of the repeat count"); if (!AElemTy.isa() && !AElemTy.isInteger(32)) - return this->emitOpError("A and B operand element type should be f32 or " - "i32 when precision type is tf32"); + return this->emitOpError("2nd operand (A) element type should be f32 or " + "i32 when the precision type is tf32"); + if (!BElemTy.isa() && !BElemTy.isInteger(32)) + return this->emitOpError("3rd operand (B) element type should be f32 or " + "i32 when the precision type is tf32"); break; case TritonGEN::PrecisionType::BF16: - if (!AElemTy.isa() && !AElemTy.isInteger(16)) - return this->emitOpError("A and B operand element type should be bf16 or " - "i16 when precision type is bf16"); - break; case TritonGEN::PrecisionType::FP16: - if (!AElemTy.isa() && !AElemTy.isInteger(16)) - return this->emitOpError("A and B operand element type should be f16 or " - "i16 when precision type is f16"); - break; case TritonGEN::PrecisionType::U8: - if (!(AElemTy.isInteger(8) && !AElemTy.cast().isSigned()) && - !AElemTy.isInteger(16)) - return this->emitOpError("A and B operand element type should be u8, i8, " - "or i16 when precision type is u8"); - break; case TritonGEN::PrecisionType::S8: - if (!(AElemTy.isInteger(8) && !AElemTy.cast().isUnsigned()) && - !AElemTy.isInteger(16)) - return this->emitOpError("A and B operand element type should be s8, i8, " - "or i16 when precision type is s8"); + if (ATy.getNumElements() != getRc()) + return this->emitOpError("2nd operand (A) should have the same number of " + "elements as repeat count"); + if (!AElemTy.isInteger(16)) + return this->emitOpError( + "2nd operand (A) element type should be i16 when " + "the precision type is not tf32"); + if (!BElemTy.isInteger(32)) + return this->emitOpError( + "3rd operand (B) element type should be i32 when " + "the precision type is not tf32"); break; default: return this->emitOpError(