From f317f52635ccacfa1cfe7a786076cc58bbb4dd52 Mon Sep 17 00:00:00 2001 From: Yanqi Yang Date: Thu, 14 Dec 2023 10:40:35 +0800 Subject: [PATCH] [rtl] fix divSqrtMux selectng signal --- arithmetic/src/float/DivSqrtMerge.scala | 52 ++++++++++++------------- 1 file changed, 26 insertions(+), 26 deletions(-) diff --git a/arithmetic/src/float/DivSqrtMerge.scala b/arithmetic/src/float/DivSqrtMerge.scala index 351de6a..9df5dce 100644 --- a/arithmetic/src/float/DivSqrtMerge.scala +++ b/arithmetic/src/float/DivSqrtMerge.scala @@ -71,13 +71,13 @@ class DivSqrtMerge(expWidth: Int, sigWidth: Int) extends Module { rawA.isNaN || sqrtInvalidCases, rawA.isNaN || rawB.isNaN || divInvalidCases ) - val isInf = Mux(input.bits.sqrt, rawA.isInf, rawA.isInf || rawB.isZero) + val isInf = Mux(input.bits.sqrt, rawA.isInf, rawA.isInf || rawB.isZero) val isZero = Mux(input.bits.sqrt, rawA.isZero, rawA.isZero || rawB.isInf) val isNVorDZReg = RegEnable(isNVorDZ, false.B, input.fire) - val isNaNReg = RegEnable(isNaN, false.B, input.fire) - val isInfReg = RegEnable(isInf, false.B, input.fire) - val isZeroReg = RegEnable(isZero, false.B, input.fire) + val isNaNReg = RegEnable(isNaN, false.B, input.fire) + val isInfReg = RegEnable(isInf, false.B, input.fire) + val isZeroReg = RegEnable(isZero, false.B, input.fire) /** invalid operation flag */ val invalidExec = isNVorDZReg && isNaNReg @@ -126,18 +126,18 @@ class DivSqrtMerge(expWidth: Int, sigWidth: Int) extends Module { // build DIV Input val fractDividendIn = Wire(UInt((fpWidth).W)) - val fractDivisorIn = Wire(UInt((fpWidth).W)) + val fractDivisorIn = Wire(UInt((fpWidth).W)) fractDividendIn := Cat(1.U(1.W), rawA.sig(sigWidth - 2, 0), 0.U(expWidth.W)) - fractDivisorIn := Cat(1.U(1.W), rawB.sig(sigWidth - 2, 0), 0.U(expWidth.W)) + fractDivisorIn := Cat(1.U(1.W), rawB.sig(sigWidth - 2, 0), 0.U(expWidth.W)) val sqrtIter = Module(new SqrtIter(2, 2, sqrtIterWidth, sigWidth + 2)) - val divIter = Module(new SRT16Iter(fpWidth, fpWidth, fpWidth, 2, 2, 4, 4)) + val divIter = Module(new SRT16Iter(fpWidth, fpWidth, fpWidth, 2, 2, 4, 4)) val sqrtMuxIn = Wire(new IterMuxIO(expWidth, sigWidth, fpWidth, ohWidth, iterWidth)) val divMuxIn = Wire(new IterMuxIO(expWidth, sigWidth, fpWidth, ohWidth, iterWidth)) val divSqrtMuxOut = Wire(new IterMuxIO(expWidth, sigWidth, fpWidth, ohWidth, iterWidth)) - divSqrtMuxOut := Mux(opSqrtReg, sqrtMuxIn, divMuxIn) + divSqrtMuxOut := Mux(opSqrtReg || (input.bits.sqrt && input.fire), sqrtMuxIn, divMuxIn) val divValid = input.valid && !input.bits.sqrt && normalCaseDiv val divReady = divIter.input.ready @@ -163,20 +163,20 @@ class DivSqrtMerge(expWidth: Int, sigWidth: Int) extends Module { divSqrtMuxOut.quotientMinusOne, divSqrtMuxOut.selectedQuotientOH) - sqrtIter.input.valid := sqrtValid + sqrtIter.input.valid := sqrtValid sqrtIter.input.bits.partialCarry := partialCarry - sqrtIter.input.bits.partialSum := partialSum + sqrtIter.input.bits.partialSum := partialSum - divIter.input.valid := divValid - divIter.input.bits.partialSum := partialSum + divIter.input.valid := divValid + divIter.input.bits.partialSum := partialSum divIter.input.bits.partialCarry := partialCarry - divIter.input.bits.divider := fractDivisorIn - divIter.input.bits.counter := 8.U + divIter.input.bits.divider := fractDivisorIn + divIter.input.bits.counter := 8.U - sqrtIter.respOTF.quotient := otf(0) + sqrtIter.respOTF.quotient := otf(0) sqrtIter.respOTF.quotientMinusOne := otf(1) - divIter.respOTF.quotient := otf(0) - divIter.respOTF.quotientMinusOne := otf(1) + divIter.respOTF.quotient := otf(0) + divIter.respOTF.quotientMinusOne := otf(1) /** collect div result * @@ -235,28 +235,29 @@ class DivSqrtMerge(expWidth: Int, sigWidth: Int) extends Module { isZeroReg) sqrtMuxIn.enable := (sqrtValid && sqrtReady) || !sqrtIter.output.isLastCycle - divMuxIn.enable := (divValid && divReady) || !divIter.output.isLastCycle sqrtMuxIn.partialSumInit := Cat("b11".U, sqrtFractIn) - divMuxIn.partialSumInit := fractDividendIn sqrtMuxIn.partialSumNext := sqrtIter.output.partialSum sqrtMuxIn.partialCarryNext := sqrtIter.output.partialCarry - divMuxIn.partialSumNext := divIter.output.partialSum - divMuxIn.partialCarryNext := divIter.output.partialCarry sqrtMuxIn.quotient := sqrtIter.reqOTF.quotient sqrtMuxIn.quotientMinusOne := sqrtIter.reqOTF.quotientMinusOne sqrtMuxIn.selectedQuotientOH := sqrtIter.reqOTF.selectedQuotientOH + sqrtMuxIn.sigToRound := sigPlusSqrt + sqrtMuxIn.expToRound := expStore + + divMuxIn.enable := (divValid && divReady) || !divIter.output.isLastCycle + divMuxIn.partialSumInit := fractDividendIn + divMuxIn.partialSumNext := divIter.output.partialSum + divMuxIn.partialCarryNext := divIter.output.partialCarry divMuxIn.quotient := divIter.reqOTF.quotient divMuxIn.quotientMinusOne := divIter.reqOTF.quotientMinusOne divMuxIn.selectedQuotientOH := divIter.reqOTF.selectedQuotientOH - sqrtMuxIn.sigToRound := sigPlusSqrt - sqrtMuxIn.expToRound := expStore divMuxIn.sigToRound := sigPlusDiv divMuxIn.expToRound := expStore - needRightShift - output.bits.result := roundresult(0) + output.bits.result := roundresult(0) output.bits.exceptionFlags := roundresult(1) - input.ready := divReady && sqrtReady + input.ready := divReady && sqrtReady output.valid := divIter.resultOutput.valid || sqrtIter.resultOutput.valid || fastValid } @@ -277,6 +278,5 @@ class IterMuxIO(expWidth: Int, sigWidth: Int, qWidth: Int, ohWidth: Int, iterWid // collect output val expToRound = UInt((expWidth + 2).W) val sigToRound = UInt((sigWidth + 2).W) - }