Skip to content

Commit

Permalink
AArch64: Implement arraycmp evaluator
Browse files Browse the repository at this point in the history
This commit implements arraycmp evaluator.
It implements two variants. One of them returns the length of the identical data
from the beginning of the arrays.
The other returns 2/0/1 when the first array is greater than/equal to/less than
the second array.
The main loop reads a 16-byte chunk from the both array in a single interation.
It uses ldp instruction to read 16-byte data into two 64-bit registers.
Then, it compares each 64-bit registers to find mismatch.
If any mismatch is found in 16-byte chunks, the bit position of
the first mismatch is searched by clz instruction for the variant returning the length.
If no mismatch is found in 16-byte chunks and there is still remaining data
to be compared (which is smaller than 16 bytes),
the secondary loop, which reads a single byte in each iteration, is executed.

Signed-off-by: Akira Saitoh <[email protected]>
  • Loading branch information
Akira Saitoh committed Mar 2, 2023
1 parent 084e87a commit 7c991a6
Show file tree
Hide file tree
Showing 2 changed files with 360 additions and 4 deletions.
8 changes: 8 additions & 0 deletions compiler/aarch64/codegen/OMRCodeGenerator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,14 @@ OMR::ARM64::CodeGenerator::initialize()
// Enable compaction of local stack slots. i.e. variables with non-overlapping live ranges
// can share the same slot.
cg->setSupportsCompactedLocals();
if (!TR::Compiler->om.canGenerateArraylets())
{
static const bool disableArrayCmp = feGetEnv("TR_aarch64DisableArrayCmp") != NULL;
if (!disableArrayCmp)
{
cg->setSupportsArrayCmp();
}
}
}

void
Expand Down
356 changes: 352 additions & 4 deletions compiler/aarch64/codegen/OMRTreeEvaluator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5624,10 +5624,358 @@ OMR::ARM64::TreeEvaluator::arraysetEvaluator(TR::Node *node, TR::CodeGenerator *

TR::Register *
OMR::ARM64::TreeEvaluator::arraycmpEvaluator(TR::Node *node, TR::CodeGenerator *cg)
{
// TODO:ARM64: Enable TR::TreeEvaluator::arraycmpEvaluator in compiler/aarch64/codegen/TreeEvaluatorTable.hpp when Implemented.
return OMR::ARM64::TreeEvaluator::unImpOpEvaluator(node, cg);
}
{
/*
* Generating following instruction sequence
*
* ; arrayLen case
* mov resultReg, lengthReg
* ; non-arrayLen case
* mov resultReg, #0
*
* cmp src1Reg, src2Reg
* if !(length is constant and length > 15) {
* ccmp lengthReg, #0, #4, ne ; Sets Z flag if src1Reg and src2Reg are the same or length is 0
* }
* b.eq LDONE
* ; arrayLen case
* mov savedSrc1Reg, src1Reg
*
* if !(length is constant and length > 15) {
* cmp lengthReg, #16
* b.cc LessThan16
* }
* sub lengthReg, lengthReg, #16
* ; Main loop reads 16 bytes from each array using ldp instruction and if any mismatch found, branches to LnotEqual16
* loop16:
* ldp data1Reg, data3Reg, [src1Reg], #16
* ldp data2Reg, data4Reg, [src2Reg], #16
* ; arrayLen case
* subs data1Reg, data1Reg, data2Reg
* ; non-arrayLen case
* cmp data1Reg, data2Reg
*
* ccmp data3Reg, data4Reg, 0, eq
* b.ne Lnotequal16
* sub lengthReg, lengthReg, #16
* b.cs loop16
* if (length is constant and length > 15) {
* cmn lengthReg, #16
* ; arrayLen case
* b LDONE0
* ; non-arrayLen case
* b LDONE
*
* add src1Reg, src1Reg, lengthReg ; src1Reg points to 16bytes before the end
* add src2Reg, src2Reg, lengthReg
* mov lengthReg, #0
* b loop16
* } else {
* add lengthReg, lengthReg, #16
* ; arrayLen case
* cbz lengthReg, LDONE0
* ; non-arrayLen case
* cbz lengthReg, LDONE
*
* b lessThan16
* }
*
* Lnotequal16: ; src1Reg points 16-byte ahead of the location where data1reg was read.
* ; arrayLen case
* eor data3Reg, data3Reg, data4Reg
* cmp data1Reg, #0
* csel data1Reg, data1Reg, data3Reg, ne
* mov data2Reg, #1
* cinc data2Reg, data2Reg, ne
* sub src1Reg, src1Reg, data2Reg, lsl #3 ; Adjusts post incremented address
* rbit data1Reg, data1Reg
* clz data1Reg, data1Reg
* add src1Reg, src1Reg, data1Reg, lsr #3
*
* ; non-arrayLen case
* cmp data1Reg, data2Reg
* csel data1Reg, data1Reg, data3Reg, ne
* csel data2Reg, data2Reg, data4Reg, ne
* rev64 data1Reg, data1Reg
* rev64 data2Reg, data2Reg
*
* if !(length is constant and length > 15) {
* b LDONE0
*
* LessThan16:
* byteloop:
* subs lengthReg, lengthReg, #1
* ldrb data1Reg, [src1Reg], #1
* ldrb data2Reg, [src2Reg], #1
* ccmp data1Reg, data2Reg, #0, hi
* b.eq byteloop
* ; arrayLen case
* cmp data1Reg, data2Reg
* cset offReg, ne
* sub src1Reg, src1Reg, offReg
* }
* ; arrayLen case
* LDONE0:
* sub resultReg, src1Reg, savedSrc1Reg
* ; non-arrayLen case
* LDONE0:
* cmp data1Reg, data2Reg
* cset resultReg, ne
* cinc resultReg, resultReg, hi ; Returns 1 or 2
*
* LDONE:
*/
TR::Node *src1Node = node->getFirstChild();
TR::Node *src2Node = node->getSecondChild();
TR::Node *lengthNode = node->getThirdChild();
bool isLengthGreaterThan15 = lengthNode->getOpCode().isLoadConst() && lengthNode->getConstValue() > 15;
const bool isArrayCmpLen = node->isArrayCmpLen();
TR_ARM64ScratchRegisterManager *srm = cg->generateScratchRegisterManager(12);

TR::Register *savedSrc1Reg = cg->evaluate(src1Node);
TR::Register *src1Reg;
if ((src1Node->getReferenceCount() > 1) || isArrayCmpLen)
{
src1Reg = srm->findOrCreateScratchRegister();
generateMovInstruction(cg, node, src1Reg, savedSrc1Reg);
}
else
{
src1Reg = savedSrc1Reg;
}
TR::Register *src2Reg = cg->gprClobberEvaluate(src2Node);
TR::Register *lengthReg = cg->gprClobberEvaluate(lengthNode);
TR::Register *resultReg = cg->allocateRegister();
TR::LabelSymbol *startLabel = generateLabelSymbol(cg);
TR::LabelSymbol *doneLabel = generateLabelSymbol(cg);
TR_Debug *debugObj = cg->getDebug();

startLabel->setStartInternalControlFlow();
doneLabel->setEndInternalControlFlow();

generateLabelInstruction(cg, TR::InstOpCode::label, node, startLabel);
if (isArrayCmpLen)
{
generateMovInstruction(cg, node, resultReg, lengthReg, false);
}
else
{
loadConstant32(cg, node, 0, resultReg);
}
generateCompareInstruction(cg, node, src1Reg, src2Reg, true);
if (!isLengthGreaterThan15)
{
auto ccmpLengthInstr = generateConditionalCompareImmInstruction(cg, node, lengthReg, 0, 4, TR::CC_NE); /* 4 for Z flag */
if (debugObj)
{
debugObj->addInstructionComment(ccmpLengthInstr, "Compares lengthReg with 0 if src1 and src2 are not the same array. Otherwise, sets EQ flag.");
}
}
auto branchToDoneLabelInstr = generateConditionalBranchInstruction(cg, TR::InstOpCode::b_cond, node, doneLabel, TR::CC_EQ);
if (debugObj)
{
debugObj->addInstructionComment(branchToDoneLabelInstr, "Done if src1 and src2 are the same array or length is 0.");
}

TR::LabelSymbol *notEqual16Label = generateLabelSymbol(cg);

TR::LabelSymbol *done0Label = generateLabelSymbol(cg);
TR::LabelSymbol *lessThan16Label = generateLabelSymbol(cg);
TR::Register *data1Reg = srm->findOrCreateScratchRegister();
TR::Register *data2Reg = srm->findOrCreateScratchRegister();
TR::Register *data3Reg = srm->findOrCreateScratchRegister();
TR::Register *data4Reg = srm->findOrCreateScratchRegister();
if (!isLengthGreaterThan15)
{
generateCompareImmInstruction(cg, node, lengthReg, 16);
auto branchToLessThan16LabelInstr = generateConditionalBranchInstruction(cg, TR::InstOpCode::b_cond, node, lessThan16Label, TR::CC_CC);
if (debugObj)
{
debugObj->addInstructionComment(branchToLessThan16LabelInstr, "Jumps to lessThan16Label if length < 16.");
}
}
generateTrg1Src1ImmInstruction(cg, TR::InstOpCode::subimmw, node, lengthReg, lengthReg, 16);

TR::LabelSymbol *loop16Label = generateLabelSymbol(cg);
{
auto loop16LabelInstr = generateLabelInstruction(cg, TR::InstOpCode::label, node, loop16Label);
generateTrg2MemInstruction(cg, TR::InstOpCode::ldppostx, node, data1Reg, data3Reg, TR::MemoryReference::createWithDisplacement(cg, src1Reg, 16));
generateTrg2MemInstruction(cg, TR::InstOpCode::ldppostx, node, data2Reg, data4Reg, TR::MemoryReference::createWithDisplacement(cg, src2Reg, 16));
if (isArrayCmpLen)
{
generateTrg1Src2Instruction(cg, TR::InstOpCode::subsx, node, data1Reg, data1Reg, data2Reg);
}
else
{
generateCompareInstruction(cg, node, data1Reg, data2Reg, true);
}
generateConditionalCompareInstruction(cg, node, data3Reg, data4Reg, 0, TR::CC_EQ, true);
auto branchToNotEqual16LabelInstr2 = generateConditionalBranchInstruction(cg, TR::InstOpCode::b_cond, node, notEqual16Label, TR::CC_NE);
auto subtractLengthInstr = generateTrg1Src1ImmInstruction(cg, isLengthGreaterThan15 ? TR::InstOpCode::subsimmx : TR::InstOpCode::subsimmw, node, lengthReg, lengthReg, 16);
auto branchBacktoLoop16LabelInstr = generateConditionalBranchInstruction(cg, TR::InstOpCode::b_cond, node, loop16Label, TR::CC_CS);
if (debugObj)
{
debugObj->addInstructionComment(loop16LabelInstr, "loop16Label");
debugObj->addInstructionComment(branchToNotEqual16LabelInstr2, "Jumps to notEqual16Label if mismatch is found in the 16-byte data");
debugObj->addInstructionComment(branchBacktoLoop16LabelInstr, "Jumps to loop16Label if the remaining length >= 16 and no mismatch is found so far.");
if (isLengthGreaterThan15)
{
debugObj->addInstructionComment(subtractLengthInstr, "Treats length reg as a 64-bit reg as it is used as the 2nd source reg for 64-bit add later.");
}
}
}
if (isLengthGreaterThan15)
{
generateCompareImmInstruction(cg, node, lengthReg, -16, true);
auto branchToDoneLabelInstr3 = generateConditionalBranchInstruction(cg, TR::InstOpCode::b_cond, node, isArrayCmpLen ? done0Label : doneLabel, TR::CC_EQ);
auto adjustSrc1RegInstr = generateTrg1Src2Instruction(cg, TR::InstOpCode::addx, node, src1Reg, src1Reg, lengthReg);
generateTrg1Src2Instruction(cg, TR::InstOpCode::addx, node, src2Reg, src2Reg, lengthReg);
loadConstant32(cg, node, 0, lengthReg);
auto branchBacktoLoop16LabelInstr = generateLabelInstruction(cg, TR::InstOpCode::b, node, loop16Label);
if (debugObj)
{
if (isArrayCmpLen)
{
debugObj->addInstructionComment(branchToDoneLabelInstr3, "Jumps to done0Label if the remaining length is 0.");
}
else
{
debugObj->addInstructionComment(branchToDoneLabelInstr3, "Jumps to doneLabel if the remaining length is 0.");
}
debugObj->addInstructionComment(adjustSrc1RegInstr, "Adjusts src registers so that they point to 16bytes before the end");
debugObj->addInstructionComment(branchBacktoLoop16LabelInstr, "Jumps to loop16Label.");
}
}
else
{
generateTrg1Src1ImmInstruction(cg, TR::InstOpCode::addimmw, node, lengthReg, lengthReg, 16);
auto branchToDoneLabelInstr3 = generateCompareBranchInstruction(cg, TR::InstOpCode::cbzw, node, lengthReg, isArrayCmpLen ? done0Label : doneLabel);
auto branchToLessThan16Label2 = generateLabelInstruction(cg, TR::InstOpCode::b, node, lessThan16Label);

if (debugObj)
{
if (isArrayCmpLen)
{
debugObj->addInstructionComment(branchToDoneLabelInstr3, "Jumps to done0Label if the remaining length is 0.");
}
else
{
debugObj->addInstructionComment(branchToDoneLabelInstr3, "Jumps to doneLabel if the remaining length is 0.");
}
debugObj->addInstructionComment(branchToLessThan16Label2, "Jumps to lessThan16Label");
}
}

auto notEqual16LabelInstr = generateLabelInstruction(cg, TR::InstOpCode::label, node, notEqual16Label);
if (debugObj)
{
debugObj->addInstructionComment(notEqual16LabelInstr, "notEqual16Label. src register points 16-byte ahead of the location where the data in the registers was read.");
}

if (isArrayCmpLen)
{
generateTrg1Src2Instruction(cg, TR::InstOpCode::eorx, node, data3Reg, data3Reg, data4Reg);
generateCompareImmInstruction(cg, node, data1Reg, 0, true);
generateCondTrg1Src2Instruction(cg, TR::InstOpCode::cselx, node, data1Reg, data1Reg, data3Reg, TR::CC_NE);
loadConstant32(cg, node, 1, data2Reg);
auto getOffsetInstr = generateCIncInstruction(cg, node, data2Reg, data2Reg, TR::CC_NE, true);
auto adjustingBaseAddrInstr = generateTrg1Src2ShiftedInstruction(cg, TR::InstOpCode::subx, node, src1Reg, src1Reg, data2Reg, TR::SH_LSL, 3);
generateTrg1Src1Instruction(cg, TR::InstOpCode::rbitx, node, data1Reg, data1Reg);
auto getMismatchLocationInstr = generateTrg1Src1Instruction(cg, TR::InstOpCode::clzx, node, data1Reg, data1Reg);
generateTrg1Src2ShiftedInstruction(cg, TR::InstOpCode::addx, node, src1Reg, src1Reg, data1Reg, TR::SH_LSR, 3);
if (debugObj)
{
debugObj->addInstructionComment(getOffsetInstr, "register has 1 if mismatch is in the first 8-byte data. Otherwise it has 2.");
debugObj->addInstructionComment(adjustingBaseAddrInstr, "Adjusts base register so that it points to the 8-byte boundary before the mismatch.");
debugObj->addInstructionComment(getMismatchLocationInstr, "Gets the bit position of the first mismatch.");
}
}
else
{
generateCompareInstruction(cg, node, data1Reg, data2Reg, true);
generateCondTrg1Src2Instruction(cg, TR::InstOpCode::cselx, node, data1Reg, data1Reg, data3Reg, TR::CC_NE);
generateCondTrg1Src2Instruction(cg, TR::InstOpCode::cselx, node, data2Reg, data2Reg, data4Reg, TR::CC_NE);
generateTrg1Src1Instruction(cg, TR::InstOpCode::revx, node, data1Reg, data1Reg);
generateTrg1Src1Instruction(cg, TR::InstOpCode::revx, node, data2Reg, data2Reg);
}
srm->reclaimScratchRegister(data3Reg);
srm->reclaimScratchRegister(data4Reg);

if (!isLengthGreaterThan15)
{
auto branchToDone0LabelInstr = generateLabelInstruction(cg, TR::InstOpCode::b, node, done0Label);

auto lessThan16LabelInstr = generateLabelInstruction(cg, TR::InstOpCode::label, node, lessThan16Label);
generateTrg1Src1ImmInstruction(cg, TR::InstOpCode::subsimmw, node, lengthReg, lengthReg, 1);
generateTrg1MemInstruction(cg, TR::InstOpCode::ldrbpost, node, data1Reg, TR::MemoryReference::createWithDisplacement(cg, src1Reg, 1));
generateTrg1MemInstruction(cg, TR::InstOpCode::ldrbpost, node, data2Reg, TR::MemoryReference::createWithDisplacement(cg, src2Reg, 1));
generateConditionalCompareInstruction(cg, node, data1Reg, data2Reg, 0, TR::CC_HI);
auto branchBacktoLessThan16LabelInstr = generateConditionalBranchInstruction(cg, TR::InstOpCode::b_cond, node, lessThan16Label, TR::CC_EQ);
if (debugObj)
{
debugObj->addInstructionComment(branchToDone0LabelInstr, "Jumps to done0Label.");
debugObj->addInstructionComment(lessThan16LabelInstr, "lessThan16Label");
debugObj->addInstructionComment(branchBacktoLessThan16LabelInstr, "Jumps to lessThan16Label (byteloop) if the remaining length > 0 and no mismatch is found");
}
if (isArrayCmpLen)
{
generateCompareInstruction(cg, node, data1Reg, data2Reg);

TR::Register *offReg = srm->findOrCreateScratchRegister();
generateCSetInstruction(cg, node, offReg, TR::CC_NE);
auto adjustSrc1AddrInstr = generateTrg1Src2Instruction(cg, TR::InstOpCode::subx, node, src1Reg, src1Reg, offReg);
if (debugObj)
{
debugObj->addInstructionComment(adjustSrc1AddrInstr, "Subtracts 1 from src1 if the mismatch is found in the last byte of data.");
}
}
}

if (isArrayCmpLen)
{
auto done0LabelInstr = generateLabelInstruction(cg, TR::InstOpCode::label, node, done0Label);
generateTrg1Src2Instruction(cg, TR::InstOpCode::subx, node, resultReg, src1Reg, savedSrc1Reg);
if (debugObj)
{
debugObj->addInstructionComment(done0LabelInstr, "done0Label");
}
}
else
{
auto done0LabelInstr = generateLabelInstruction(cg, TR::InstOpCode::label, node, done0Label); /* Result: 0, 1 or 2 */
generateCompareInstruction(cg, node, data1Reg, data2Reg, true);
generateCSetInstruction(cg, node, resultReg, TR::CC_NE);
generateCIncInstruction(cg, node, resultReg, resultReg, TR::CC_HI, false);
if (debugObj)
{
debugObj->addInstructionComment(done0LabelInstr, "done0Label");
}
}

/* savedSrc1Reg, src2Reg, lengthReg, resultReg, and registers allocated through SRM */
TR::RegisterDependencyConditions *conditions = new (cg->trHeapMemory()) TR::RegisterDependencyConditions(0, 4 + srm->numAvailableRegisters(), cg->trMemory());
conditions->addPostCondition(savedSrc1Reg, TR::RealRegister::NoReg);
conditions->addPostCondition(src2Reg, TR::RealRegister::NoReg);
conditions->addPostCondition(lengthReg, TR::RealRegister::NoReg);
conditions->addPostCondition(resultReg, TR::RealRegister::NoReg);
srm->addScratchRegistersToDependencyList(conditions);

auto doneLabelInstr = generateLabelInstruction(cg, TR::InstOpCode::label, node, doneLabel, conditions);
if (debugObj)
{
debugObj->addInstructionComment(doneLabelInstr, "doneLabel");
}
srm->stopUsingRegisters();
cg->stopUsingRegister(src2Reg);
cg->stopUsingRegister(lengthReg);

node->setRegister(resultReg);
cg->decReferenceCount(src1Node);
cg->decReferenceCount(src2Node);
cg->decReferenceCount(lengthNode);

return resultReg;
}

static void
inlineConstantLengthForwardArrayCopy(TR::Node *node, int64_t byteLen, TR::Register *srcReg, TR::Register *dstReg, TR::CodeGenerator *cg)
Expand Down

0 comments on commit 7c991a6

Please sign in to comment.