Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[GEN] Further restrict DPAS input types #1062

Merged
merged 3 commits into from
May 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 68 additions & 19 deletions test/TritonGEN/tritongen-invalid.mlir
Original file line number Diff line number Diff line change
@@ -1,64 +1,113 @@
// RUN: triton-opt -split-input-file -verify-diagnostics %s

llvm.func @triton_gen.dpas(%c : vector<8xi32>, %a : vector<16xi8>, %b : vector<32xi8>) {
llvm.func @triton_gen.dpas(%c : vector<8xi32>, %a : vector<8xi16>, %b : vector<8xi32>) {
// expected-error @+1 {{'triton_gen.dpas' op expecting repeat count to be 1, 2, 4, or 8}}
%0 = triton_gen.dpas %c, %a, %b {pa=i8, pb=i8, rc=6} : (vector<8xi32>, vector<16xi8>, vector<32xi8>) -> vector<8xi32>
%0 = triton_gen.dpas %c, %a, %b {pa=i8, pb=i8, rc=16} : (vector<8xi32>, vector<8xi16>, vector<8xi32>) -> vector<8xi32>
llvm.return
}

// -----

llvm.func @triton_gen.dpas(%c : vector<8xi32>, %a : vector<16xi8>, %b : vector<32xi8>) {
llvm.func @triton_gen.dpas(%c : vector<8xi32>, %a : vector<8xi16>, %b : vector<8xi32>) {
// expected-error @+1 {{'triton_gen.dpas' op expecting precision of matrix A and B to be the same}}
%0 = triton_gen.dpas %c, %a, %b {pa=i8, pb=u8, rc=8} : (vector<8xi32>, vector<16xi8>, vector<32xi8>) -> vector<8xi32>
%0 = triton_gen.dpas %c, %a, %b {pa=i8, pb=f16, rc=8} : (vector<8xi32>, vector<8xi16>, vector<8xi32>) -> vector<8xi32>
llvm.return
}

// -----

llvm.func @triton_gen.dpas(%c : vector<8xi8>, %a : vector<16xi8>, %b : vector<32xi8>) {
llvm.func @triton_gen.dpas(%c : vector<8xi8>, %a : vector<8xi16>, %b : vector<8xi32>) {
// expected-error @+1 {{'triton_gen.dpas' op 1st operand (C) and result (D) should have the same type}}
%0 = triton_gen.dpas %c, %a, %b {pa=i8, pb=i8, rc=8} : (vector<8xi8>, vector<16xi8>, vector<32xi8>) -> vector<8xi32>
%0 = triton_gen.dpas %c, %a, %b {pa=i8, pb=i8, rc=8} : (vector<8xi8>, vector<8xi16>, vector<8xi32>) -> vector<8xi32>
llvm.return
}

// -----

llvm.func @triton_gen.dpas(%c : vector<16xi32>, %a : vector<16xi8>, %b : vector<32xi8>) {
llvm.func @triton_gen.dpas(%c : vector<16xi32>, %a : vector<8xi16>, %b : vector<8xi32>) {
// expected-error @+1 {{'triton_gen.dpas' op the dimension for 1st operand (C) and result (D) should match repeat count}}
%0 = triton_gen.dpas %c, %a, %b {pa=i8, pb=i8, rc=8} : (vector<16xi32>, vector<16xi8>, vector<32xi8>) -> vector<16xi32>
%0 = triton_gen.dpas %c, %a, %b {pa=i8, pb=i8, rc=8} : (vector<16xi32>, vector<8xi16>, vector<8xi32>) -> vector<16xi32>
llvm.return
}

// -----

llvm.func @triton_gen.dpas(%c : vector<8xi32>, %a : vector<8xi8>, %b : vector<8xi8>) {
// expected-error @+1 {{'triton_gen.dpas' op 2nd operand (A) bit-size should be repeat count times 16}}
%0 = triton_gen.dpas %c, %a, %b {pa=i8, pb=i8, rc=8} : (vector<8xi32>, vector<8xi8>, vector<8xi8>) -> vector<8xi32>
llvm.func @triton_gen.dpas(%c : vector<8xi32>, %a : vector<8xi16>, %b : vector<16xi32>) {
// expected-error @+1 {{'triton_gen.dpas' op the dimension for the 3rd operand (B) should match the systolic depth of 8}}
%0 = triton_gen.dpas %c, %a, %b {pa=i8, pb=i8, rc=8} : (vector<8xi32>, vector<8xi16>, vector<16xi32>) -> vector<8xi32>
llvm.return
}

// -----

llvm.func @triton_gen.dpas(%c : vector<8xi32>, %a : vector<16xi8>, %b : vector<16xi8>) {
// expected-error @+1 {{'triton_gen.dpas' op 3rd operand (B) bit-size should be systolic depth (8) times 32}}
%0 = triton_gen.dpas %c, %a, %b {pa=i8, pb=i8, rc=8} : (vector<8xi32>, vector<16xi8>, vector<16xi8>) -> vector<8xi32>
llvm.func @triton_gen.dpas(%c : vector<8xi8>, %a : vector<8xi16>, %b : vector<8xi32>) {
// expected-error @+1 {{'triton_gen.dpas' op the element type for 1st operand (C) and the result should be i32}}
%0 = triton_gen.dpas %c, %a, %b {pa=i8, pb=i8, rc=8} : (vector<8xi8>, vector<8xi16>, vector<8xi32>) -> vector<8xi8>
llvm.return
}

// -----

llvm.func @triton_gen.dpas(%c : vector<8xi8>, %a : vector<16xi8>, %b : vector<32xi8>) {
// expected-error @+1 {{'triton_gen.dpas' op the element type for 1st operand (C) and the result should be i32}}
%0 = triton_gen.dpas %c, %a, %b {pa=i8, pb=i8, rc=8} : (vector<8xi8>, vector<16xi8>, vector<32xi8>) -> vector<8xi8>
llvm.func @triton_gen.dpas(%c : vector<8xi32>, %a : vector<8xi16>, %b : vector<8xi32>) {
// expected-error @+1 {{'triton_gen.dpas' op the element type for 1st operand (C) and the result should be f32}}
%0 = triton_gen.dpas %c, %a, %b {pa=f16, pb=f16, rc=8} : (vector<8xi32>, vector<8xi16>, vector<8xi32>) -> vector<8xi32>
llvm.return
}

// -----

llvm.func @triton_gen.dpas(%c : vector<8xf32>, %a : vector<8xf32>, %b : vector<8xf32>) {
// expected-error @+1 {{'triton_gen.dpas' op the dimension for the 2nd operand (A) should be equal to half of the repeat count}}
%0 = triton_gen.dpas %c, %a, %b {pa = tf32, pb = tf32, rc = 8} : (vector<8xf32>, vector<8xf32>, vector<8xf32>) -> vector<8xf32>
llvm.return
}

// -----

llvm.func @triton_gen.dpas(%c : vector<8xf32>, %a : vector<4xi16>, %b : vector<8xf32>) {
// expected-error @+1 {{'triton_gen.dpas' op 2nd operand (A) element type should be f32 or i32 when the precision type is tf32}}
%0 = triton_gen.dpas %c, %a, %b {pa = tf32, pb = tf32, rc = 8} : (vector<8xf32>, vector<4xi16>, vector<8xf32>) -> vector<8xf32>
llvm.return
}

// -----


llvm.func @triton_gen.dpas(%c : vector<8xf32>, %a : vector<4xf32>, %b : vector<8xi16>) {
// expected-error @+1 {{'triton_gen.dpas' op 3rd operand (B) element type should be f32 or i32 when the precision type is tf32}}
%0 = triton_gen.dpas %c, %a, %b {pa = tf32, pb = tf32, rc = 8} : (vector<8xf32>, vector<4xf32>, vector<8xi16>) -> vector<8xf32>
llvm.return
}

// -----

llvm.func @triton_gen.dpas(%c : vector<8xi32>, %a : vector<4xi16>, %b : vector<8xi32>) {
// expected-error @+1 {{'triton_gen.dpas' op 2nd operand (A) should have the same number of elements as repeat count}}
%0 = triton_gen.dpas %c, %a, %b {pa=i8, pb=i8, rc=8} : (vector<8xi32>, vector<4xi16>, vector<8xi32>) -> vector<8xi32>
llvm.return
}

// -----

llvm.func @triton_gen.dpas(%c : vector<8xi32>, %a : vector<8xi8>, %b : vector<8xi32>) {
// expected-error @+1 {{'triton_gen.dpas' op 2nd operand (A) element type should be i16 when the precision type is not tf32}}
%0 = triton_gen.dpas %c, %a, %b {pa=i8, pb=i8, rc=8} : (vector<8xi32>, vector<8xi8>, vector<8xi32>) -> vector<8xi32>
llvm.return
}

// -----

llvm.func @triton_gen.dpas(%c : vector<8xf32>, %a : vector<8xi16>, %b : vector<8xf32>) {
// expected-error @+1 {{'triton_gen.dpas' op 3rd operand (B) element type should be i32 when the precision type is not tf32}}
%0 = triton_gen.dpas %c, %a, %b {pa=f16, pb=f16, rc=8} : (vector<8xf32>, vector<8xi16>, vector<8xf32>) -> vector<8xf32>
llvm.return
}

// -----

llvm.func @triton_gen.dpas(%c : vector<8xf32>, %a : vector<8xf16>, %b : vector<16xf16>) {
llvm.func @triton_gen.dpas(%c : vector<8xf32>, %a : vector<8xi16>, %b : vector<8xi32>) {
// expected-error @+1 {{'triton_gen.dpas' op expecting precision type to be tf32, bf16, fp16, u8, or s8}}
%0 = triton_gen.dpas %c, %a, %b {pa=i4, pb=i4, rc=8} : (vector<8xf32>, vector<8xf16>, vector<16xf16>) -> vector<8xf32>
%0 = triton_gen.dpas %c, %a, %b {pa=i4, pb=i4, rc=8} : (vector<8xf32>, vector<8xi16>, vector<8xi32>) -> vector<8xf32>
llvm.return
}

Expand Down
40 changes: 16 additions & 24 deletions test/TritonGEN/tritongen-to-llvm.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -161,51 +161,43 @@ llvm.func @triton_gen.sub_group_shuffle() {

// CHECK: llvm.func spir_funccc @_Z36intel_sub_group_i8_i8_matrix_mad_k32Dv8_sDv8_iS0_(vector<8xi16>, vector<8xi32>, vector<8xi32>) -> vector<8xi32> attributes {passthrough = ["convergent"]}

llvm.func @triton_gen.dpas.i8(%c : vector<8xi32>, %a : vector<16xi8>, %b : vector<32xi8>) {
// CHECK: llvm.func @triton_gen.dpas.i8(%arg0: vector<8xi32>, %arg1: vector<16xi8>, %arg2: vector<32xi8>) {
// CHECK-DAG: [[A:%.*]] = llvm.bitcast %arg1 : vector<16xi8> to vector<8xi16>
// CHECK-DAG: [[B:%.*]] = llvm.bitcast %arg2 : vector<32xi8> to vector<8xi32>
// CHECK-NEXT: llvm.call @_Z36intel_sub_group_i8_i8_matrix_mad_k32Dv8_sDv8_iS0_([[A]], [[B]], %arg0) {passthrough = ["convergent"]} : (vector<8xi16>, vector<8xi32>, vector<8xi32>) -> vector<8xi32>
%0 = triton_gen.dpas %c, %a, %b {pa = i8, pb = i8, rc = 8} : (vector<8xi32>, vector<16xi8>, vector<32xi8>) -> vector<8xi32>
llvm.func @triton_gen.dpas.i8(%c : vector<8xi32>, %a : vector<8xi16>, %b : vector<8xi32>) {
// CHECK: llvm.func @triton_gen.dpas.i8(%arg0: vector<8xi32>, %arg1: vector<8xi16>, %arg2: vector<8xi32>) {
// CHECK-NEXT: llvm.call @_Z36intel_sub_group_i8_i8_matrix_mad_k32Dv8_sDv8_iS0_(%arg1, %arg2, %arg0) {passthrough = ["convergent"]} : (vector<8xi16>, vector<8xi32>, vector<8xi32>) -> vector<8xi32>
%0 = triton_gen.dpas %c, %a, %b {pa = i8, pb = i8, rc = 8} : (vector<8xi32>, vector<8xi16>, vector<8xi32>) -> vector<8xi32>
llvm.return
}

// -----

// CHECK: llvm.func spir_funccc @_Z36intel_sub_group_u8_u8_matrix_mad_k32Dv8_sDv8_iS0_(vector<8xi16>, vector<8xi32>, vector<8xi32>) -> vector<8xi32> attributes {passthrough = ["convergent"]}

llvm.func @triton_gen.dpas.u8(%c : vector<8xi32>, %a : vector<16xi8>, %b : vector<32xi8>) {
// CHECK: llvm.func @triton_gen.dpas.u8(%arg0: vector<8xi32>, %arg1: vector<16xi8>, %arg2: vector<32xi8>) {
// CHECK-DAG: [[A:%.*]] = llvm.bitcast %arg1 : vector<16xi8> to vector<8xi16>
// CHECK-DAG: [[B:%.*]] = llvm.bitcast %arg2 : vector<32xi8> to vector<8xi32>
// CHECK-NEXT: llvm.call @_Z36intel_sub_group_u8_u8_matrix_mad_k32Dv8_sDv8_iS0_([[A]], [[B]], %arg0) {passthrough = ["convergent"]} : (vector<8xi16>, vector<8xi32>, vector<8xi32>) -> vector<8xi32>
%0 = triton_gen.dpas %c, %a, %b {pa = u8, pb = u8, rc = 8} : (vector<8xi32>, vector<16xi8>, vector<32xi8>) -> vector<8xi32>
llvm.func @triton_gen.dpas.u8(%c : vector<8xi32>, %a : vector<8xi16>, %b : vector<8xi32>) {
// CHECK: llvm.func @triton_gen.dpas.u8(%arg0: vector<8xi32>, %arg1: vector<8xi16>, %arg2: vector<8xi32>) {
// CHECK-NEXT: llvm.call @_Z36intel_sub_group_u8_u8_matrix_mad_k32Dv8_sDv8_iS0_(%arg1, %arg2, %arg0) {passthrough = ["convergent"]} : (vector<8xi16>, vector<8xi32>, vector<8xi32>) -> vector<8xi32>
%0 = triton_gen.dpas %c, %a, %b {pa = u8, pb = u8, rc = 8} : (vector<8xi32>, vector<8xi16>, vector<8xi32>) -> vector<8xi32>
llvm.return
}

// -----

// CHECK: llvm.func spir_funccc @_Z40intel_sub_group_bf16_bf16_matrix_mad_k16Dv8_sDv8_iDv8_f(vector<8xi16>, vector<8xi32>, vector<8xf32>) -> vector<8xf32> attributes {passthrough = ["convergent"]}

llvm.func @triton_gen.dpas.bf16(%c : vector<8xf32>, %a : vector<8xbf16>, %b : vector<16xbf16>) {
// CHECK: llvm.func @triton_gen.dpas.bf16(%arg0: vector<8xf32>, %arg1: vector<8xbf16>, %arg2: vector<16xbf16>) {
// CHECK-DAG: [[A:%.*]] = llvm.bitcast %arg1 : vector<8xbf16> to vector<8xi16>
// CHECK-DAG: [[B:%.*]] = llvm.bitcast %arg2 : vector<16xbf16> to vector<8xi32>
// CHECK-NEXT: llvm.call @_Z40intel_sub_group_bf16_bf16_matrix_mad_k16Dv8_sDv8_iDv8_f([[A]], [[B]], %arg0) {passthrough = ["convergent"]} : (vector<8xi16>, vector<8xi32>, vector<8xf32>) -> vector<8xf32>
%0 = triton_gen.dpas %c, %a, %b {pa = bf16, pb = bf16, rc = 8} : (vector<8xf32>, vector<8xbf16>, vector<16xbf16>) -> vector<8xf32>
llvm.func @triton_gen.dpas.bf16(%c : vector<8xf32>, %a : vector<8xi16>, %b : vector<8xi32>) {
// CHECK: llvm.func @triton_gen.dpas.bf16(%arg0: vector<8xf32>, %arg1: vector<8xi16>, %arg2: vector<8xi32>) {
// CHECK-NEXT: llvm.call @_Z40intel_sub_group_bf16_bf16_matrix_mad_k16Dv8_sDv8_iDv8_f(%arg1, %arg2, %arg0) {passthrough = ["convergent"]} : (vector<8xi16>, vector<8xi32>, vector<8xf32>) -> vector<8xf32>
%0 = triton_gen.dpas %c, %a, %b {pa = bf16, pb = bf16, rc = 8} : (vector<8xf32>, vector<8xi16>, vector<8xi32>) -> vector<8xf32>
llvm.return
}

// -----

// CHECK: llvm.func spir_funccc @_Z38intel_sub_group_f16_f16_matrix_mad_k16Dv8_sDv8_iDv8_f(vector<8xi16>, vector<8xi32>, vector<8xf32>) -> vector<8xf32> attributes {passthrough = ["convergent"]}

llvm.func @triton_gen.dpas.f16(%c : vector<8xf32>, %a : vector<8xf16>, %b : vector<16xf16>) {
// CHECK: llvm.func @triton_gen.dpas.f16(%arg0: vector<8xf32>, %arg1: vector<8xf16>, %arg2: vector<16xf16>) {
// CHECK-DAG: [[A:%.*]] = llvm.bitcast %arg1 : vector<8xf16> to vector<8xi16>
// CHECK-DAG: [[B:%.*]] = llvm.bitcast %arg2 : vector<16xf16> to vector<8xi32>
// CHECK-NEXT: llvm.call @_Z38intel_sub_group_f16_f16_matrix_mad_k16Dv8_sDv8_iDv8_f([[A]], [[B]], %arg0) {passthrough = ["convergent"]} : (vector<8xi16>, vector<8xi32>, vector<8xf32>) -> vector<8xf32>
%0 = triton_gen.dpas %c, %a, %b {pa = f16, pb = f16, rc = 8} : (vector<8xf32>, vector<8xf16>, vector<16xf16>) -> vector<8xf32>
llvm.func @triton_gen.dpas.f16(%c : vector<8xf32>, %a : vector<8xi16>, %b : vector<8xi32>) {
// CHECK: llvm.func @triton_gen.dpas.f16(%arg0: vector<8xf32>, %arg1: vector<8xi16>, %arg2: vector<8xi32>) {
// CHECK-NEXT: llvm.call @_Z38intel_sub_group_f16_f16_matrix_mad_k16Dv8_sDv8_iDv8_f(%arg1, %arg2, %arg0) {passthrough = ["convergent"]} : (vector<8xi16>, vector<8xi32>, vector<8xf32>) -> vector<8xf32>
%0 = triton_gen.dpas %c, %a, %b {pa = f16, pb = f16, rc = 8} : (vector<8xf32>, vector<8xi16>, vector<8xi32>) -> vector<8xf32>
llvm.return
}

Expand Down
8 changes: 4 additions & 4 deletions test/TritonGEN/tritongen.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -98,10 +98,10 @@ llvm.func @triton_gen.sub_group_shuffle() {
llvm.return
}

llvm.func @triton_gen.dpas(%c : vector<8xi32>, %a : vector<16xi8>, %b : vector<32xi8>) {
// CHECK: llvm.func @triton_gen.dpas(%arg0: vector<8xi32>, %arg1: vector<16xi8>, %arg2: vector<32xi8>) {
// CHECK-NEXT: %0 = triton_gen.dpas %arg0, %arg1, %arg2 {pa = i8, pb = i8, rc = 8} : (vector<8xi32>, vector<16xi8>, vector<32xi8>) -> vector<8xi32>
%0 = triton_gen.dpas %c, %a, %b {pa=i8, pb=i8, rc=8} : (vector<8xi32>, vector<16xi8>, vector<32xi8>) -> vector<8xi32>
llvm.func @triton_gen.dpas(%c : vector<8xi32>, %a : vector<8xi16>, %b : vector<8xi32>) {
// CHECK: llvm.func @triton_gen.dpas(%arg0: vector<8xi32>, %arg1: vector<8xi16>, %arg2: vector<8xi32>) {
// CHECK-NEXT: %0 = triton_gen.dpas %arg0, %arg1, %arg2 {pa = i8, pb = i8, rc = 8} : (vector<8xi32>, vector<8xi16>, vector<8xi32>) -> vector<8xi32>
%0 = triton_gen.dpas %c, %a, %b {pa=i8, pb=i8, rc=8} : (vector<8xi32>, vector<8xi16>, vector<8xi32>) -> vector<8xi32>
llvm.return
}

Expand Down
54 changes: 24 additions & 30 deletions third_party/intel/lib/Dialect/TritonGEN/IR/TritonGENOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -95,21 +95,15 @@ LogicalResult TritonGEN::MatrixDPASOp::verify() {
return this->emitOpError("the dimension for 1st operand (C) and "
"result (D) should match repeat count");

constexpr unsigned SD = 8;
if (BTy.getNumElements() != SD)
return this->emitOpError("the dimension for the 3rd operand (B) should "
"match the systolic depth of 8");

Type AElemTy = ATy.getElementType();
Type BElemTy = BTy.getElementType();
Type CElemTy = CTy.getElementType();

// ATy is required to be vector<RC x i16> as hard coded by IGC.
if (ATy.getNumElements() * AElemTy.getIntOrFloatBitWidth() != getRc() * 16)
return this->emitOpError(
"2nd operand (A) bit-size should be repeat count times 16");

// BTy is required to be vector<SD x i32> as hard coded by IGC.
constexpr unsigned SD = 8;
if (BTy.getNumElements() * BElemTy.getIntOrFloatBitWidth() != SD * 32)
return this->emitOpError(
"3rd operand (B) bit-size should be systolic depth (8) times 32");

if (precision == TritonGEN::PrecisionType::U8 ||
precision == TritonGEN::PrecisionType::S8) {
if (!CElemTy.isInteger(32))
Expand All @@ -121,31 +115,31 @@ LogicalResult TritonGEN::MatrixDPASOp::verify() {

switch (precision) {
case TritonGEN::PrecisionType::TF32:
if (ATy.getNumElements() != getRc() / 2)
return this->emitOpError("the dimension for the 2nd operand (A) should "
"be equal to half of the repeat count");
if (!AElemTy.isa<Float32Type>() && !AElemTy.isInteger(32))
return this->emitOpError("A and B operand element type should be f32 or "
"i32 when precision type is tf32");
return this->emitOpError("2nd operand (A) element type should be f32 or "
"i32 when the precision type is tf32");
if (!BElemTy.isa<Float32Type>() && !BElemTy.isInteger(32))
return this->emitOpError("3rd operand (B) element type should be f32 or "
"i32 when the precision type is tf32");
break;
case TritonGEN::PrecisionType::BF16:
if (!AElemTy.isa<BFloat16Type>() && !AElemTy.isInteger(16))
return this->emitOpError("A and B operand element type should be bf16 or "
"i16 when precision type is bf16");
break;
case TritonGEN::PrecisionType::FP16:
if (!AElemTy.isa<Float16Type>() && !AElemTy.isInteger(16))
return this->emitOpError("A and B operand element type should be f16 or "
"i16 when precision type is f16");
break;
case TritonGEN::PrecisionType::U8:
if (!(AElemTy.isInteger(8) && !AElemTy.cast<IntegerType>().isSigned()) &&
!AElemTy.isInteger(16))
return this->emitOpError("A and B operand element type should be u8, i8, "
"or i16 when precision type is u8");
break;
case TritonGEN::PrecisionType::S8:
if (!(AElemTy.isInteger(8) && !AElemTy.cast<IntegerType>().isUnsigned()) &&
!AElemTy.isInteger(16))
return this->emitOpError("A and B operand element type should be s8, i8, "
"or i16 when precision type is s8");
if (ATy.getNumElements() != getRc())
return this->emitOpError("2nd operand (A) should have the same number of "
"elements as repeat count");
if (!AElemTy.isInteger(16))
return this->emitOpError(
"2nd operand (A) element type should be i16 when "
"the precision type is not tf32");
if (!BElemTy.isInteger(32))
return this->emitOpError(
"3rd operand (B) element type should be i32 when "
"the precision type is not tf32");
break;
default:
return this->emitOpError(
Expand Down