From cd460dbda99c080b9f9ec1cbdcd233f25ba8c189 Mon Sep 17 00:00:00 2001 From: Tanner Gooding Date: Tue, 30 Jan 2024 21:21:27 -0800 Subject: [PATCH] Add the barebones support for using embedded masking with AVX512 (#97675) * Add the barebones support for using embedded masking with AVX512 * Applying formatting patch * Add some basic asserts to ensure _idCustom# isn't used incorrectly * Ensure that the instruction check is correct for TlsGD --- src/coreclr/jit/emit.h | 68 +++++++++++-- src/coreclr/jit/emitxarch.cpp | 97 ++++++++++++------ src/coreclr/jit/emitxarch.h | 30 +++++- src/coreclr/jit/gentree.cpp | 62 ++++++++++-- src/coreclr/jit/gentree.h | 30 +++++- src/coreclr/jit/hwintrinsic.h | 11 ++- src/coreclr/jit/hwintrinsiccodegenxarch.cpp | 103 +++++++++++++++++++- src/coreclr/jit/instr.cpp | 2 +- src/coreclr/jit/instr.h | 51 +++++++++- src/coreclr/jit/lowerxarch.cpp | 61 ++++++++++++ src/coreclr/jit/lsraxarch.cpp | 46 +++++++++ 11 files changed, 501 insertions(+), 60 deletions(-) diff --git a/src/coreclr/jit/emit.h b/src/coreclr/jit/emit.h index e0ce3adb529ee..87a540564245e 100644 --- a/src/coreclr/jit/emit.h +++ b/src/coreclr/jit/emit.h @@ -768,12 +768,26 @@ class emitter unsigned _idLargeDsp : 1; // does a large displacement follow? unsigned _idLargeCall : 1; // large call descriptor used - unsigned _idBound : 1; // jump target / frame offset bound -#ifndef TARGET_ARMARCH - unsigned _idCallRegPtr : 1; // IL indirect calls: addr in reg -#endif - unsigned _idTlsGD : 1; // Used to store information related to TLS GD access on linux - unsigned _idNoGC : 1; // Some helpers don't get recorded in GC tables + // We have several pieces of information we need to encode but which are only applicable + // to a subset of instrDescs. To accommodate that, we define a several _idCustom# bitfields + // and then some defineds to make accessing them simpler + + unsigned _idCustom1 : 1; + unsigned _idCustom2 : 1; + unsigned _idCustom3 : 1; + +#define _idBound _idCustom1 /* jump target / frame offset bound */ +#define _idTlsGD _idCustom2 /* Used to store information related to TLS GD access on linux */ +#define _idNoGC _idCustom3 /* Some helpers don't get recorded in GC tables */ +#define _idEvexAaaContext (_idCustom3 << 2) | (_idCustom2 << 1) | _idCustom1 /* bits used for the EVEX.aaa context */ + +#if !defined(TARGET_ARMARCH) + unsigned _idCustom4 : 1; + +#define _idCallRegPtr _idCustom4 /* IL indirect calls : addr in reg */ +#define _idEvexZContext _idCustom4 /* bits used for the EVEX.z context */ +#endif // !TARGET_ARMARCH + #if defined(TARGET_XARCH) // EVEX.b can indicate several context: embedded broadcast, embedded rounding. // For normal and embedded broadcast intrinsics, EVEX.L'L has the same semantic, vector length. @@ -1578,30 +1592,36 @@ class emitter bool idIsBound() const { + assert(!IsAvx512OrPriorInstruction(_idIns)); return _idBound != 0; } void idSetIsBound() { + assert(!IsAvx512OrPriorInstruction(_idIns)); _idBound = 1; } #ifndef TARGET_ARMARCH bool idIsCallRegPtr() const { + assert(!IsAvx512OrPriorInstruction(_idIns)); return _idCallRegPtr != 0; } void idSetIsCallRegPtr() { + assert(!IsAvx512OrPriorInstruction(_idIns)); _idCallRegPtr = 1; } -#endif +#endif // !TARGET_ARMARCH bool idIsTlsGD() const { + assert(!IsAvx512OrPriorInstruction(_idIns)); return _idTlsGD != 0; } void idSetTlsGD() { + assert(!IsAvx512OrPriorInstruction(_idIns)); _idTlsGD = 1; } @@ -1610,10 +1630,12 @@ class emitter // code, it is not necessary to generate GC info for a call so labeled. bool idIsNoGC() const { + assert(!IsAvx512OrPriorInstruction(_idIns)); return _idNoGC != 0; } void idSetIsNoGC(bool val) { + assert(!IsAvx512OrPriorInstruction(_idIns)); _idNoGC = val; } @@ -1625,7 +1647,8 @@ class emitter void idSetEvexbContext(insOpts instOptions) { - assert(_idEvexbContext == 0); + assert(!idIsEvexbContextSet()); + if (instOptions == INS_OPTS_EVEX_eb_er_rd) { _idEvexbContext = 1; @@ -1648,6 +1671,34 @@ class emitter { return _idEvexbContext; } + + unsigned idGetEvexAaaContext() const + { + assert(IsAvx512OrPriorInstruction(_idIns)); + return _idEvexAaaContext; + } + + void idSetEvexAaaContext(insOpts instOptions) + { + assert(idGetEvexAaaContext() == 0); + unsigned value = static_cast((instOptions & INS_OPTS_EVEX_aaa_MASK) >> 2); + + _idCustom1 = ((value >> 0) & 1); + _idCustom2 = ((value >> 1) & 1); + _idCustom3 = ((value >> 2) & 1); + } + + bool idIsEvexZContextSet() const + { + assert(IsAvx512OrPriorInstruction(_idIns)); + return _idEvexZContext != 0; + } + + void idSetEvexZContext() + { + assert(!idIsEvexZContextSet()); + _idEvexZContext = 1; + } #endif #ifdef TARGET_ARMARCH @@ -2222,6 +2273,7 @@ class emitter void emitDispInsHex(instrDesc* id, BYTE* code, size_t sz); void emitDispEmbBroadcastCount(instrDesc* id); void emitDispEmbRounding(instrDesc* id); + void emitDispEmbMasking(instrDesc* id); void emitDispIns(instrDesc* id, bool isNew, bool doffs, diff --git a/src/coreclr/jit/emitxarch.cpp b/src/coreclr/jit/emitxarch.cpp index 896681e72ada8..d5dc2fd9530a6 100644 --- a/src/coreclr/jit/emitxarch.cpp +++ b/src/coreclr/jit/emitxarch.cpp @@ -49,23 +49,6 @@ bool emitter::IsKInstruction(instruction ins) return (flags & KInstruction) != 0; } -//------------------------------------------------------------------------ -// IsAvx512OrPriorInstruction: Is this an Avx512 or Avx or Sse or K (opmask) instruction. -// Technically, K instructions would be considered under the VEX encoding umbrella, but due to -// the instruction table encoding had to be pulled out with the rest of the `INST5` definitions. -// -// Arguments: -// ins - The instruction to check. -// -// Returns: -// `true` if it is a sse or avx or avx512 instruction. -// -bool emitter::IsAvx512OrPriorInstruction(instruction ins) -{ - // TODO-XArch-AVX512: Fix check once AVX512 instructions are added. - return ((ins >= INS_FIRST_SSE_INSTRUCTION) && (ins <= INS_LAST_AVX512_INSTRUCTION)); -} - bool emitter::IsAVXOnlyInstruction(instruction ins) { return (ins >= INS_FIRST_AVX_INSTRUCTION) && (ins <= INS_LAST_AVX_INSTRUCTION); @@ -1304,9 +1287,10 @@ bool emitter::TakesEvexPrefix(const instrDesc* id) const #define DEFAULT_BYTE_EVEX_PREFIX 0x62F07C0800000000ULL #define DEFAULT_BYTE_EVEX_PREFIX_MASK 0xFFFFFFFF00000000ULL +#define BBIT_IN_BYTE_EVEX_PREFIX 0x0000001000000000ULL #define LBIT_IN_BYTE_EVEX_PREFIX 0x0000002000000000ULL #define LPRIMEBIT_IN_BYTE_EVEX_PREFIX 0x0000004000000000ULL -#define EVEX_B_BIT 0x0000001000000000ULL +#define ZBIT_IN_BYTE_EVEX_PREFIX 0x0000008000000000ULL //------------------------------------------------------------------------ // AddEvexPrefix: Add default EVEX prefix with only LL' bits set. @@ -1344,7 +1328,7 @@ emitter::code_t emitter::AddEvexPrefix(const instrDesc* id, code_t code, emitAtt if (id->idIsEvexbContextSet()) { - code |= EVEX_B_BIT; + code |= BBIT_IN_BYTE_EVEX_PREFIX; if (!id->idHasMem()) { @@ -1385,6 +1369,8 @@ emitter::code_t emitter::AddEvexPrefix(const instrDesc* id, code_t code, emitAtt { case IF_RWR_RRD_ARD_RRD: { + assert(id->idGetEvexAaaContext() == 0); + CnsVal cnsVal; emitGetInsAmdCns(id, &cnsVal); @@ -1394,6 +1380,8 @@ emitter::code_t emitter::AddEvexPrefix(const instrDesc* id, code_t code, emitAtt case IF_RWR_RRD_MRD_RRD: { + assert(id->idGetEvexAaaContext() == 0); + CnsVal cnsVal; emitGetInsDcmCns(id, &cnsVal); @@ -1403,6 +1391,8 @@ emitter::code_t emitter::AddEvexPrefix(const instrDesc* id, code_t code, emitAtt case IF_RWR_RRD_SRD_RRD: { + assert(id->idGetEvexAaaContext() == 0); + CnsVal cnsVal; emitGetInsCns(id, &cnsVal); @@ -1412,12 +1402,24 @@ emitter::code_t emitter::AddEvexPrefix(const instrDesc* id, code_t code, emitAtt case IF_RWR_RRD_RRD_RRD: { + assert(id->idGetEvexAaaContext() == 0); maskReg = id->idReg4(); break; } default: { + unsigned aaaContext = id->idGetEvexAaaContext(); + + if (aaaContext != 0) + { + maskReg = static_cast(aaaContext + KBASE); + + if (id->idIsEvexZContextSet()) + { + code |= ZBIT_IN_BYTE_EVEX_PREFIX; + } + } break; } } @@ -4170,9 +4172,8 @@ UNATIVE_OFFSET emitter::emitInsSizeAM(instrDesc* id, code_t code) } // If this is just "call reg", we're done. - if (id->idIsCallRegPtr()) + if (((ins == INS_call) || (ins == INS_tail_i_jmp)) && id->idIsCallRegPtr()) { - assert(ins == INS_call || ins == INS_tail_i_jmp); assert(dsp == 0); return size; } @@ -6822,7 +6823,9 @@ void emitter::emitIns_R_R_A( id->idIns(ins); id->idReg1(reg1); id->idReg2(reg2); + SetEvexBroadcastIfNeeded(id, instOptions); + SetEvexEmbMaskIfNeeded(id, instOptions); emitHandleMemOp(indir, id, (ins == INS_mulx) ? IF_RWR_RWR_ARD : emitInsModeFormat(ins, IF_RRD_RRD_ARD), ins); @@ -6947,7 +6950,9 @@ void emitter::emitIns_R_R_C(instruction ins, id->idReg1(reg1); id->idReg2(reg2); id->idAddr()->iiaFieldHnd = fldHnd; + SetEvexBroadcastIfNeeded(id, instOptions); + SetEvexEmbMaskIfNeeded(id, instOptions); UNATIVE_OFFSET sz = emitInsSizeCV(id, insCodeRM(ins)); id->idCodeSize(sz); @@ -6974,12 +6979,13 @@ void emitter::emitIns_R_R_R( id->idReg2(reg1); id->idReg3(reg2); - if ((instOptions & INS_OPTS_b_MASK) != INS_OPTS_NONE) + if ((instOptions & INS_OPTS_EVEX_b_MASK) != 0) { // if EVEX.b needs to be set in this path, then it should be embedded rounding. assert(UseEvexEncoding()); id->idSetEvexbContext(instOptions); } + SetEvexEmbMaskIfNeeded(id, instOptions); UNATIVE_OFFSET sz = emitInsSizeRR(id, insCodeRM(ins)); id->idCodeSize(sz); @@ -7001,7 +7007,9 @@ void emitter::emitIns_R_R_S( id->idReg1(reg1); id->idReg2(reg2); id->idAddr()->iiaLclVar.initLclVarAddr(varx, offs); + SetEvexBroadcastIfNeeded(id, instOptions); + SetEvexEmbMaskIfNeeded(id, instOptions); #ifdef DEBUG id->idDebugOnlyInfo()->idVarRefOffs = emitVarRefOffs; @@ -10785,6 +10793,28 @@ void emitter::emitDispEmbRounding(instrDesc* id) } } +// emitDispEmbMasking: Display the tag where embedded masking is activated +// +// Arguments: +// id - The instruction descriptor +// +void emitter::emitDispEmbMasking(instrDesc* id) +{ + regNumber maskReg = static_cast(id->idGetEvexAaaContext() + KBASE); + + if (maskReg == REG_K0) + { + return; + } + + printf(" {%s}", emitRegName(maskReg)); + + if (id->idIsEvexZContextSet()) + { + printf(" {z}"); + } +} + //-------------------------------------------------------------------- // emitDispIns: Dump the given instruction to jitstdout. // @@ -11033,7 +11063,7 @@ void emitter::emitDispIns( case IF_AWR: case IF_ARW: { - if (id->idIsCallRegPtr()) + if (((ins == INS_call) || (ins == INS_tail_i_jmp)) && id->idIsCallRegPtr()) { printf("%s", emitRegName(id->idAddr()->iiaAddrMode.amBaseReg)); } @@ -11184,7 +11214,9 @@ void emitter::emitDispIns( case IF_RRW_RRD_ARD: case IF_RWR_RWR_ARD: { - printf("%s, %s, %s", emitRegName(id->idReg1(), attr), emitRegName(id->idReg2(), attr), sstr); + printf("%s", emitRegName(id->idReg1(), attr)); + emitDispEmbMasking(id); + printf(", %s, %s", emitRegName(id->idReg2(), attr), sstr); emitDispAddrMode(id); emitDispEmbBroadcastCount(id); break; @@ -11458,7 +11490,9 @@ void emitter::emitDispIns( case IF_RRW_RRD_SRD: case IF_RWR_RWR_SRD: { - printf("%s, %s, %s", emitRegName(id->idReg1(), attr), emitRegName(id->idReg2(), attr), sstr); + printf("%s", emitRegName(id->idReg1(), attr)); + emitDispEmbMasking(id); + printf(", %s, %s", emitRegName(id->idReg2(), attr), sstr); emitDispFrameRef(id->idAddr()->iiaLclVar.lvaVarNum(), id->idAddr()->iiaLclVar.lvaOffset(), id->idDebugOnlyInfo()->idVarRefOffs, asmfm); emitDispEmbBroadcastCount(id); @@ -11652,8 +11686,9 @@ void emitter::emitDispIns( reg2 = reg3; reg3 = tmp; } - printf("%s, ", emitRegName(id->idReg1(), attr)); - printf("%s, ", emitRegName(reg2, attr)); + printf("%s", emitRegName(id->idReg1(), attr)); + emitDispEmbMasking(id); + printf(", %s, ", emitRegName(reg2, attr)); printf("%s", emitRegName(reg3, attr)); emitDispEmbRounding(id); break; @@ -11964,7 +11999,9 @@ void emitter::emitDispIns( case IF_RRW_RRD_MRD: case IF_RWR_RWR_MRD: { - printf("%s, %s, %s", emitRegName(id->idReg1(), attr), emitRegName(id->idReg2(), attr), sstr); + printf("%s", emitRegName(id->idReg1(), attr)); + emitDispEmbMasking(id); + printf(", %s, %s", emitRegName(id->idReg2(), attr), sstr); offs = emitGetInsDsp(id); emitDispClsVar(id->idAddr()->iiaFieldHnd, offs, ID_INFO_DSP_RELOC); emitDispEmbBroadcastCount(id); @@ -12918,7 +12955,7 @@ BYTE* emitter::emitOutputAM(BYTE* dst, instrDesc* id, code_t code, CnsVal* addc) #else dst += emitOutputLong(dst, dsp); #endif - if (id->idIsTlsGD()) + if (!IsAvx512OrPriorInstruction(ins) && id->idIsTlsGD()) { addlDelta = -4; emitRecordRelocationWithAddlDelta((void*)(dst - sizeof(INT32)), (void*)dsp, IMAGE_REL_TLSGD, @@ -16648,7 +16685,7 @@ size_t emitter::emitOutputInstr(insGroup* ig, instrDesc* id, BYTE** dp) } #ifdef DEBUG - if (ins == INS_call && !id->idIsTlsGD()) + if ((ins == INS_call) && !id->idIsTlsGD()) { emitRecordCallSite(emitCurCodeOffs(*dp), id->idDebugOnlyInfo()->idCallSig, (CORINFO_METHOD_HANDLE)id->idDebugOnlyInfo()->idMemCookie); diff --git a/src/coreclr/jit/emitxarch.h b/src/coreclr/jit/emitxarch.h index 1ce01cde59b02..d842f91f06a5d 100644 --- a/src/coreclr/jit/emitxarch.h +++ b/src/coreclr/jit/emitxarch.h @@ -106,7 +106,6 @@ unsigned insSSval(unsigned scale); static bool IsSSEInstruction(instruction ins); static bool IsSSEOrAVXInstruction(instruction ins); -static bool IsAvx512OrPriorInstruction(instruction ins); static bool IsAVXOnlyInstruction(instruction ins); static bool IsAvx512OnlyInstruction(instruction ins); static bool IsFMAInstruction(instruction ins); @@ -346,14 +345,39 @@ code_t AddSimdPrefixIfNeeded(const instrDesc* id, code_t code, emitAttr size) // instOptions - emit options void SetEvexBroadcastIfNeeded(instrDesc* id, insOpts instOptions) { - if ((instOptions & INS_OPTS_b_MASK) == INS_OPTS_EVEX_eb_er_rd) + if ((instOptions & INS_OPTS_EVEX_b_MASK) == INS_OPTS_EVEX_eb_er_rd) { assert(UseEvexEncoding()); id->idSetEvexbContext(instOptions); } else { - assert(instOptions == 0); + assert((instOptions & INS_OPTS_EVEX_b_MASK) == 0); + } +} + +//------------------------------------------------------------------------ +// SetEvexEmbMaskIfNeeded: set embedded mask if needed. +// +// Arguments: +// id - instruction descriptor +// instOptions - emit options +// +void SetEvexEmbMaskIfNeeded(instrDesc* id, insOpts instOptions) +{ + if ((instOptions & INS_OPTS_EVEX_aaa_MASK) != 0) + { + assert(UseEvexEncoding()); + id->idSetEvexAaaContext(instOptions); + + if ((instOptions & INS_OPTS_EVEX_z_MASK) == INS_OPTS_EVEX_em_zero) + { + id->idSetEvexZContext(); + } + } + else + { + assert((instOptions & INS_OPTS_EVEX_z_MASK) == 0); } } diff --git a/src/coreclr/jit/gentree.cpp b/src/coreclr/jit/gentree.cpp index 31afa8983c9e1..b23eb5d0eb432 100644 --- a/src/coreclr/jit/gentree.cpp +++ b/src/coreclr/jit/gentree.cpp @@ -17920,10 +17920,14 @@ bool GenTree::canBeContained() const } // It is not possible for nodes that do not produce values or that are not containable values to be contained. - if (!IsValue() || ((DebugOperKind() & DBK_NOCONTAIN) != 0) || (OperIsHWIntrinsic() && !isContainableHWIntrinsic())) + if (!IsValue() || ((DebugOperKind() & DBK_NOCONTAIN) != 0)) { return false; } + else if (OperIsHWIntrinsic() && !isContainableHWIntrinsic()) + { + return isEvexEmbeddedMaskingCompatibleHWIntrinsic(); + } return true; } @@ -19655,6 +19659,28 @@ bool GenTree::isRMWHWIntrinsic(Compiler* comp) switch (intrinsicId) { + case NI_AVX512F_BlendVariableMask: + { + GenTree* op2 = hwintrinsic->Op(2); + + if (op2->IsEmbMaskOp()) + { + GenTree* op1 = hwintrinsic->Op(1); + + if (op1->isContained()) + { + assert(op1->IsVectorZero()); + return false; + } + else + { + return true; + } + } + + return false; + } + case NI_AVX512F_Fixup: case NI_AVX512F_FixupScalar: case NI_AVX512F_VL_Fixup: @@ -19745,18 +19771,34 @@ bool GenTree::isRMWHWIntrinsic(Compiler* comp) // Return Value: // true if the intrisic node lowering instruction has an EVEX form // -bool GenTree::isEvexCompatibleHWIntrinsic() +bool GenTree::isEvexCompatibleHWIntrinsic() const { - assert(gtOper == GT_HWINTRINSIC); - -// TODO-XARCH-AVX512 remove the ReturnsPerElementMask check once K registers have been properly -// implemented in the register allocator -#if defined(TARGET_AMD64) - return HWIntrinsicInfo::HasEvexSemantics(AsHWIntrinsic()->GetHWIntrinsicId()) && + // TODO-XARCH-AVX512 remove the ReturnsPerElementMask check once K registers have been properly + // implemented in the register allocator + return OperIsHWIntrinsic() && HWIntrinsicInfo::HasEvexSemantics(AsHWIntrinsic()->GetHWIntrinsicId()) && !HWIntrinsicInfo::ReturnsPerElementMask(AsHWIntrinsic()->GetHWIntrinsicId()); -#else +} + +//------------------------------------------------------------------------ +// isEvexEmbeddedMaskingCompatibleHWIntrinsic: Checks if the intrinsic is compatible +// with the EVEX embedded masking form for its intended lowering instruction. +// +// Return Value: +// true if the intrisic node lowering instruction has an EVEX embedded masking +// +bool GenTree::isEvexEmbeddedMaskingCompatibleHWIntrinsic() const +{ +#if defined(TARGET_XARCH) + if (OperIsHWIntrinsic()) + { + // TODO-AVX512F-CQ: Expand this to the full set of APIs and make it table driven + // using IsEmbMaskingCompatible. For now, however, limit it to some explicit ids + // for prototyping purposes. + return (AsHWIntrinsic()->GetHWIntrinsicId() == NI_AVX512F_Add); + } +#endif // TARGET_XARCH + return false; -#endif } GenTreeHWIntrinsic* Compiler::gtNewSimdHWIntrinsicNode(var_types type, diff --git a/src/coreclr/jit/gentree.h b/src/coreclr/jit/gentree.h index 7e7883d9ef594..89e924220bdd0 100644 --- a/src/coreclr/jit/gentree.h +++ b/src/coreclr/jit/gentree.h @@ -556,6 +556,10 @@ enum GenTreeFlags : unsigned int GTF_MDARRLEN_NONFAULTING = 0x20000000, // GT_MDARR_LENGTH -- An MD array length operation that cannot fault. Same as GT_IND_NONFAULTING. GTF_MDARRLOWERBOUND_NONFAULTING = 0x20000000, // GT_MDARR_LOWER_BOUND -- An MD array lower bound operation that cannot fault. Same as GT_IND_NONFAULTING. + +#if defined(TARGET_XARCH) && defined(FEATURE_HW_INTRINSICS) + GTF_HW_EM_OP = 0x10000000, // GT_HWINTRINSIC -- node is used as an operand to an embedded mask +#endif // TARGET_XARCH && FEATURE_HW_INTRINSICS }; inline constexpr GenTreeFlags operator ~(GenTreeFlags a) @@ -1486,7 +1490,8 @@ struct GenTree bool isCommutativeHWIntrinsic() const; bool isContainableHWIntrinsic() const; bool isRMWHWIntrinsic(Compiler* comp); - bool isEvexCompatibleHWIntrinsic(); + bool isEvexCompatibleHWIntrinsic() const; + bool isEvexEmbeddedMaskingCompatibleHWIntrinsic() const; #else bool isCommutativeHWIntrinsic() const { @@ -1503,7 +1508,12 @@ struct GenTree return false; } - bool isEvexCompatibleHWIntrinsic() + bool isEvexCompatibleHWIntrinsic() const + { + return false; + } + + bool isEvexEmbeddedMaskingCompatibleHWIntrinsic() const { return false; } @@ -2232,6 +2242,22 @@ struct GenTree gtFlags &= ~GTF_ICON_HDL_MASK; } +#if defined(TARGET_XARCH) && defined(FEATURE_HW_INTRINSICS) + bool IsEmbMaskOp() + { + bool result = (gtFlags & GTF_HW_EM_OP) != 0; + assert(!result || (gtOper == GT_HWINTRINSIC)); + return result; + } + + void MakeEmbMaskOp() + { + assert(!IsEmbMaskOp()); + gtFlags |= GTF_HW_EM_OP; + } + +#endif // TARGET_XARCH && FEATURE_HW_INTRINSICS + bool IsCall() const { return OperGet() == GT_CALL; diff --git a/src/coreclr/jit/hwintrinsic.h b/src/coreclr/jit/hwintrinsic.h index dcd5c86129b74..08b1ed1f229f0 100644 --- a/src/coreclr/jit/hwintrinsic.h +++ b/src/coreclr/jit/hwintrinsic.h @@ -205,7 +205,10 @@ enum HWIntrinsicFlag : unsigned int HW_Flag_EmbBroadcastCompatible = 0x8000000, // The intrinsic is an embedded rounding compatible intrinsic - HW_Flag_EmbRoundingCompatible = 0x10000000 + HW_Flag_EmbRoundingCompatible = 0x10000000, + + // The intrinsic is an embedded masking incompatible intrinsic + HW_Flag_EmbMaskingIncompatible = 0x20000000, #endif // TARGET_XARCH }; @@ -597,6 +600,12 @@ struct HWIntrinsicInfo return (flags & HW_Flag_EmbRoundingCompatible) != 0; } + static bool IsEmbMaskingCompatible(NamedIntrinsic id) + { + HWIntrinsicFlag flags = lookupFlags(id); + return (flags & HW_Flag_EmbMaskingIncompatible) == 0; + } + static size_t EmbRoundingArgPos(NamedIntrinsic id) { // This helper function returns the expected position, diff --git a/src/coreclr/jit/hwintrinsiccodegenxarch.cpp b/src/coreclr/jit/hwintrinsiccodegenxarch.cpp index ec936c349c8f2..6e85d98fea122 100644 --- a/src/coreclr/jit/hwintrinsiccodegenxarch.cpp +++ b/src/coreclr/jit/hwintrinsiccodegenxarch.cpp @@ -108,7 +108,7 @@ static insOpts AddEmbRoundingMode(insOpts instOptions, int8_t mode) // .NET doesn't support raising IEEE 754 floating-point exceptions, // we simplify the handling below to only consider the 2-bits of RC. - assert((instOptions & INS_OPTS_b_MASK) == 0); + assert((instOptions & INS_OPTS_EVEX_b_MASK) == 0); unsigned result = static_cast(instOptions); switch (mode & 0x03) @@ -140,6 +140,35 @@ static insOpts AddEmbRoundingMode(insOpts instOptions, int8_t mode) return static_cast(result); } +//------------------------------------------------------------------------ +// AddEmbMaskingMode: Adds the embedded masking mode to the insOpts +// +// Arguments: +// instOptions - The existing insOpts +// maskReg - The register to use for the embedded mask +// mergeWithZero - true if the mask merges with zero; otherwise, false +// +// Return Value: +// The modified insOpts +// +static insOpts AddEmbMaskingMode(insOpts instOptions, regNumber maskReg, bool mergeWithZero) +{ + assert((instOptions & INS_OPTS_EVEX_aaa_MASK) == 0); + assert((instOptions & INS_OPTS_EVEX_z_MASK) == 0); + + unsigned result = static_cast(instOptions); + unsigned em_k = (maskReg - KBASE) << 2; + unsigned em_z = mergeWithZero ? INS_OPTS_EVEX_em_zero : 0; + + assert(emitter::isMaskReg(maskReg)); + assert((em_k & INS_OPTS_EVEX_aaa_MASK) == em_k); + + result |= em_k; + result |= em_z; + + return static_cast(result); +} + //------------------------------------------------------------------------ // genHWIntrinsic: Generates the code for a given hardware intrinsic node. // @@ -152,6 +181,7 @@ void CodeGen::genHWIntrinsic(GenTreeHWIntrinsic* node) CORINFO_InstructionSet isa = HWIntrinsicInfo::lookupIsa(intrinsicId); HWIntrinsicCategory category = HWIntrinsicInfo::lookupCategory(intrinsicId); size_t numArgs = node->GetOperandCount(); + GenTree* embMaskOp = nullptr; // We need to validate that other phases of the compiler haven't introduced unsupported intrinsics assert(compiler->compIsaSupportedDebugOnly(isa)); @@ -162,6 +192,67 @@ void CodeGen::genHWIntrinsic(GenTreeHWIntrinsic* node) if (GetEmitter()->UseEvexEncoding()) { + if (numArgs == 3) + { + GenTree* op2 = node->Op(2); + + if (op2->IsEmbMaskOp()) + { + assert(intrinsicId == NI_AVX512F_BlendVariableMask); + assert(op2->isContained()); + assert(op2->OperIsHWIntrinsic()); + + // We currently only support this for table driven intrinsics + assert(isTableDriven); + + GenTree* op1 = node->Op(1); + GenTree* op3 = node->Op(3); + + regNumber targetReg = node->GetRegNum(); + regNumber mergeReg = op1->GetRegNum(); + regNumber maskReg = op3->GetRegNum(); + + // TODO-AVX512-CQ: Ensure we can support embedded operations on RMW intrinsics + assert(!op2->isRMWHWIntrinsic(compiler)); + + bool mergeWithZero = op1->isContained(); + + if (mergeWithZero) + { + // We're merging with zero, so we the target register isn't RMW + assert(op1->IsVectorZero()); + mergeWithZero = true; + } + else + { + // We're merging with a non-zero value, so the target register is RMW + emitAttr attr = emitActualTypeSize(Compiler::getSIMDTypeForSize(node->GetSimdSize())); + GetEmitter()->emitIns_Mov(INS_movaps, attr, targetReg, mergeReg, /* canSkip */ true); + } + + // Update op2 to use the actual target register + op2->ClearContained(); + op2->SetRegNum(targetReg); + + // Fixup all the already initialized variables + node = op2->AsHWIntrinsic(); + intrinsicId = node->GetHWIntrinsicId(); + isa = HWIntrinsicInfo::lookupIsa(intrinsicId); + category = HWIntrinsicInfo::lookupCategory(intrinsicId); + numArgs = node->GetOperandCount(); + + // Add the embedded masking info to the insOpts + instOptions = AddEmbMaskingMode(instOptions, maskReg, mergeWithZero); + + // We don't need to genProduceReg(node) since that will be handled by processing op2 + // likewise, processing op2 will ensure its own registers are consumed + + // Make sure we consume the registers that are getting specially handled + genConsumeReg(op1); + embMaskOp = op3; + } + } + if (HWIntrinsicInfo::IsEmbRoundingCompatible(intrinsicId)) { assert(isTableDriven); @@ -562,10 +653,20 @@ void CodeGen::genHWIntrinsic(GenTreeHWIntrinsic* node) break; } + if (embMaskOp != nullptr) + { + // Handle an extra operand we need to consume so that + // embedded masking can work without making the overall + // logic significantly more complex. + genConsumeReg(embMaskOp); + } + genProduceReg(node); return; } + assert(embMaskOp == nullptr); + switch (isa) { case InstructionSet_Vector128: diff --git a/src/coreclr/jit/instr.cpp b/src/coreclr/jit/instr.cpp index 80d326bbcf2a0..caee21bff8d86 100644 --- a/src/coreclr/jit/instr.cpp +++ b/src/coreclr/jit/instr.cpp @@ -1223,7 +1223,7 @@ bool CodeGenInterface::IsEmbeddedBroadcastEnabled(instruction ins, GenTree* op) // static insOpts AddEmbBroadcastMode(insOpts instOptions) { - assert((instOptions & INS_OPTS_b_MASK) == 0); + assert((instOptions & INS_OPTS_EVEX_b_MASK) == 0); unsigned result = static_cast(instOptions); return static_cast(result | INS_OPTS_EVEX_eb_er_rd); } diff --git a/src/coreclr/jit/instr.h b/src/coreclr/jit/instr.h index 0f5d8eed5eb71..700ab55fbac8c 100644 --- a/src/coreclr/jit/instr.h +++ b/src/coreclr/jit/instr.h @@ -84,6 +84,26 @@ enum instruction : uint32_t INS_count = INS_none }; +//------------------------------------------------------------------------ +// IsAvx512OrPriorInstruction: Is this an Avx512 or Avx or Sse or K (opmask) instruction. +// Technically, K instructions would be considered under the VEX encoding umbrella, but due to +// the instruction table encoding had to be pulled out with the rest of the `INST5` definitions. +// +// Arguments: +// ins - The instruction to check. +// +// Returns: +// `true` if it is a sse or avx or avx512 instruction. +// +inline bool IsAvx512OrPriorInstruction(instruction ins) +{ +#if defined(TARGET_XARCH) + return (ins >= INS_FIRST_SSE_INSTRUCTION) && (ins <= INS_LAST_AVX512_INSTRUCTION); +#else + return false; +#endif // TARGET_XARCH +} + /*****************************************************************************/ enum insUpdateModes @@ -205,13 +225,36 @@ enum insOpts: unsigned { INS_OPTS_NONE = 0, - INS_OPTS_EVEX_eb_er_rd = 1, // Embedded Broadcast or Round down + // Two-bits: 0b0000_0011 + INS_OPTS_EVEX_b_MASK = 0x03, // mask for EVEX.b related features. + + INS_OPTS_EVEX_eb_er_rd = 1, // Embedded Broadcast or Round down + + INS_OPTS_EVEX_er_ru = 2, // Round up + + INS_OPTS_EVEX_er_rz = 3, // Round towards zero + + // Two-bits: 0b0001_1100 + INS_OPTS_EVEX_aaa_MASK = 0x1C, // mask for EVEX.aaa related features + + INS_OPTS_EVEX_em_k1 = 1 << 2, // Embedded mask uses K1 + + INS_OPTS_EVEX_em_k2 = 2 << 2, // Embedded mask uses K2 + + INS_OPTS_EVEX_em_k3 = 3 << 2, // Embedded mask uses K3 + + INS_OPTS_EVEX_em_k4 = 4 << 2, // Embedded mask uses K4 + + INS_OPTS_EVEX_em_k5 = 5 << 2, // Embedded mask uses K5 + + INS_OPTS_EVEX_em_k6 = 6 << 2, // Embedded mask uses K6 - INS_OPTS_EVEX_er_ru = 2, // Round up + INS_OPTS_EVEX_em_k7 = 7 << 2, // Embedded mask uses K7 - INS_OPTS_EVEX_er_rz = 3, // Round towards zero + // One-bit: 0b0010_0000 + INS_OPTS_EVEX_z_MASK = 0x20, // mask for EVEX.z related features - INS_OPTS_b_MASK = (INS_OPTS_EVEX_eb_er_rd | INS_OPTS_EVEX_er_ru | INS_OPTS_EVEX_er_rz), // mask for Evex.b related features. + INS_OPTS_EVEX_em_zero, // Embedded mask merges with zero }; #elif defined(TARGET_ARM) || defined(TARGET_ARM64) || defined(TARGET_LOONGARCH64) || defined(TARGET_RISCV64) diff --git a/src/coreclr/jit/lowerxarch.cpp b/src/coreclr/jit/lowerxarch.cpp index ff9cd371570fa..d58854f857258 100644 --- a/src/coreclr/jit/lowerxarch.cpp +++ b/src/coreclr/jit/lowerxarch.cpp @@ -9709,8 +9709,69 @@ void Lowering::ContainCheckHWIntrinsic(GenTreeHWIntrinsic* node) case NI_SSE41_BlendVariable: case NI_AVX_BlendVariable: case NI_AVX2_BlendVariable: + { + if (IsContainableHWIntrinsicOp(node, op2, &supportsRegOptional)) + { + MakeSrcContained(node, op2); + } + else if (supportsRegOptional) + { + MakeSrcRegOptional(node, op2); + } + break; + } + case NI_AVX512F_BlendVariableMask: { + // BlendVariableMask represents one of the following instructions: + // * vblendmpd + // * vblendmps + // * vpblendmpb + // * vpblendmpd + // * vpblendmpq + // * vpblendmpw + // + // In all cases, the node operands are ordered: + // * op1: selectFalse + // * op2: selectTrue + // * op3: condition + // + // The managed API surface we expose doesn't directly support TYP_MASK + // and we don't directly expose overloads for APIs like `vaddps` which + // support embedded masking. Instead, we have decide to do pattern + // recognition over the relevant ternary select APIs which functionally + // execute `cond ? selectTrue : selectFalse` on a per element basis. + // + // To facilitate this, the mentioned ternary select APIs, such as + // ConditionalSelect or TernaryLogic, with a correct control word, will + // all compile down to BlendVariableMask when the condition is of TYP_MASK. + // + // So, before we do the normal containment checks for memory operands, we + // instead want to check if `selectTrue` (op2) supports embedded masking and + // if so, we want to mark it as contained. Codegen will then see that it is + // contained and not a memory operand and know to invoke the special handling + // so that the embedded masking can work as expected. + + if (op2->isEvexEmbeddedMaskingCompatibleHWIntrinsic()) + { + uint32_t maskSize = genTypeSize(simdBaseType); + uint32_t operSize = genTypeSize(op2->AsHWIntrinsic()->GetSimdBaseType()); + + if ((maskSize == operSize) && IsInvariantInRange(op2, node)) + { + MakeSrcContained(node, op2); + op2->MakeEmbMaskOp(); + + if (op1->IsVectorZero()) + { + // When we are merging with zero, we can specialize + // and avoid instantiating the vector constant. + MakeSrcContained(node, op1); + } + break; + } + } + if (IsContainableHWIntrinsicOp(node, op2, &supportsRegOptional)) { MakeSrcContained(node, op2); diff --git a/src/coreclr/jit/lsraxarch.cpp b/src/coreclr/jit/lsraxarch.cpp index 33122e79777b9..4fc56947478aa 100644 --- a/src/coreclr/jit/lsraxarch.cpp +++ b/src/coreclr/jit/lsraxarch.cpp @@ -2500,6 +2500,52 @@ int LinearScan::BuildHWIntrinsic(GenTreeHWIntrinsic* intrinsicTree, int* pDstCou break; } + case NI_AVX512F_BlendVariableMask: + { + assert(numArgs == 3); + + if (op2->IsEmbMaskOp()) + { + // TODO-AVX512-CQ: Ensure we can support embedded operations on RMW intrinsics + assert(!op2->isRMWHWIntrinsic(compiler)); + + if (isRMW) + { + assert(!op1->isContained()); + + tgtPrefUse = BuildUse(op1); + srcCount += 1; + + assert(op2->isContained()); + + for (GenTree* operand : op2->AsHWIntrinsic()->Operands()) + { + assert(varTypeIsSIMD(operand)); + srcCount += BuildDelayFreeUses(operand, op1); + } + } + else + { + assert(op1->isContained() && op1->IsVectorZero()); + srcCount += BuildOperandUses(op1); + + assert(op2->isContained()); + + for (GenTree* operand : op2->AsHWIntrinsic()->Operands()) + { + assert(varTypeIsSIMD(operand)); + srcCount += BuildOperandUses(operand); + } + } + + assert(!op3->isContained()); + srcCount += BuildOperandUses(op3); + + buildUses = false; + } + break; + } + case NI_AVX512F_PermuteVar8x64x2: case NI_AVX512F_PermuteVar16x32x2: case NI_AVX512F_VL_PermuteVar2x64x2: