Skip to content

Commit

Permalink
Merge pull request #6904 from Akira1Saitoh/aarch64ArraycmpLDP
Browse files Browse the repository at this point in the history
AArch64: Implement arraycmp evaluator
  • Loading branch information
knn-k authored Mar 2, 2023
2 parents 53b3b1b + 7c991a6 commit 2952aea
Show file tree
Hide file tree
Showing 2 changed files with 360 additions and 4 deletions.
8 changes: 8 additions & 0 deletions compiler/aarch64/codegen/OMRCodeGenerator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,14 @@ OMR::ARM64::CodeGenerator::initialize()
// Enable compaction of local stack slots. i.e. variables with non-overlapping live ranges
// can share the same slot.
cg->setSupportsCompactedLocals();
if (!TR::Compiler->om.canGenerateArraylets())
{
static const bool disableArrayCmp = feGetEnv("TR_aarch64DisableArrayCmp") != NULL;
if (!disableArrayCmp)
{
cg->setSupportsArrayCmp();
}
}
}

void
Expand Down
356 changes: 352 additions & 4 deletions compiler/aarch64/codegen/OMRTreeEvaluator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5624,10 +5624,358 @@ OMR::ARM64::TreeEvaluator::arraysetEvaluator(TR::Node *node, TR::CodeGenerator *

TR::Register *
OMR::ARM64::TreeEvaluator::arraycmpEvaluator(TR::Node *node, TR::CodeGenerator *cg)
{
// TODO:ARM64: Enable TR::TreeEvaluator::arraycmpEvaluator in compiler/aarch64/codegen/TreeEvaluatorTable.hpp when Implemented.
return OMR::ARM64::TreeEvaluator::unImpOpEvaluator(node, cg);
}
{
/*
* Generating following instruction sequence
*
* ; arrayLen case
* mov resultReg, lengthReg
* ; non-arrayLen case
* mov resultReg, #0
*
* cmp src1Reg, src2Reg
* if !(length is constant and length > 15) {
* ccmp lengthReg, #0, #4, ne ; Sets Z flag if src1Reg and src2Reg are the same or length is 0
* }
* b.eq LDONE
* ; arrayLen case
* mov savedSrc1Reg, src1Reg
*
* if !(length is constant and length > 15) {
* cmp lengthReg, #16
* b.cc LessThan16
* }
* sub lengthReg, lengthReg, #16
* ; Main loop reads 16 bytes from each array using ldp instruction and if any mismatch found, branches to LnotEqual16
* loop16:
* ldp data1Reg, data3Reg, [src1Reg], #16
* ldp data2Reg, data4Reg, [src2Reg], #16
* ; arrayLen case
* subs data1Reg, data1Reg, data2Reg
* ; non-arrayLen case
* cmp data1Reg, data2Reg
*
* ccmp data3Reg, data4Reg, 0, eq
* b.ne Lnotequal16
* sub lengthReg, lengthReg, #16
* b.cs loop16
* if (length is constant and length > 15) {
* cmn lengthReg, #16
* ; arrayLen case
* b LDONE0
* ; non-arrayLen case
* b LDONE
*
* add src1Reg, src1Reg, lengthReg ; src1Reg points to 16bytes before the end
* add src2Reg, src2Reg, lengthReg
* mov lengthReg, #0
* b loop16
* } else {
* add lengthReg, lengthReg, #16
* ; arrayLen case
* cbz lengthReg, LDONE0
* ; non-arrayLen case
* cbz lengthReg, LDONE
*
* b lessThan16
* }
*
* Lnotequal16: ; src1Reg points 16-byte ahead of the location where data1reg was read.
* ; arrayLen case
* eor data3Reg, data3Reg, data4Reg
* cmp data1Reg, #0
* csel data1Reg, data1Reg, data3Reg, ne
* mov data2Reg, #1
* cinc data2Reg, data2Reg, ne
* sub src1Reg, src1Reg, data2Reg, lsl #3 ; Adjusts post incremented address
* rbit data1Reg, data1Reg
* clz data1Reg, data1Reg
* add src1Reg, src1Reg, data1Reg, lsr #3
*
* ; non-arrayLen case
* cmp data1Reg, data2Reg
* csel data1Reg, data1Reg, data3Reg, ne
* csel data2Reg, data2Reg, data4Reg, ne
* rev64 data1Reg, data1Reg
* rev64 data2Reg, data2Reg
*
* if !(length is constant and length > 15) {
* b LDONE0
*
* LessThan16:
* byteloop:
* subs lengthReg, lengthReg, #1
* ldrb data1Reg, [src1Reg], #1
* ldrb data2Reg, [src2Reg], #1
* ccmp data1Reg, data2Reg, #0, hi
* b.eq byteloop
* ; arrayLen case
* cmp data1Reg, data2Reg
* cset offReg, ne
* sub src1Reg, src1Reg, offReg
* }
* ; arrayLen case
* LDONE0:
* sub resultReg, src1Reg, savedSrc1Reg
* ; non-arrayLen case
* LDONE0:
* cmp data1Reg, data2Reg
* cset resultReg, ne
* cinc resultReg, resultReg, hi ; Returns 1 or 2
*
* LDONE:
*/
TR::Node *src1Node = node->getFirstChild();
TR::Node *src2Node = node->getSecondChild();
TR::Node *lengthNode = node->getThirdChild();
bool isLengthGreaterThan15 = lengthNode->getOpCode().isLoadConst() && lengthNode->getConstValue() > 15;
const bool isArrayCmpLen = node->isArrayCmpLen();
TR_ARM64ScratchRegisterManager *srm = cg->generateScratchRegisterManager(12);

TR::Register *savedSrc1Reg = cg->evaluate(src1Node);
TR::Register *src1Reg;
if ((src1Node->getReferenceCount() > 1) || isArrayCmpLen)
{
src1Reg = srm->findOrCreateScratchRegister();
generateMovInstruction(cg, node, src1Reg, savedSrc1Reg);
}
else
{
src1Reg = savedSrc1Reg;
}
TR::Register *src2Reg = cg->gprClobberEvaluate(src2Node);
TR::Register *lengthReg = cg->gprClobberEvaluate(lengthNode);
TR::Register *resultReg = cg->allocateRegister();
TR::LabelSymbol *startLabel = generateLabelSymbol(cg);
TR::LabelSymbol *doneLabel = generateLabelSymbol(cg);
TR_Debug *debugObj = cg->getDebug();

startLabel->setStartInternalControlFlow();
doneLabel->setEndInternalControlFlow();

generateLabelInstruction(cg, TR::InstOpCode::label, node, startLabel);
if (isArrayCmpLen)
{
generateMovInstruction(cg, node, resultReg, lengthReg, false);
}
else
{
loadConstant32(cg, node, 0, resultReg);
}
generateCompareInstruction(cg, node, src1Reg, src2Reg, true);
if (!isLengthGreaterThan15)
{
auto ccmpLengthInstr = generateConditionalCompareImmInstruction(cg, node, lengthReg, 0, 4, TR::CC_NE); /* 4 for Z flag */
if (debugObj)
{
debugObj->addInstructionComment(ccmpLengthInstr, "Compares lengthReg with 0 if src1 and src2 are not the same array. Otherwise, sets EQ flag.");
}
}
auto branchToDoneLabelInstr = generateConditionalBranchInstruction(cg, TR::InstOpCode::b_cond, node, doneLabel, TR::CC_EQ);
if (debugObj)
{
debugObj->addInstructionComment(branchToDoneLabelInstr, "Done if src1 and src2 are the same array or length is 0.");
}

TR::LabelSymbol *notEqual16Label = generateLabelSymbol(cg);

TR::LabelSymbol *done0Label = generateLabelSymbol(cg);
TR::LabelSymbol *lessThan16Label = generateLabelSymbol(cg);
TR::Register *data1Reg = srm->findOrCreateScratchRegister();
TR::Register *data2Reg = srm->findOrCreateScratchRegister();
TR::Register *data3Reg = srm->findOrCreateScratchRegister();
TR::Register *data4Reg = srm->findOrCreateScratchRegister();
if (!isLengthGreaterThan15)
{
generateCompareImmInstruction(cg, node, lengthReg, 16);
auto branchToLessThan16LabelInstr = generateConditionalBranchInstruction(cg, TR::InstOpCode::b_cond, node, lessThan16Label, TR::CC_CC);
if (debugObj)
{
debugObj->addInstructionComment(branchToLessThan16LabelInstr, "Jumps to lessThan16Label if length < 16.");
}
}
generateTrg1Src1ImmInstruction(cg, TR::InstOpCode::subimmw, node, lengthReg, lengthReg, 16);

TR::LabelSymbol *loop16Label = generateLabelSymbol(cg);
{
auto loop16LabelInstr = generateLabelInstruction(cg, TR::InstOpCode::label, node, loop16Label);
generateTrg2MemInstruction(cg, TR::InstOpCode::ldppostx, node, data1Reg, data3Reg, TR::MemoryReference::createWithDisplacement(cg, src1Reg, 16));
generateTrg2MemInstruction(cg, TR::InstOpCode::ldppostx, node, data2Reg, data4Reg, TR::MemoryReference::createWithDisplacement(cg, src2Reg, 16));
if (isArrayCmpLen)
{
generateTrg1Src2Instruction(cg, TR::InstOpCode::subsx, node, data1Reg, data1Reg, data2Reg);
}
else
{
generateCompareInstruction(cg, node, data1Reg, data2Reg, true);
}
generateConditionalCompareInstruction(cg, node, data3Reg, data4Reg, 0, TR::CC_EQ, true);
auto branchToNotEqual16LabelInstr2 = generateConditionalBranchInstruction(cg, TR::InstOpCode::b_cond, node, notEqual16Label, TR::CC_NE);
auto subtractLengthInstr = generateTrg1Src1ImmInstruction(cg, isLengthGreaterThan15 ? TR::InstOpCode::subsimmx : TR::InstOpCode::subsimmw, node, lengthReg, lengthReg, 16);
auto branchBacktoLoop16LabelInstr = generateConditionalBranchInstruction(cg, TR::InstOpCode::b_cond, node, loop16Label, TR::CC_CS);
if (debugObj)
{
debugObj->addInstructionComment(loop16LabelInstr, "loop16Label");
debugObj->addInstructionComment(branchToNotEqual16LabelInstr2, "Jumps to notEqual16Label if mismatch is found in the 16-byte data");
debugObj->addInstructionComment(branchBacktoLoop16LabelInstr, "Jumps to loop16Label if the remaining length >= 16 and no mismatch is found so far.");
if (isLengthGreaterThan15)
{
debugObj->addInstructionComment(subtractLengthInstr, "Treats length reg as a 64-bit reg as it is used as the 2nd source reg for 64-bit add later.");
}
}
}
if (isLengthGreaterThan15)
{
generateCompareImmInstruction(cg, node, lengthReg, -16, true);
auto branchToDoneLabelInstr3 = generateConditionalBranchInstruction(cg, TR::InstOpCode::b_cond, node, isArrayCmpLen ? done0Label : doneLabel, TR::CC_EQ);
auto adjustSrc1RegInstr = generateTrg1Src2Instruction(cg, TR::InstOpCode::addx, node, src1Reg, src1Reg, lengthReg);
generateTrg1Src2Instruction(cg, TR::InstOpCode::addx, node, src2Reg, src2Reg, lengthReg);
loadConstant32(cg, node, 0, lengthReg);
auto branchBacktoLoop16LabelInstr = generateLabelInstruction(cg, TR::InstOpCode::b, node, loop16Label);
if (debugObj)
{
if (isArrayCmpLen)
{
debugObj->addInstructionComment(branchToDoneLabelInstr3, "Jumps to done0Label if the remaining length is 0.");
}
else
{
debugObj->addInstructionComment(branchToDoneLabelInstr3, "Jumps to doneLabel if the remaining length is 0.");
}
debugObj->addInstructionComment(adjustSrc1RegInstr, "Adjusts src registers so that they point to 16bytes before the end");
debugObj->addInstructionComment(branchBacktoLoop16LabelInstr, "Jumps to loop16Label.");
}
}
else
{
generateTrg1Src1ImmInstruction(cg, TR::InstOpCode::addimmw, node, lengthReg, lengthReg, 16);
auto branchToDoneLabelInstr3 = generateCompareBranchInstruction(cg, TR::InstOpCode::cbzw, node, lengthReg, isArrayCmpLen ? done0Label : doneLabel);
auto branchToLessThan16Label2 = generateLabelInstruction(cg, TR::InstOpCode::b, node, lessThan16Label);

if (debugObj)
{
if (isArrayCmpLen)
{
debugObj->addInstructionComment(branchToDoneLabelInstr3, "Jumps to done0Label if the remaining length is 0.");
}
else
{
debugObj->addInstructionComment(branchToDoneLabelInstr3, "Jumps to doneLabel if the remaining length is 0.");
}
debugObj->addInstructionComment(branchToLessThan16Label2, "Jumps to lessThan16Label");
}
}

auto notEqual16LabelInstr = generateLabelInstruction(cg, TR::InstOpCode::label, node, notEqual16Label);
if (debugObj)
{
debugObj->addInstructionComment(notEqual16LabelInstr, "notEqual16Label. src register points 16-byte ahead of the location where the data in the registers was read.");
}

if (isArrayCmpLen)
{
generateTrg1Src2Instruction(cg, TR::InstOpCode::eorx, node, data3Reg, data3Reg, data4Reg);
generateCompareImmInstruction(cg, node, data1Reg, 0, true);
generateCondTrg1Src2Instruction(cg, TR::InstOpCode::cselx, node, data1Reg, data1Reg, data3Reg, TR::CC_NE);
loadConstant32(cg, node, 1, data2Reg);
auto getOffsetInstr = generateCIncInstruction(cg, node, data2Reg, data2Reg, TR::CC_NE, true);
auto adjustingBaseAddrInstr = generateTrg1Src2ShiftedInstruction(cg, TR::InstOpCode::subx, node, src1Reg, src1Reg, data2Reg, TR::SH_LSL, 3);
generateTrg1Src1Instruction(cg, TR::InstOpCode::rbitx, node, data1Reg, data1Reg);
auto getMismatchLocationInstr = generateTrg1Src1Instruction(cg, TR::InstOpCode::clzx, node, data1Reg, data1Reg);
generateTrg1Src2ShiftedInstruction(cg, TR::InstOpCode::addx, node, src1Reg, src1Reg, data1Reg, TR::SH_LSR, 3);
if (debugObj)
{
debugObj->addInstructionComment(getOffsetInstr, "register has 1 if mismatch is in the first 8-byte data. Otherwise it has 2.");
debugObj->addInstructionComment(adjustingBaseAddrInstr, "Adjusts base register so that it points to the 8-byte boundary before the mismatch.");
debugObj->addInstructionComment(getMismatchLocationInstr, "Gets the bit position of the first mismatch.");
}
}
else
{
generateCompareInstruction(cg, node, data1Reg, data2Reg, true);
generateCondTrg1Src2Instruction(cg, TR::InstOpCode::cselx, node, data1Reg, data1Reg, data3Reg, TR::CC_NE);
generateCondTrg1Src2Instruction(cg, TR::InstOpCode::cselx, node, data2Reg, data2Reg, data4Reg, TR::CC_NE);
generateTrg1Src1Instruction(cg, TR::InstOpCode::revx, node, data1Reg, data1Reg);
generateTrg1Src1Instruction(cg, TR::InstOpCode::revx, node, data2Reg, data2Reg);
}
srm->reclaimScratchRegister(data3Reg);
srm->reclaimScratchRegister(data4Reg);

if (!isLengthGreaterThan15)
{
auto branchToDone0LabelInstr = generateLabelInstruction(cg, TR::InstOpCode::b, node, done0Label);

auto lessThan16LabelInstr = generateLabelInstruction(cg, TR::InstOpCode::label, node, lessThan16Label);
generateTrg1Src1ImmInstruction(cg, TR::InstOpCode::subsimmw, node, lengthReg, lengthReg, 1);
generateTrg1MemInstruction(cg, TR::InstOpCode::ldrbpost, node, data1Reg, TR::MemoryReference::createWithDisplacement(cg, src1Reg, 1));
generateTrg1MemInstruction(cg, TR::InstOpCode::ldrbpost, node, data2Reg, TR::MemoryReference::createWithDisplacement(cg, src2Reg, 1));
generateConditionalCompareInstruction(cg, node, data1Reg, data2Reg, 0, TR::CC_HI);
auto branchBacktoLessThan16LabelInstr = generateConditionalBranchInstruction(cg, TR::InstOpCode::b_cond, node, lessThan16Label, TR::CC_EQ);
if (debugObj)
{
debugObj->addInstructionComment(branchToDone0LabelInstr, "Jumps to done0Label.");
debugObj->addInstructionComment(lessThan16LabelInstr, "lessThan16Label");
debugObj->addInstructionComment(branchBacktoLessThan16LabelInstr, "Jumps to lessThan16Label (byteloop) if the remaining length > 0 and no mismatch is found");
}
if (isArrayCmpLen)
{
generateCompareInstruction(cg, node, data1Reg, data2Reg);

TR::Register *offReg = srm->findOrCreateScratchRegister();
generateCSetInstruction(cg, node, offReg, TR::CC_NE);
auto adjustSrc1AddrInstr = generateTrg1Src2Instruction(cg, TR::InstOpCode::subx, node, src1Reg, src1Reg, offReg);
if (debugObj)
{
debugObj->addInstructionComment(adjustSrc1AddrInstr, "Subtracts 1 from src1 if the mismatch is found in the last byte of data.");
}
}
}

if (isArrayCmpLen)
{
auto done0LabelInstr = generateLabelInstruction(cg, TR::InstOpCode::label, node, done0Label);
generateTrg1Src2Instruction(cg, TR::InstOpCode::subx, node, resultReg, src1Reg, savedSrc1Reg);
if (debugObj)
{
debugObj->addInstructionComment(done0LabelInstr, "done0Label");
}
}
else
{
auto done0LabelInstr = generateLabelInstruction(cg, TR::InstOpCode::label, node, done0Label); /* Result: 0, 1 or 2 */
generateCompareInstruction(cg, node, data1Reg, data2Reg, true);
generateCSetInstruction(cg, node, resultReg, TR::CC_NE);
generateCIncInstruction(cg, node, resultReg, resultReg, TR::CC_HI, false);
if (debugObj)
{
debugObj->addInstructionComment(done0LabelInstr, "done0Label");
}
}

/* savedSrc1Reg, src2Reg, lengthReg, resultReg, and registers allocated through SRM */
TR::RegisterDependencyConditions *conditions = new (cg->trHeapMemory()) TR::RegisterDependencyConditions(0, 4 + srm->numAvailableRegisters(), cg->trMemory());
conditions->addPostCondition(savedSrc1Reg, TR::RealRegister::NoReg);
conditions->addPostCondition(src2Reg, TR::RealRegister::NoReg);
conditions->addPostCondition(lengthReg, TR::RealRegister::NoReg);
conditions->addPostCondition(resultReg, TR::RealRegister::NoReg);
srm->addScratchRegistersToDependencyList(conditions);

auto doneLabelInstr = generateLabelInstruction(cg, TR::InstOpCode::label, node, doneLabel, conditions);
if (debugObj)
{
debugObj->addInstructionComment(doneLabelInstr, "doneLabel");
}
srm->stopUsingRegisters();
cg->stopUsingRegister(src2Reg);
cg->stopUsingRegister(lengthReg);

node->setRegister(resultReg);
cg->decReferenceCount(src1Node);
cg->decReferenceCount(src2Node);
cg->decReferenceCount(lengthNode);

return resultReg;
}

static void
inlineConstantLengthForwardArrayCopy(TR::Node *node, int64_t byteLen, TR::Register *srcReg, TR::Register *dstReg, TR::CodeGenerator *cg)
Expand Down

0 comments on commit 2952aea

Please sign in to comment.