Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add the barebones support for using embedded masking with AVX512 #97675

Merged
merged 4 commits into from
Jan 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 60 additions & 8 deletions src/coreclr/jit/emit.h
Original file line number Diff line number Diff line change
Expand Up @@ -768,12 +768,26 @@ class emitter
unsigned _idLargeDsp : 1; // does a large displacement follow?
unsigned _idLargeCall : 1; // large call descriptor used

unsigned _idBound : 1; // jump target / frame offset bound
#ifndef TARGET_ARMARCH
unsigned _idCallRegPtr : 1; // IL indirect calls: addr in reg
#endif
unsigned _idTlsGD : 1; // Used to store information related to TLS GD access on linux
unsigned _idNoGC : 1; // Some helpers don't get recorded in GC tables
// We have several pieces of information we need to encode but which are only applicable
// to a subset of instrDescs. To accommodate that, we define a several _idCustom# bitfields
// and then some defineds to make accessing them simpler
Comment on lines +771 to +773
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These bits are "expensive" and impact the maximum size of "small" constants, so I opted to repurpose these existing 3 bits that are only used for IF_LABEL, IF_METHOD, and related formats. They will never conflict with the SIMD instructions so this ends up being a nice way to fit it in, IMO.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No TP impact, so this works!


unsigned _idCustom1 : 1;
unsigned _idCustom2 : 1;
unsigned _idCustom3 : 1;

#define _idBound _idCustom1 /* jump target / frame offset bound */
#define _idTlsGD _idCustom2 /* Used to store information related to TLS GD access on linux */
#define _idNoGC _idCustom3 /* Some helpers don't get recorded in GC tables */
#define _idEvexAaaContext (_idCustom3 << 2) | (_idCustom2 << 1) | _idCustom1 /* bits used for the EVEX.aaa context */

#if !defined(TARGET_ARMARCH)
unsigned _idCustom4 : 1;

#define _idCallRegPtr _idCustom4 /* IL indirect calls : addr in reg */
#define _idEvexZContext _idCustom4 /* bits used for the EVEX.z context */
#endif // !TARGET_ARMARCH

#if defined(TARGET_XARCH)
// EVEX.b can indicate several context: embedded broadcast, embedded rounding.
// For normal and embedded broadcast intrinsics, EVEX.L'L has the same semantic, vector length.
Expand Down Expand Up @@ -1578,30 +1592,36 @@ class emitter

bool idIsBound() const
{
assert(!IsAvx512OrPriorInstruction(_idIns));
return _idBound != 0;
}
void idSetIsBound()
{
assert(!IsAvx512OrPriorInstruction(_idIns));
_idBound = 1;
}

#ifndef TARGET_ARMARCH
bool idIsCallRegPtr() const
{
assert(!IsAvx512OrPriorInstruction(_idIns));
return _idCallRegPtr != 0;
}
void idSetIsCallRegPtr()
{
assert(!IsAvx512OrPriorInstruction(_idIns));
_idCallRegPtr = 1;
}
#endif
#endif // !TARGET_ARMARCH

bool idIsTlsGD() const
{
assert(!IsAvx512OrPriorInstruction(_idIns));
return _idTlsGD != 0;
}
void idSetTlsGD()
{
assert(!IsAvx512OrPriorInstruction(_idIns));
_idTlsGD = 1;
}

Expand All @@ -1610,10 +1630,12 @@ class emitter
// code, it is not necessary to generate GC info for a call so labeled.
bool idIsNoGC() const
{
assert(!IsAvx512OrPriorInstruction(_idIns));
return _idNoGC != 0;
}
void idSetIsNoGC(bool val)
{
assert(!IsAvx512OrPriorInstruction(_idIns));
_idNoGC = val;
}

Expand All @@ -1625,7 +1647,8 @@ class emitter

void idSetEvexbContext(insOpts instOptions)
{
assert(_idEvexbContext == 0);
assert(!idIsEvexbContextSet());

if (instOptions == INS_OPTS_EVEX_eb_er_rd)
{
_idEvexbContext = 1;
Expand All @@ -1648,6 +1671,34 @@ class emitter
{
return _idEvexbContext;
}

unsigned idGetEvexAaaContext() const
{
assert(IsAvx512OrPriorInstruction(_idIns));
return _idEvexAaaContext;
}

void idSetEvexAaaContext(insOpts instOptions)
{
assert(idGetEvexAaaContext() == 0);
kunalspathak marked this conversation as resolved.
Show resolved Hide resolved
unsigned value = static_cast<unsigned>((instOptions & INS_OPTS_EVEX_aaa_MASK) >> 2);

_idCustom1 = ((value >> 0) & 1);
_idCustom2 = ((value >> 1) & 1);
_idCustom3 = ((value >> 2) & 1);
}

bool idIsEvexZContextSet() const
{
assert(IsAvx512OrPriorInstruction(_idIns));
return _idEvexZContext != 0;
}

void idSetEvexZContext()
{
assert(!idIsEvexZContextSet());
tannergooding marked this conversation as resolved.
Show resolved Hide resolved
_idEvexZContext = 1;
}
#endif

#ifdef TARGET_ARMARCH
Expand Down Expand Up @@ -2222,6 +2273,7 @@ class emitter
void emitDispInsHex(instrDesc* id, BYTE* code, size_t sz);
void emitDispEmbBroadcastCount(instrDesc* id);
void emitDispEmbRounding(instrDesc* id);
void emitDispEmbMasking(instrDesc* id);
void emitDispIns(instrDesc* id,
bool isNew,
bool doffs,
Expand Down
97 changes: 67 additions & 30 deletions src/coreclr/jit/emitxarch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,23 +49,6 @@ bool emitter::IsKInstruction(instruction ins)
return (flags & KInstruction) != 0;
}

//------------------------------------------------------------------------
// IsAvx512OrPriorInstruction: Is this an Avx512 or Avx or Sse or K (opmask) instruction.
// Technically, K instructions would be considered under the VEX encoding umbrella, but due to
// the instruction table encoding had to be pulled out with the rest of the `INST5` definitions.
//
// Arguments:
// ins - The instruction to check.
//
// Returns:
// `true` if it is a sse or avx or avx512 instruction.
//
bool emitter::IsAvx512OrPriorInstruction(instruction ins)
{
// TODO-XArch-AVX512: Fix check once AVX512 instructions are added.
return ((ins >= INS_FIRST_SSE_INSTRUCTION) && (ins <= INS_LAST_AVX512_INSTRUCTION));
}

bool emitter::IsAVXOnlyInstruction(instruction ins)
{
return (ins >= INS_FIRST_AVX_INSTRUCTION) && (ins <= INS_LAST_AVX_INSTRUCTION);
Expand Down Expand Up @@ -1304,9 +1287,10 @@ bool emitter::TakesEvexPrefix(const instrDesc* id) const
#define DEFAULT_BYTE_EVEX_PREFIX 0x62F07C0800000000ULL

#define DEFAULT_BYTE_EVEX_PREFIX_MASK 0xFFFFFFFF00000000ULL
#define BBIT_IN_BYTE_EVEX_PREFIX 0x0000001000000000ULL
#define LBIT_IN_BYTE_EVEX_PREFIX 0x0000002000000000ULL
#define LPRIMEBIT_IN_BYTE_EVEX_PREFIX 0x0000004000000000ULL
#define EVEX_B_BIT 0x0000001000000000ULL
#define ZBIT_IN_BYTE_EVEX_PREFIX 0x0000008000000000ULL

//------------------------------------------------------------------------
// AddEvexPrefix: Add default EVEX prefix with only LL' bits set.
Expand Down Expand Up @@ -1344,7 +1328,7 @@ emitter::code_t emitter::AddEvexPrefix(const instrDesc* id, code_t code, emitAtt

if (id->idIsEvexbContextSet())
{
code |= EVEX_B_BIT;
code |= BBIT_IN_BYTE_EVEX_PREFIX;

if (!id->idHasMem())
{
Expand Down Expand Up @@ -1385,6 +1369,8 @@ emitter::code_t emitter::AddEvexPrefix(const instrDesc* id, code_t code, emitAtt
{
case IF_RWR_RRD_ARD_RRD:
{
assert(id->idGetEvexAaaContext() == 0);

CnsVal cnsVal;
emitGetInsAmdCns(id, &cnsVal);

Expand All @@ -1394,6 +1380,8 @@ emitter::code_t emitter::AddEvexPrefix(const instrDesc* id, code_t code, emitAtt

case IF_RWR_RRD_MRD_RRD:
{
assert(id->idGetEvexAaaContext() == 0);

CnsVal cnsVal;
emitGetInsDcmCns(id, &cnsVal);

Expand All @@ -1403,6 +1391,8 @@ emitter::code_t emitter::AddEvexPrefix(const instrDesc* id, code_t code, emitAtt

case IF_RWR_RRD_SRD_RRD:
{
assert(id->idGetEvexAaaContext() == 0);

CnsVal cnsVal;
emitGetInsCns(id, &cnsVal);

Expand All @@ -1412,12 +1402,24 @@ emitter::code_t emitter::AddEvexPrefix(const instrDesc* id, code_t code, emitAtt

case IF_RWR_RRD_RRD_RRD:
{
assert(id->idGetEvexAaaContext() == 0);
maskReg = id->idReg4();
break;
}

default:
{
unsigned aaaContext = id->idGetEvexAaaContext();

if (aaaContext != 0)
{
maskReg = static_cast<regNumber>(aaaContext + KBASE);

if (id->idIsEvexZContextSet())
{
code |= ZBIT_IN_BYTE_EVEX_PREFIX;
}
}
break;
}
}
Expand Down Expand Up @@ -4170,9 +4172,8 @@ UNATIVE_OFFSET emitter::emitInsSizeAM(instrDesc* id, code_t code)
}

// If this is just "call reg", we're done.
if (id->idIsCallRegPtr())
if (((ins == INS_call) || (ins == INS_tail_i_jmp)) && id->idIsCallRegPtr())
{
assert(ins == INS_call || ins == INS_tail_i_jmp);
assert(dsp == 0);
return size;
}
Expand Down Expand Up @@ -6822,7 +6823,9 @@ void emitter::emitIns_R_R_A(
id->idIns(ins);
id->idReg1(reg1);
id->idReg2(reg2);

SetEvexBroadcastIfNeeded(id, instOptions);
SetEvexEmbMaskIfNeeded(id, instOptions);
kunalspathak marked this conversation as resolved.
Show resolved Hide resolved

emitHandleMemOp(indir, id, (ins == INS_mulx) ? IF_RWR_RWR_ARD : emitInsModeFormat(ins, IF_RRD_RRD_ARD), ins);

Expand Down Expand Up @@ -6947,7 +6950,9 @@ void emitter::emitIns_R_R_C(instruction ins,
id->idReg1(reg1);
id->idReg2(reg2);
id->idAddr()->iiaFieldHnd = fldHnd;

SetEvexBroadcastIfNeeded(id, instOptions);
SetEvexEmbMaskIfNeeded(id, instOptions);

UNATIVE_OFFSET sz = emitInsSizeCV(id, insCodeRM(ins));
id->idCodeSize(sz);
Expand All @@ -6974,12 +6979,13 @@ void emitter::emitIns_R_R_R(
id->idReg2(reg1);
id->idReg3(reg2);

if ((instOptions & INS_OPTS_b_MASK) != INS_OPTS_NONE)
if ((instOptions & INS_OPTS_EVEX_b_MASK) != 0)
{
// if EVEX.b needs to be set in this path, then it should be embedded rounding.
assert(UseEvexEncoding());
id->idSetEvexbContext(instOptions);
}
SetEvexEmbMaskIfNeeded(id, instOptions);

UNATIVE_OFFSET sz = emitInsSizeRR(id, insCodeRM(ins));
id->idCodeSize(sz);
Expand All @@ -7001,7 +7007,9 @@ void emitter::emitIns_R_R_S(
id->idReg1(reg1);
id->idReg2(reg2);
id->idAddr()->iiaLclVar.initLclVarAddr(varx, offs);

SetEvexBroadcastIfNeeded(id, instOptions);
SetEvexEmbMaskIfNeeded(id, instOptions);

#ifdef DEBUG
id->idDebugOnlyInfo()->idVarRefOffs = emitVarRefOffs;
Expand Down Expand Up @@ -10785,6 +10793,28 @@ void emitter::emitDispEmbRounding(instrDesc* id)
}
}

// emitDispEmbMasking: Display the tag where embedded masking is activated
//
// Arguments:
// id - The instruction descriptor
//
void emitter::emitDispEmbMasking(instrDesc* id)
{
regNumber maskReg = static_cast<regNumber>(id->idGetEvexAaaContext() + KBASE);

if (maskReg == REG_K0)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

K0 is special and basically means "don't mask"

{
return;
}

printf(" {%s}", emitRegName(maskReg));

if (id->idIsEvexZContextSet())
{
printf(" {z}");
}
}

//--------------------------------------------------------------------
// emitDispIns: Dump the given instruction to jitstdout.
//
Expand Down Expand Up @@ -11033,7 +11063,7 @@ void emitter::emitDispIns(
case IF_AWR:
case IF_ARW:
{
if (id->idIsCallRegPtr())
if (((ins == INS_call) || (ins == INS_tail_i_jmp)) && id->idIsCallRegPtr())
{
printf("%s", emitRegName(id->idAddr()->iiaAddrMode.amBaseReg));
}
Expand Down Expand Up @@ -11184,7 +11214,9 @@ void emitter::emitDispIns(
case IF_RRW_RRD_ARD:
case IF_RWR_RWR_ARD:
{
printf("%s, %s, %s", emitRegName(id->idReg1(), attr), emitRegName(id->idReg2(), attr), sstr);
printf("%s", emitRegName(id->idReg1(), attr));
emitDispEmbMasking(id);
printf(", %s, %s", emitRegName(id->idReg2(), attr), sstr);
emitDispAddrMode(id);
emitDispEmbBroadcastCount(id);
break;
Expand Down Expand Up @@ -11458,7 +11490,9 @@ void emitter::emitDispIns(
case IF_RRW_RRD_SRD:
case IF_RWR_RWR_SRD:
{
printf("%s, %s, %s", emitRegName(id->idReg1(), attr), emitRegName(id->idReg2(), attr), sstr);
printf("%s", emitRegName(id->idReg1(), attr));
emitDispEmbMasking(id);
printf(", %s, %s", emitRegName(id->idReg2(), attr), sstr);
emitDispFrameRef(id->idAddr()->iiaLclVar.lvaVarNum(), id->idAddr()->iiaLclVar.lvaOffset(),
id->idDebugOnlyInfo()->idVarRefOffs, asmfm);
emitDispEmbBroadcastCount(id);
Expand Down Expand Up @@ -11652,8 +11686,9 @@ void emitter::emitDispIns(
reg2 = reg3;
reg3 = tmp;
}
printf("%s, ", emitRegName(id->idReg1(), attr));
printf("%s, ", emitRegName(reg2, attr));
printf("%s", emitRegName(id->idReg1(), attr));
emitDispEmbMasking(id);
printf(", %s, ", emitRegName(reg2, attr));
printf("%s", emitRegName(reg3, attr));
emitDispEmbRounding(id);
break;
Expand Down Expand Up @@ -11964,7 +11999,9 @@ void emitter::emitDispIns(
case IF_RRW_RRD_MRD:
case IF_RWR_RWR_MRD:
{
printf("%s, %s, %s", emitRegName(id->idReg1(), attr), emitRegName(id->idReg2(), attr), sstr);
printf("%s", emitRegName(id->idReg1(), attr));
emitDispEmbMasking(id);
printf(", %s, %s", emitRegName(id->idReg2(), attr), sstr);
offs = emitGetInsDsp(id);
emitDispClsVar(id->idAddr()->iiaFieldHnd, offs, ID_INFO_DSP_RELOC);
emitDispEmbBroadcastCount(id);
Expand Down Expand Up @@ -12918,7 +12955,7 @@ BYTE* emitter::emitOutputAM(BYTE* dst, instrDesc* id, code_t code, CnsVal* addc)
#else
dst += emitOutputLong(dst, dsp);
#endif
if (id->idIsTlsGD())
if (!IsAvx512OrPriorInstruction(ins) && id->idIsTlsGD())
{
addlDelta = -4;
emitRecordRelocationWithAddlDelta((void*)(dst - sizeof(INT32)), (void*)dsp, IMAGE_REL_TLSGD,
Expand Down Expand Up @@ -16648,7 +16685,7 @@ size_t emitter::emitOutputInstr(insGroup* ig, instrDesc* id, BYTE** dp)
}

#ifdef DEBUG
if (ins == INS_call && !id->idIsTlsGD())
if ((ins == INS_call) && !id->idIsTlsGD())
{
emitRecordCallSite(emitCurCodeOffs(*dp), id->idDebugOnlyInfo()->idCallSig,
(CORINFO_METHOD_HANDLE)id->idDebugOnlyInfo()->idMemCookie);
Expand Down
Loading
Loading