diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp index e8d1ac1d3a9167..e3b76b95eb86ad 100644 --- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp @@ -3419,13 +3419,18 @@ KnownBits SelectionDAG::computeKnownBits(SDValue Op, const APInt &DemandedElts, Known = KnownBits::mulhs(Known, Known2); break; } - case ISD::AVGCEILU: { + case ISD::AVGFLOORU: + case ISD::AVGCEILU: + case ISD::AVGFLOORS: + case ISD::AVGCEILS: { + bool IsCeil = Opcode == ISD::AVGCEILU || Opcode == ISD::AVGCEILS; + bool IsSigned = Opcode == ISD::AVGFLOORS || Opcode == ISD::AVGCEILS; Known = computeKnownBits(Op.getOperand(0), DemandedElts, Depth + 1); Known2 = computeKnownBits(Op.getOperand(1), DemandedElts, Depth + 1); - Known = Known.zext(BitWidth + 1); - Known2 = Known2.zext(BitWidth + 1); - KnownBits One = KnownBits::makeConstant(APInt(1, 1)); - Known = KnownBits::computeForAddCarry(Known, Known2, One); + Known = IsSigned ? Known.sext(BitWidth + 1) : Known.zext(BitWidth + 1); + Known2 = IsSigned ? Known2.sext(BitWidth + 1) : Known2.zext(BitWidth + 1); + KnownBits Carry = KnownBits::makeConstant(APInt(1, IsCeil ? 1 : 0)); + Known = KnownBits::computeForAddCarry(Known, Known2, Carry); Known = Known.extractBits(BitWidth, 1); break; } diff --git a/llvm/unittests/CodeGen/AArch64SelectionDAGTest.cpp b/llvm/unittests/CodeGen/AArch64SelectionDAGTest.cpp index e0772684e3a954..27bcad7c24c4db 100644 --- a/llvm/unittests/CodeGen/AArch64SelectionDAGTest.cpp +++ b/llvm/unittests/CodeGen/AArch64SelectionDAGTest.cpp @@ -796,4 +796,52 @@ TEST_F(AArch64SelectionDAGTest, computeKnownBits_extload_knownnegative) { EXPECT_EQ(Known.One, APInt(32, 0xfffffff0)); } +TEST_F(AArch64SelectionDAGTest, + computeKnownBits_AVGFLOORU_AVGFLOORS_AVGCEILU_AVGCEILS) { + SDLoc Loc; + auto Int8VT = EVT::getIntegerVT(Context, 8); + auto Int16VT = EVT::getIntegerVT(Context, 16); + auto Int8Vec8VT = EVT::getVectorVT(Context, Int8VT, 8); + auto Int16Vec8VT = EVT::getVectorVT(Context, Int16VT, 8); + + SDValue UnknownOp0 = DAG->getRegister(0, Int8Vec8VT); + SDValue UnknownOp1 = DAG->getRegister(1, Int8Vec8VT); + + SDValue ZextOp0 = + DAG->getNode(ISD::ZERO_EXTEND, Loc, Int16Vec8VT, UnknownOp0); + SDValue ZextOp1 = + DAG->getNode(ISD::ZERO_EXTEND, Loc, Int16Vec8VT, UnknownOp1); + // ZextOp0 = 00000000???????? + // ZextOp1 = 00000000???????? + // => (for all AVG* instructions) + // Known.Zero = 1111111100000000 (0xFF00) + // Known.One = 0000000000000000 (0x0000) + auto Zeroes = APInt(16, 0xFF00); + auto Ones = APInt(16, 0x0000); + + SDValue AVGFLOORU = + DAG->getNode(ISD::AVGFLOORU, Loc, Int16Vec8VT, ZextOp0, ZextOp1); + KnownBits KnownAVGFLOORU = DAG->computeKnownBits(AVGFLOORU); + EXPECT_EQ(KnownAVGFLOORU.Zero, Zeroes); + EXPECT_EQ(KnownAVGFLOORU.One, Ones); + + SDValue AVGFLOORS = + DAG->getNode(ISD::AVGFLOORU, Loc, Int16Vec8VT, ZextOp0, ZextOp1); + KnownBits KnownAVGFLOORS = DAG->computeKnownBits(AVGFLOORS); + EXPECT_EQ(KnownAVGFLOORS.Zero, Zeroes); + EXPECT_EQ(KnownAVGFLOORS.One, Ones); + + SDValue AVGCEILU = + DAG->getNode(ISD::AVGCEILU, Loc, Int16Vec8VT, ZextOp0, ZextOp1); + KnownBits KnownAVGCEILU = DAG->computeKnownBits(AVGCEILU); + EXPECT_EQ(KnownAVGCEILU.Zero, Zeroes); + EXPECT_EQ(KnownAVGCEILU.One, Ones); + + SDValue AVGCEILS = + DAG->getNode(ISD::AVGCEILS, Loc, Int16Vec8VT, ZextOp0, ZextOp1); + KnownBits KnownAVGCEILS = DAG->computeKnownBits(AVGCEILS); + EXPECT_EQ(KnownAVGCEILS.Zero, Zeroes); + EXPECT_EQ(KnownAVGCEILS.One, Ones); +} + } // end namespace llvm