diff --git a/.gitignore b/.gitignore index c0b4ce2..beffdbc 100644 --- a/.gitignore +++ b/.gitignore @@ -24,4 +24,5 @@ verdiLog *.out *.cmd *.log -*.json \ No newline at end of file +*.json +*.iml \ No newline at end of file diff --git a/arithmetic/src/addition/prefixadder/graph/PrefixGraph.scala b/arithmetic/src/addition/prefixadder/graph/PrefixGraph.scala index 38e63bd..ae72d72 100644 --- a/arithmetic/src/addition/prefixadder/graph/PrefixGraph.scala +++ b/arithmetic/src/addition/prefixadder/graph/PrefixGraph.scala @@ -53,7 +53,7 @@ object PrefixGraph { } object CommonSumByConsole extends HasPrefixSumWithGraphImp with CommonPrefixSum { - val filePath = Path(io.StdIn.readLine("Import your graph generated by `dot -Txdot_json`: "), pwd) + val filePath = Path(scala.io.StdIn.readLine("Import your graph generated by `dot -Txdot_json`: "), pwd) val fileName = filePath.baseName val prefixGraph: PrefixGraph = PrefixGraph(filePath) } diff --git a/arithmetic/src/division/srt/SRT.scala b/arithmetic/src/division/srt/SRT.scala new file mode 100644 index 0000000..539afb4 --- /dev/null +++ b/arithmetic/src/division/srt/SRT.scala @@ -0,0 +1,68 @@ +package division.srt + +import division.srt.srt4._ +import division.srt.srt8._ +import division.srt.srt16._ +import chisel3._ +import chisel3.util.{DecoupledIO, ValidIO} + +class SRT( + dividendWidth: Int, + dividerWidth: Int, + n: Int, // the longest width + radixLog2: Int = 2, + a: Int = 2, + dTruncateWidth: Int = 4, + rTruncateWidth: Int = 4) + extends Module { +// val x = (radixLog2, a, dTruncateWidth) +// val tips = x match { +// case (2,2,4) => require(rTruncateWidth >= 4, "rTruncateWidth need >= 4") +// case (2,2,5) => require(rTruncateWidth >= 4, "rTruncateWidth need >= 4") +// case (2,2,6) => require(rTruncateWidth >= 4, "rTruncateWidth need >= 4") +// +// case (3,4,6) => require(rTruncateWidth >= 7, "rTruncateWidth need >= 7") +// case (3,4,7) => require(rTruncateWidth >= 6, "rTruncateWidth need >= 6") +// +// case (3,5,5) => require(rTruncateWidth >= 5, "rTruncateWidth need >= 5") +// case (3,5,6) => require(rTruncateWidth >= 4, "rTruncateWidth need >= 4") +// +// case (3,6,4) => require(rTruncateWidth >= 6, "rTruncateWidth need >= 6") +// case (3,6,5) => require(rTruncateWidth >= 4, "rTruncateWidth need >= 4") +// +// case (3,7,4) => require(rTruncateWidth >= 4, "rTruncateWidth need >= 4") +// case (3,7,5) => require(rTruncateWidth >= 3, "rTruncateWidth need >= 3") +// +// case (4,2,4) => require(rTruncateWidth >= 4, "rTruncateWidth need >= 4") +// case (4,2,5) => require(rTruncateWidth >= 4, "rTruncateWidth need >= 4") +// case (4,2,6) => require(rTruncateWidth >= 4, "rTruncateWidth need >= 4") +// +// case _ => println("this srt is not supported") +// } + + val input = IO(Flipped(DecoupledIO(new SRTInput(dividendWidth, dividerWidth, n)))) + val output = IO(ValidIO(new SRTOutput(dividerWidth, dividendWidth))) + +// select radix + if (radixLog2 == 2) { // SRT4 + val srt = Module(new SRT4(dividendWidth, dividerWidth, n, radixLog2, a, dTruncateWidth, rTruncateWidth)) + srt.input <> input + output <> srt.output + } else if (radixLog2 == 3) { // SRT8 + val srt = Module(new SRT8(dividendWidth, dividerWidth, n, radixLog2, a, dTruncateWidth, rTruncateWidth)) + srt.input <> input + output <> srt.output + } else if (radixLog2 == 4) { //SRT16 + val srt = Module(new SRT16(dividendWidth, dividerWidth, n, radixLog2 >> 1, a, dTruncateWidth, rTruncateWidth)) + srt.input <> input + output <> srt.output + } + +// val srt = radixLog2 match { +// case 2 => Module(new SRT4(dividendWidth, dividerWidth, n, radixLog2, a, dTruncateWidth, rTruncateWidth)) +// case 3 => Module(new SRT8(dividendWidth, dividerWidth, n, radixLog2, a, dTruncateWidth, rTruncateWidth)) +// case 4 => Module(new SRT16(dividendWidth, dividerWidth, n, radixLog2 >> 1, a, dTruncateWidth, rTruncateWidth)) +// } +// srt.input <> input +// output <> srt.output +} diff --git a/arithmetic/src/division/srt/SRTIO.scala b/arithmetic/src/division/srt/SRTIO.scala new file mode 100644 index 0000000..417aaa5 --- /dev/null +++ b/arithmetic/src/division/srt/SRTIO.scala @@ -0,0 +1,38 @@ +package division.srt + +import chisel3._ +import chisel3.util.log2Ceil +// SRTIO +class SRTInput(dividendWidth: Int, dividerWidth: Int, n: Int) extends Bundle { + val dividend = UInt(dividendWidth.W) //.*********** + val divider = UInt(dividerWidth.W) //.1********** + val counter = UInt(log2Ceil(n).W) //the width of quotient. +} + +class SRTOutput(reminderWidth: Int, quotientWidth: Int) extends Bundle { + val reminder = UInt(reminderWidth.W) + val quotient = UInt(quotientWidth.W) +} + +//OTFIO +class OTFInput(qWidth: Int, ohWidth: Int) extends Bundle { + val quotient = UInt(qWidth.W) + val quotientMinusOne = UInt(qWidth.W) + val selectedQuotientOH = UInt(ohWidth.W) +} + +class OTFOutput(qWidth: Int) extends Bundle { + val quotient = UInt(qWidth.W) + val quotientMinusOne = UInt(qWidth.W) +} + +// QDSIO +class QDSInput(rWidth: Int, partialDividerWidth: Int) extends Bundle { + val partialReminderCarry: UInt = UInt(rWidth.W) + val partialReminderSum: UInt = UInt(rWidth.W) + val partialDivider: UInt = UInt(partialDividerWidth.W) +} + +class QDSOutput(ohWidth: Int) extends Bundle { + val selectedQuotientOH: UInt = UInt(ohWidth.W) +} diff --git a/arithmetic/src/division/srt/SRTTable.scala b/arithmetic/src/division/srt/SRTTable.scala new file mode 100644 index 0000000..da843da --- /dev/null +++ b/arithmetic/src/division/srt/SRTTable.scala @@ -0,0 +1,194 @@ +package division.srt + +import com.cibo.evilplot.colors.HTMLNamedColors +import com.cibo.evilplot.numeric.Bounds +import com.cibo.evilplot.plot._ +import com.cibo.evilplot.plot.aesthetics.DefaultTheme._ +import com.cibo.evilplot.plot.renderers.PointRenderer +import os.Path +import spire.implicits._ +import spire.math._ + +/** Base SRT class. + * + * @param radix is the radix of SRT. + * It defined how many rounds can be calculate in one cycle. + * @note 5.2 + */ +case class SRTTable( + radix: Algebraic, + a: Algebraic, + dTruncateWidth: Algebraic, + xTruncateWidth: Algebraic, + dMin: Algebraic = 0.5, + dMax: Algebraic = 1) { + require(a > 0) + lazy val xMin: Algebraic = -rho * dMax * radix + lazy val xMax: Algebraic = rho * dMax * radix + + /** P-D Diagram + * + * @note Graph 5.17(b) + */ + lazy val pd: Plot = Overlay((aMin.toBigInt to aMax.toBigInt).flatMap { k: BigInt => + Seq( + FunctionPlot.series( + _ * uRate(k.toInt).toDouble, + s"U($k)", + HTMLNamedColors.blue, + Some(Bounds(dMin.toDouble, dMax.toDouble)), + strokeWidth = Some(1) + ), + FunctionPlot.series( + _ * lRate(k.toInt).toDouble, + s"L($k)", + HTMLNamedColors.red, + Some(Bounds(dMin.toDouble, dMax.toDouble)), + strokeWidth = Some(1) + ) + ) ++ qdsPoints :+ mesh + }: _*) + .title(s"P-D Graph of $this") + .xLabel("d") + .yLabel(s"${radix.toInt}ω[j]") + .rightLegend() + .standard() + + lazy val aMax: Algebraic = a + lazy val aMin: Algebraic = -a + lazy val deltaD: Algebraic = pow(2, -dTruncateWidth.toDouble) + lazy val deltaX: Algebraic = pow(2, -xTruncateWidth.toDouble) + + /** redundancy factor + * @note 5.8 + */ + lazy val rho: Algebraic = a / (radix - 1) + // k d m xSet + lazy val tables: Seq[(Int, Seq[(Algebraic, Seq[Algebraic])])] = { + (aMin.toInt to aMax.toInt).drop(1).map { k => + k -> dSet.dropRight(1).map { d => + val (floor, ceil) = xRange(k, d, d + deltaD) + val m: Seq[Algebraic] = xSet.filter { x: Algebraic => x <= (ceil - deltaX) && x >= floor } + (d, m) + } + } + } + lazy val qdsPoints: Seq[Plot] = { + tables.map { + case (i, ps) => + ScatterPlot( + ps.flatMap { case (d, xs) => xs.map(x => com.cibo.evilplot.numeric.Point(d.toDouble, x.toDouble)) }, + Some( + PointRenderer + .default[com.cibo.evilplot.numeric.Point](pointSize = Some(1), color = Some(HTMLNamedColors.gold)) + ) + ) + } + } + + // TODO: select a Constant from each m, then offer the table to QDS. + // todo: ? select rule: symmetry and draw a line parallel to the X-axis, how define the rule + lazy val tablesToQDS: Seq[Seq[Int]] = { + (aMin.toInt to aMax.toInt).drop(1).map { k => + k -> dSet.dropRight(1).map { d => + val (floor, ceil) = xRange(k, d, d + deltaD) + val m: Seq[Algebraic] = xSet.filter { x: Algebraic => x <= (ceil - deltaX) && x >= floor } + (d, m.head) + } + } + }.flatMap { + case (i, ps) => + ps.map { + case (x, y) => (x.toDouble, y.toDouble * (1 << xTruncateWidth.toInt)) + } + }.groupBy(_._1).toSeq.sortBy(_._1).map { case (x, y) => y.map { case (x, y) => y.toInt }.reverse } + + private val xStep = (xMax - xMin) / deltaX + // @note 5.7 + require(a >= radix / 2) + private val xSet = Seq.tabulate((xStep / 2 + 1).toInt) { n => deltaX * n } ++ Seq.tabulate((xStep / 2 + 1).toInt) { + n => -deltaX * n + } + + private val dStep: Algebraic = (dMax - dMin) / deltaD + assert((rho > 1 / 2) && (rho <= 1)) + private val dSet = Seq.tabulate((dStep + 1).toInt) { n => dMin + deltaD * n } + + private val mesh = + ScatterPlot( + xSet.flatMap { y => + dSet.map { x => + com.cibo.evilplot.numeric.Point(x.toDouble, y.toDouble) + } + }, + Some( + PointRenderer + .default[com.cibo.evilplot.numeric.Point](pointSize = Some(0.5), color = Some(HTMLNamedColors.gray)) + ) + ) + + override def toString: String = + s"SRT${radix.toInt} with quotient set: from ${aMin.toInt} to ${aMax.toInt}" + + /** Robertson Diagram + * + * @note Graph 5.17(a) + */ + def robertson(d: Algebraic): Plot = { + require(d > dMin && d < dMax) + Overlay((aMin.toBigInt to aMax.toBigInt).map { k: BigInt => + FunctionPlot.series( + _ - (Algebraic(k) * d).toDouble, + s"$k", + HTMLNamedColors.black, + xbounds = Some(Bounds(((Algebraic(k) - rho) * d).toDouble, ((Algebraic(k) + rho) * d).toDouble)) + ) + }: _*) + .title(s"Robertson Graph of $this divisor: $d") + .xLabel("rω[j]") + .yLabel("ω[j+1]") + .xbounds((-radix * rho * dMax).toDouble, (radix * rho * dMax).toDouble) + .ybounds((-rho * d).toDouble, (rho * d).toDouble) + .rightLegend() + .standard() + } + + def dumpGraph(plot: Plot, path: Path) = { + javax.imageio.ImageIO.write( + plot.render().asBufferedImage, + "png", + path.wrapped.toFile + ) + } + + // select four points, then drop the first and the last one. + /** for range `dLeft` to `dRight`, return the `rOmegaCeil` and `rOmegaFloor` + * this is used for constructing the rectangle where m_k(i) is located. + */ + private def xRange(k: Algebraic, dLeft: Algebraic, dRight: Algebraic): (Algebraic, Algebraic) = { + Seq(L(k, dLeft), L(k, dRight), U(k - 1, dLeft), U(k - 1, dRight)) + // not safe + .sortBy(_.toDouble) + .drop(1) + .dropRight(1) match { case Seq(l, r) => (l, r) } + } + + // U_k = (k + rho) * d, L_k = (k - rho) * d + /** find the intersection point between L`k` and `d` */ + private def L(k: Algebraic, d: Algebraic): Algebraic = lRate(k) * d + + /** slope factor of L_k + * + * @note 5.56 + */ + private def lRate(k: Algebraic): Algebraic = k - rho + + /** find the intersection point between U`k` and `d` */ + private def U(k: Algebraic, d: Algebraic): Algebraic = uRate(k) * d + + /** slope factor of U_k + * + * @note 5.56 + */ + private def uRate(k: Algebraic): Algebraic = k + rho +} diff --git a/arithmetic/src/division/srt/srt16/OTF.scala b/arithmetic/src/division/srt/srt16/OTF.scala new file mode 100644 index 0000000..d9d6d8b --- /dev/null +++ b/arithmetic/src/division/srt/srt16/OTF.scala @@ -0,0 +1,50 @@ +package division.srt.srt16 + +import division.srt._ +import chisel3._ +import chisel3.util.Mux1H + +class OTF(radixLog2: Int, qWidth: Int, ohWidth: Int) extends Module { + val input = IO(Input(new OTFInput(qWidth, ohWidth))) + val output = IO(Output(new OTFOutput(qWidth))) + + val radix: Int = 1 << radixLog2 + // datapath + // q_j+1 in this circle, only for srt4 + val qNext: UInt = Mux1H( + Seq( + input.selectedQuotientOH(0) -> "b110".U, + input.selectedQuotientOH(1) -> "b111".U, + input.selectedQuotientOH(2) -> "b000".U, + input.selectedQuotientOH(3) -> "b001".U, + input.selectedQuotientOH(4) -> "b010".U + ) + ) + + // val cShiftQ: Bool = qNext >= 0.U + // val cShiftQM: Bool = qNext <= 0.U + val cShiftQ: Bool = input.selectedQuotientOH(ohWidth - 1, ohWidth / 2).orR + val cShiftQM: Bool = input.selectedQuotientOH(ohWidth / 2, 0).orR + val qIn: UInt = Mux(cShiftQ, qNext, radix.U + qNext)(radixLog2 - 1, 0) + val qmIn: UInt = Mux(!cShiftQM, qNext - 1.U, (radix - 1).U + qNext)(radixLog2 - 1, 0) + + output.quotient := Mux(cShiftQ, input.quotient, input.quotientMinusOne)(qWidth - radixLog2, 0) ## qIn + output.quotientMinusOne := Mux(!cShiftQM, input.quotient, input.quotientMinusOne)(qWidth - radixLog2, 0) ## qmIn +} + +object OTF { + def apply( + radixLog2: Int, + qWidth: Int, + ohWidth: Int + )(quotient: UInt, + quotientMinusOne: UInt, + selectedQuotientOH: UInt + ): Vec[UInt] = { + val m = Module(new OTF(radixLog2, qWidth, ohWidth)) + m.input.quotient := quotient + m.input.quotientMinusOne := quotientMinusOne + m.input.selectedQuotientOH := selectedQuotientOH + VecInit(m.output.quotient, m.output.quotientMinusOne) + } +} diff --git a/arithmetic/src/division/srt/srt16/QDS.scala b/arithmetic/src/division/srt/srt16/QDS.scala new file mode 100644 index 0000000..bca9e6e --- /dev/null +++ b/arithmetic/src/division/srt/srt16/QDS.scala @@ -0,0 +1,63 @@ +package division.srt.srt16 + +import division.srt._ +import chisel3._ +import chisel3.util.BitPat +import chisel3.util.BitPat.bitPatToUInt +import chisel3.util.experimental.decode._ +import utils.{extend, sIntToBitPat} + +class QDS(rWidth: Int, ohWidth: Int, partialDividerWidth: Int, tables: Seq[Seq[Int]]) extends Module { + // IO + val input = IO(Input(new QDSInput(rWidth, partialDividerWidth))) + val output = IO(Output(new QDSOutput(ohWidth))) + + // get from SRTTable. + lazy val selectRom = VecInit(tables.map { + case x => + VecInit(x.map { + case x => bitPatToUInt(sIntToBitPat(-x, rWidth)) + }) + }) + + val columnSelect = input.partialDivider + val adderWidth = rWidth + 1 + val yTruncate: UInt = input.partialReminderCarry + input.partialReminderSum + val mkVec = selectRom(columnSelect) + val selectPoints = VecInit(mkVec.map { mk => + (extend(yTruncate, adderWidth).asUInt + + extend(mk, adderWidth).asUInt).head(1) + }).asUInt + + // decoder or findFirstOne here, prefer decoder, the decoder only for srt4 + output.selectedQuotientOH := chisel3.util.experimental.decode.decoder( + selectPoints, + TruthTable( + Seq( + BitPat("b???0") -> BitPat("b10000"), //2 + BitPat("b??01") -> BitPat("b01000"), //1 + BitPat("b?011") -> BitPat("b00100"), //0 + BitPat("b0111") -> BitPat("b00010") //-1 + ), + BitPat("b00001") //-2 + ) + ) +} + +object QDS { + def apply( + rWidth: Int, + ohWidth: Int, + partialDividerWidth: Int, + tables: Seq[Seq[Int]] + )(partialReminderSum: UInt, + partialReminderCarry: UInt, + partialDivider: UInt + ): UInt = { + val m = Module(new QDS(rWidth, ohWidth, partialDividerWidth, tables)) + m.input.partialReminderSum := partialReminderSum + m.input.partialReminderCarry := partialReminderCarry + m.input.partialDivider := partialDivider + m.output.selectedQuotientOH + } +} diff --git a/arithmetic/src/division/srt/srt16/SRT16.scala b/arithmetic/src/division/srt/srt16/SRT16.scala new file mode 100644 index 0000000..7ea12f1 --- /dev/null +++ b/arithmetic/src/division/srt/srt16/SRT16.scala @@ -0,0 +1,140 @@ +package division.srt.srt16 + +import division.srt._ +import chisel3._ +import chisel3.util.{log2Ceil, DecoupledIO, Fill, Mux1H, RegEnable, ValidIO} +import utils.leftShift + +/** RSRT16 with Two SRT4 Overlapped Stages + * n>=7 + * Reuse parameters, OTF and QDS of srt4 + */ +class SRT16( + dividendWidth: Int, + dividerWidth: Int, + n: Int, // the longest width + radixLog2: Int = 2, + a: Int = 2, + dTruncateWidth: Int = 4, + rTruncateWidth: Int = 4) + extends Module { + + val xLen: Int = dividendWidth + radixLog2 + 1 + val wLen: Int = xLen + radixLog2 + val ohWidth: Int = 2 * a + 1 + val rWidth: Int = 1 + radixLog2 + rTruncateWidth + + // IO + val input = IO(Flipped(DecoupledIO(new SRTInput(dividendWidth, dividerWidth, n)))) + val output = IO(ValidIO(new SRTOutput(dividerWidth, dividendWidth))) + + val partialReminderCarryNext, partialReminderSumNext = Wire(UInt(wLen.W)) + val dividerNext = Wire(UInt(dividerWidth.W)) + val counterNext = Wire(UInt(log2Ceil(n).W)) + val quotientNext, quotientMinusOneNext = Wire(UInt(n.W)) + + // Control + val isLastCycle, enable: Bool = Wire(Bool()) + // State + // because we need a CSA to minimize the critical path + val partialReminderCarry = RegEnable(partialReminderCarryNext, 0.U(wLen.W), enable) + val partialReminderSum = RegEnable(partialReminderSumNext, 0.U(wLen.W), enable) + val divider = RegEnable(dividerNext, 0.U(dividerWidth.W), enable) + val quotient = RegEnable(quotientNext, 0.U(n.W), enable) + val quotientMinusOne = RegEnable(quotientMinusOneNext, 0.U(n.W), enable) + val counter = RegEnable(counterNext, 0.U(log2Ceil(n).W), enable) + + // Datapath + isLastCycle := !counter.orR + output.valid := isLastCycle + input.ready := isLastCycle + enable := input.fire || !isLastCycle + + val remainderNoCorrect: UInt = partialReminderSum + partialReminderCarry + val remainderCorrect: UInt = + partialReminderSum + partialReminderCarry + (divider << radixLog2) + val needCorrect: Bool = remainderNoCorrect(wLen - 3).asBool + output.bits.reminder := Mux(needCorrect, remainderCorrect, remainderNoCorrect)(wLen - 4, radixLog2) + output.bits.quotient := Mux(needCorrect, quotientMinusOne, quotient) + + // 5*CSA32 SRT16 <- SRT4 + SRT4*5 /SRT16 -> CSA53+CSA32 + val dividerMap = VecInit((-2 to 2).map { + case -2 => divider << 1 + case -1 => divider + case 0 => 0.U + case 1 => Fill(1 + radixLog2, 1.U(1.W)) ## ~divider + case 2 => Fill(radixLog2, 1.U(1.W)) ## ~(divider << 1) + }) + val csa0InWidth = rWidth + radixLog2 + 1 + val csaIn1 = leftShift(partialReminderSum, radixLog2).head(csa0InWidth) + val csaIn2 = leftShift(partialReminderCarry, radixLog2).head(csa0InWidth) + + val csa1 = addition.csa.c32(VecInit(csaIn1, csaIn2, dividerMap(0).head(csa0InWidth))) // -2 csain 10bit + val csa2 = addition.csa.c32(VecInit(csaIn1, csaIn2, dividerMap(1).head(csa0InWidth))) // -1 + val csa3 = addition.csa.c32(VecInit(csaIn1, csaIn2, dividerMap(2).head(csa0InWidth))) // 0 + val csa4 = addition.csa.c32(VecInit(csaIn1, csaIn2, dividerMap(3).head(csa0InWidth))) // 1 + val csa5 = addition.csa.c32(VecInit(csaIn1, csaIn2, dividerMap(4).head(csa0InWidth))) // 2 + + // qds + val tables: Seq[Seq[Int]] = SRTTable(1 << radixLog2, a, dTruncateWidth, rTruncateWidth).tablesToQDS + val partialDivider: UInt = dividerNext.head(dTruncateWidth)(dTruncateWidth - 2, 0) + val qdsOH0: UInt = + QDS(rWidth, ohWidth, dTruncateWidth - 1, tables)( + leftShift(partialReminderSum, radixLog2).head(rWidth), + leftShift(partialReminderCarry, radixLog2).head(rWidth), + partialDivider + ) // q_j+1 oneHot + + def qds(a: Vec[UInt]): UInt = { + QDS(rWidth, ohWidth, dTruncateWidth - 1, tables)( + leftShift(a(1), radixLog2).head(rWidth), + leftShift(a(0), radixLog2 + 1).head(rWidth), + partialDivider + ) + } + // q_j+2 oneHot precompute + val qds1SelectedQuotientOH: UInt = qds(csa1) // -2 + val qds2SelectedQuotientOH: UInt = qds(csa2) // -1 + val qds3SelectedQuotientOH: UInt = qds(csa3) // 0 + val qds4SelectedQuotientOH: UInt = qds(csa4) // 1 + val qds5SelectedQuotientOH: UInt = qds(csa5) // 2 + + val qds1SelectedQuotientOHMap = VecInit((-2 to 2).map { + case -2 => qds1SelectedQuotientOH + case -1 => qds2SelectedQuotientOH + case 0 => qds3SelectedQuotientOH + case 1 => qds4SelectedQuotientOH + case 2 => qds5SelectedQuotientOH + }) + + val qdsOH1 = Mux1H(qdsOH0, qds1SelectedQuotientOHMap) // q_j+2 oneHot + val qds0sign = qdsOH0(ohWidth - 1, ohWidth / 2 + 1).orR + val qds1sign = qdsOH1(ohWidth - 1, ohWidth / 2 + 1).orR + + val csa0Out = addition.csa.c32( + VecInit( + leftShift(partialReminderSum, radixLog2).head(wLen - radixLog2), + leftShift(partialReminderCarry, radixLog2).head(wLen - radixLog2 - 1) ## qds0sign, + Mux1H(qdsOH0, dividerMap) + ) + ) + val csa1Out = addition.csa.c32( + VecInit( + leftShift(csa0Out(1), radixLog2).head(wLen - radixLog2), + leftShift(csa0Out(0), radixLog2 + 1).head(wLen - radixLog2 - 1) ## qds1sign, + Mux1H(qdsOH1, dividerMap) + ) + ) + + // On-The-Fly conversion + // todo?: OTF input: Q, QM1, (q1 << 2 + q2) output: Q,QM1 + val otf0 = OTF(radixLog2, n, ohWidth)(quotient, quotientMinusOne, qdsOH0) + val otf1 = OTF(radixLog2, n, ohWidth)(otf0(0), otf0(1), qdsOH1) + + dividerNext := Mux(input.fire, input.bits.divider, divider) + counterNext := Mux(input.fire, input.bits.counter, counter - 1.U) + quotientNext := Mux(input.fire, 0.U, otf1(0)) + quotientMinusOneNext := Mux(input.fire, 0.U, otf1(1)) + partialReminderSumNext := Mux(input.fire, input.bits.dividend, csa1Out(1) << radixLog2) + partialReminderCarryNext := Mux(input.fire, 0.U, csa1Out(0) << radixLog2 + 1) +} diff --git a/arithmetic/src/division/srt/srt4/OTF.scala b/arithmetic/src/division/srt/srt4/OTF.scala new file mode 100644 index 0000000..f89106a --- /dev/null +++ b/arithmetic/src/division/srt/srt4/OTF.scala @@ -0,0 +1,74 @@ +package division.srt.srt4 + +import division.srt._ +import chisel3._ +import chisel3.util.Mux1H + +class OTF(radixLog2: Int, qWidth: Int, ohWidth: Int, a: Int) extends Module { + val input = IO(Input(new OTFInput(qWidth, ohWidth))) + val output = IO(Output(new OTFOutput(qWidth))) + + val radix: Int = 1 << radixLog2 + // datapath + // q_j+1 in this circle, only for srt4 + // val cShiftQ: Bool = qNext >= 0.U + // val cShiftQM: Bool = qNext <= 0.U + val qNext: UInt = Wire(UInt(3.W)) + val cShiftQ, cShiftQM = Wire(Bool()) + + if (a == 2) { + qNext := Mux1H( + Seq( + input.selectedQuotientOH(0) -> "b110".U, //-2 + input.selectedQuotientOH(1) -> "b111".U, //-1 + input.selectedQuotientOH(2) -> "b000".U, // 0 + input.selectedQuotientOH(3) -> "b001".U, // 1 + input.selectedQuotientOH(4) -> "b010".U // 2 + ) + ) + cShiftQ := input.selectedQuotientOH(ohWidth - 1, ohWidth / 2).orR + cShiftQM := input.selectedQuotientOH(ohWidth / 2, 0).orR + } else if (a == 3) { + qNext := Mux1H( + Seq( + input.selectedQuotientOH(0) -> "b111".U, //-1 + input.selectedQuotientOH(1) -> "b000".U, // 0 + input.selectedQuotientOH(2) -> "b001".U // 1 + ) + ) + Mux1H( + Seq( + input.selectedQuotientOH(3) -> "b110".U, // -2 + input.selectedQuotientOH(4) -> "b000".U, // 0 + input.selectedQuotientOH(5) -> "b010".U // 2 + ) + ) + cShiftQ := input.selectedQuotientOH(5) || + (input.selectedQuotientOH(4) && input.selectedQuotientOH(2, 1).orR) + cShiftQM := input.selectedQuotientOH(3) || + (input.selectedQuotientOH(4) && input.selectedQuotientOH(1, 0).orR) + } + + val qIn: UInt = Mux(cShiftQ, qNext, radix.U + qNext)(radixLog2 - 1, 0) + val qmIn: UInt = Mux(!cShiftQM, qNext - 1.U, (radix - 1).U + qNext)(radixLog2 - 1, 0) + + output.quotient := Mux(cShiftQ, input.quotient, input.quotientMinusOne)(qWidth - radixLog2, 0) ## qIn + output.quotientMinusOne := Mux(!cShiftQM, input.quotient, input.quotientMinusOne)(qWidth - radixLog2, 0) ## qmIn +} + +object OTF { + def apply( + radixLog2: Int, + qWidth: Int, + ohWidth: Int, + a: Int + )(quotient: UInt, + quotientMinusOne: UInt, + selectedQuotientOH: UInt + ): Vec[UInt] = { + val m = Module(new OTF(radixLog2, qWidth, ohWidth, a)) + m.input.quotient := quotient + m.input.quotientMinusOne := quotientMinusOne + m.input.selectedQuotientOH := selectedQuotientOH + VecInit(m.output.quotient, m.output.quotientMinusOne) + } +} diff --git a/arithmetic/src/division/srt/srt4/QDS.scala b/arithmetic/src/division/srt/srt4/QDS.scala new file mode 100644 index 0000000..fa40757 --- /dev/null +++ b/arithmetic/src/division/srt/srt4/QDS.scala @@ -0,0 +1,101 @@ +package division.srt.srt4 + +import division.srt._ +import chisel3._ +import chisel3.util.BitPat +import chisel3.util.BitPat.bitPatToUInt +import chisel3.util.experimental.decode.TruthTable +import utils.{extend, sIntToBitPat} + +class QDS(rWidth: Int, ohWidth: Int, partialDividerWidth: Int, tables: Seq[Seq[Int]], a: Int) extends Module { + // IO + val input = IO(Input(new QDSInput(rWidth, partialDividerWidth))) + val output = IO(Output(new QDSOutput(ohWidth))) + + // from P269 in : /16, should have got from SRTTable. + // val qSelTable = Array( + // Array(12, 4, -4, -13), + // Array(14, 4, -6, -15), + // Array(15, 4, -6, -16), + // Array(16, 4, -6, -18), + // Array(18, 6, -8, -20), + // Array(20, 6, -8, -20), + // Array(20, 8, -8, -22), + // Array(24, 8, -8, -24)/16 + // ) + // val selectRom: Vec[Vec[UInt]] = VecInit( + // VecInit("b111_0100".U, "b111_1100".U, "b000_0100".U, "b000_1101".U), + // VecInit("b111_0010".U, "b111_1100".U, "b000_0110".U, "b000_1111".U), + // VecInit("b111_0001".U, "b111_1100".U, "b000_0110".U, "b001_0000".U), + // VecInit("b111_0000".U, "b111_1100".U, "b000_0110".U, "b001_0010".U), + // VecInit("b110_1110".U, "b111_1010".U, "b000_1000".U, "b001_0100".U), + // VecInit("b110_1100".U, "b111_1010".U, "b000_1000".U, "b001_0100".U), + // VecInit("b110_1100".U, "b111_1000".U, "b000_1000".U, "b001_0110".U), + // VecInit("b110_1000".U, "b111_1000".U, "b000_1000".U, "b001_1000".U) + // ) + + // get from SRTTable. + lazy val selectRom = VecInit(tables.map { + case x => + VecInit(x.map { + case x => bitPatToUInt(sIntToBitPat(-x, rWidth)) + }) + }) + + val columnSelect = input.partialDivider + val adderWidth = rWidth + 1 + val yTruncate: UInt = input.partialReminderCarry + input.partialReminderSum + val mkVec = selectRom(columnSelect) + val selectPoints = VecInit(mkVec.map { mk => + (extend(yTruncate, adderWidth).asUInt + + extend(mk, adderWidth).asUInt).head(1) + }).asUInt + + // decoder or findFirstOne here, prefer decoder, the decoder only for srt4 + output.selectedQuotientOH := chisel3.util.experimental.decode.decoder( + selectPoints, + a match { + case 2 => + TruthTable( + Seq( + BitPat("b???0") -> BitPat("b10000"), //2 + BitPat("b??01") -> BitPat("b01000"), //1 + BitPat("b?011") -> BitPat("b00100"), //0 + BitPat("b0111") -> BitPat("b00010") //-1 + ), + BitPat("b00001") //-2 + ) + case 3 => + TruthTable( + Seq( // 2 0 -2 1 0 -1 + BitPat("b??_???0") -> BitPat("b100_100"), //3 = 2 + 1 + BitPat("b??_??01") -> BitPat("b100_010"), //2 = 2 + 0 + BitPat("b??_?011") -> BitPat("b010_100"), //1 = 0 + 1 + BitPat("b??_0111") -> BitPat("b010_010"), //0 = 0 + 0 + BitPat("b?0_1111") -> BitPat("b010_001"), //-1 = 0 + -1 + BitPat("b01_1111") -> BitPat("b001_010") //-2 = -2 + 0 + ), + BitPat("b001_001") //-3 = -2 + -1 + ) + } + ) +} + +object QDS { + def apply( + rWidth: Int, + ohWidth: Int, + partialDividerWidth: Int, + tables: Seq[Seq[Int]], + a: Int + )(partialReminderSum: UInt, + partialReminderCarry: UInt, + partialDivider: UInt + ): UInt = { + val m = Module(new QDS(rWidth, ohWidth, partialDividerWidth, tables, a)) + m.input.partialReminderSum := partialReminderSum + m.input.partialReminderCarry := partialReminderCarry + m.input.partialDivider := partialDivider + m.output.selectedQuotientOH + } +} diff --git a/arithmetic/src/division/srt/srt4/SRT4.scala b/arithmetic/src/division/srt/srt4/SRT4.scala new file mode 100644 index 0000000..f10dfaa --- /dev/null +++ b/arithmetic/src/division/srt/srt4/SRT4.scala @@ -0,0 +1,141 @@ +package division.srt.srt4 + +import division.srt._ +import addition.csa.CarrySaveAdder +import addition.csa.common.CSACompressor3_2 +import chisel3._ +import chisel3.util._ +import spire.math +import utils.leftShift + +/** SRT4 + * 1/2 <= d < 1, 1/2 < rho <=1, 0 < q < 2 + * radix = 4 + * a = 2, {-2, -1, 0, 1, -2}, + * dTruncateWidth = 4, rTruncateWidth = 8 + * y^(xxx.xxxx), d^(0.1xxx) + * -44/16 < y^ < 42/16 + * floor((-r*rho - 2^-t)_t) <= y^ <= floor((r*rho - ulp)_t) + */ + +class SRT4( + dividendWidth: Int, + dividerWidth: Int, + n: Int, // the longest width + radixLog2: Int = 2, + a: Int = 2, + dTruncateWidth: Int = 4, + rTruncateWidth: Int = 4) + extends Module { + val xLen: Int = dividendWidth + radixLog2 + 1 + val wLen: Int = xLen + radixLog2 + // IO + val input = IO(Flipped(DecoupledIO(new SRTInput(dividendWidth, dividerWidth, n)))) + val output = IO(ValidIO(new SRTOutput(dividerWidth, dividendWidth))) + + val partialReminderCarryNext, partialReminderSumNext = Wire(UInt(wLen.W)) + val quotientNext, quotientMinusOneNext = Wire(UInt(n.W)) + val dividerNext = Wire(UInt(dividerWidth.W)) + val counterNext = Wire(UInt(log2Ceil(n).W)) + + // Control + // sign of Cycle, true -> (counter === 0.U) + val isLastCycle, enable: Bool = Wire(Bool()) + + // State + // because we need a CSA to minimize the critical path + val partialReminderCarry = RegEnable(partialReminderCarryNext, 0.U(wLen.W), enable) + val partialReminderSum = RegEnable(partialReminderSumNext, 0.U(wLen.W), enable) + val divider = RegEnable(dividerNext, 0.U(dividerWidth.W), enable) + val quotient = RegEnable(quotientNext, 0.U(n.W), enable) + val quotientMinusOne = RegEnable(quotientMinusOneNext, 0.U(n.W), enable) + val counter = RegEnable(counterNext, 0.U(log2Ceil(n).W), enable) + + // Datapath + // according two adders + isLastCycle := !counter.orR + output.valid := isLastCycle + input.ready := isLastCycle + enable := input.fire || !isLastCycle + + val remainderNoCorrect: UInt = partialReminderSum + partialReminderCarry + val remainderCorrect: UInt = + partialReminderSum + partialReminderCarry + (divider << radixLog2) + val needCorrect: Bool = remainderNoCorrect(wLen - 3).asBool + output.bits.reminder := Mux(needCorrect, remainderCorrect, remainderNoCorrect)(wLen - 4, radixLog2) + output.bits.quotient := Mux(needCorrect, quotientMinusOne, quotient) + + val rWidth: Int = 1 + radixLog2 + rTruncateWidth + val tables: Seq[Seq[Int]] = SRTTable(1 << radixLog2, a, dTruncateWidth, rTruncateWidth).tablesToQDS + val ohWidth: Int = a match { + case 2 => 2 * a + 1 + case 3 => 6 + } + //qds + val selectedQuotientOH: UInt = + QDS(rWidth, ohWidth, dTruncateWidth - 1, tables, a)( + leftShift(partialReminderSum, radixLog2).head(rWidth), + leftShift(partialReminderCarry, radixLog2).head(rWidth), + dividerNext.head(dTruncateWidth)(dTruncateWidth - 2, 0) //.1********* -> 1*** -> *** + ) + // On-The-Fly conversion + val otf = OTF(radixLog2, n, ohWidth, a)(quotient, quotientMinusOne, selectedQuotientOH) + + val csa: Vec[UInt] = + if (a == 2) { // a == 2 + //csa + val dividerMap = VecInit((-2 to 2).map { + case -2 => divider << 1 + case -1 => divider + case 0 => 0.U + case 1 => Fill(1 + radixLog2, 1.U(1.W)) ## ~divider + case 2 => Fill(radixLog2, 1.U(1.W)) ## ~(divider << 1) + }) + val qdsSign = selectedQuotientOH(ohWidth - 1, ohWidth / 2 + 1).orR + addition.csa.c32( + VecInit( + leftShift(partialReminderSum, radixLog2).head(wLen - radixLog2), + leftShift(partialReminderCarry, radixLog2).head(wLen - radixLog2 - 1) ## qdsSign, + Mux1H(selectedQuotientOH, dividerMap) + ) + ) + } else { // a==3 + val qHigh = selectedQuotientOH(5, 3) + val qLow = selectedQuotientOH(2, 0) + val qds0Sign = qHigh.head(1) + val qds1Sign = qLow.head(1) + + // csa + val dividerHMap = VecInit((-1 to 1).map { + case -1 => divider << 1 // -2 + case 0 => 0.U // 0 + case 1 => Fill(radixLog2, 1.U(1.W)) ## ~(divider << 1) // 2 + }) + val dividerLMap = VecInit((-1 to 1).map { + case -1 => divider // -1 + case 0 => 0.U // 0 + case 1 => Fill(1 + radixLog2, 1.U(1.W)) ## ~divider // 1 + }) + val csa0 = addition.csa.c32( + VecInit( + leftShift(partialReminderSum, radixLog2).head(wLen - radixLog2), + leftShift(partialReminderCarry, radixLog2).head(wLen - radixLog2 - 1) ## qds0Sign, + Mux1H(qHigh, dividerHMap) + ) + ) + addition.csa.c32( + VecInit( + csa0(1).head(wLen - radixLog2), + leftShift(csa0(0), 1).head(wLen - radixLog2 - 1) ## qds1Sign, + Mux1H(qLow, dividerLMap) + ) + ) + } + + dividerNext := Mux(input.fire, input.bits.divider, divider) + counterNext := Mux(input.fire, input.bits.counter, counter - 1.U) + quotientNext := Mux(input.fire, 0.U, otf(0)) + quotientMinusOneNext := Mux(input.fire, 0.U, otf(1)) + partialReminderSumNext := Mux(input.fire, input.bits.dividend, csa(1) << radixLog2) + partialReminderCarryNext := Mux(input.fire, 0.U, csa(0) << 1 + radixLog2) +} diff --git a/arithmetic/src/division/srt/srt8/OTF.scala b/arithmetic/src/division/srt/srt8/OTF.scala new file mode 100644 index 0000000..640a67f --- /dev/null +++ b/arithmetic/src/division/srt/srt8/OTF.scala @@ -0,0 +1,126 @@ +package division.srt.srt8 + +import division.srt._ +import chisel3._ +import chisel3.util.Mux1H + +class OTF(radixLog2: Int, qWidth: Int, ohWidth: Int, a: Int) extends Module { + val input = IO(Input(new OTFInput(qWidth, ohWidth))) + val output = IO(Output(new OTFOutput(qWidth))) + + val radix: Int = 1 << radixLog2 + // datapath + // q_j+1 in this circle + // val cShiftQ: Bool = qNext >= 0.U + // val cShiftQM: Bool = qNext <= 0.U + val qNext: UInt = Wire(UInt(5.W)) + val cShiftQ, cShiftQM: Bool = Wire(Bool()) + + if (a == 7) { + qNext := Mux1H( + Seq( + input.selectedQuotientOH(0) -> "b11110".U, // -2 + input.selectedQuotientOH(1) -> "b11111".U, // -1 + input.selectedQuotientOH(2) -> "b00000".U, // 0 + input.selectedQuotientOH(3) -> "b00001".U, // 1 + input.selectedQuotientOH(4) -> "b00010".U // 2 + ) + ) + Mux1H( + Seq( + input.selectedQuotientOH(5) -> "b11000".U, // -8 + input.selectedQuotientOH(6) -> "b11100".U, // -4 + input.selectedQuotientOH(7) -> "b00000".U, // 0 + input.selectedQuotientOH(8) -> "b00100".U, // 4 + input.selectedQuotientOH(9) -> "b01000".U // 8 + ) + ) + cShiftQ := input.selectedQuotientOH(9, 8).orR || + (input.selectedQuotientOH(7) && input.selectedQuotientOH(4, 2).orR) + cShiftQM := input.selectedQuotientOH(6, 5).orR || + (input.selectedQuotientOH(7) && input.selectedQuotientOH(2, 0).orR) + } else if (a == 6) { + qNext := Mux1H( + Seq( + input.selectedQuotientOH(0) -> "b11110".U, // -2 + input.selectedQuotientOH(1) -> "b11111".U, // -1 + input.selectedQuotientOH(2) -> "b00000".U, // 0 + input.selectedQuotientOH(3) -> "b00001".U, // 1 + input.selectedQuotientOH(4) -> "b00010".U // 2 + ) + ) + Mux1H( + Seq( + input.selectedQuotientOH(5) -> "b11100".U, // -4 + input.selectedQuotientOH(6) -> "b00000".U, // 0 + input.selectedQuotientOH(7) -> "b00100".U // 4 + ) + ) + cShiftQ := input.selectedQuotientOH(7) || + (input.selectedQuotientOH(6) && input.selectedQuotientOH(4, 2).orR) + cShiftQM := input.selectedQuotientOH(5) || + (input.selectedQuotientOH(6) && input.selectedQuotientOH(2, 0).orR) + } else if (a == 5) { + qNext := Mux1H( + Seq( + input.selectedQuotientOH(0) -> "b11110".U, // -2 + input.selectedQuotientOH(1) -> "b11111".U, // -1 + input.selectedQuotientOH(2) -> "b00000".U, // 0 + input.selectedQuotientOH(3) -> "b00001".U, // 1 + input.selectedQuotientOH(4) -> "b00010".U // 2 + ) + ) + Mux1H( + Seq( + input.selectedQuotientOH(5) -> "b11100".U, // -4 + input.selectedQuotientOH(6) -> "b00000".U, // 0 + input.selectedQuotientOH(7) -> "b00100".U // 4 + ) + ) + cShiftQ := input.selectedQuotientOH(7) || + (input.selectedQuotientOH(6) && input.selectedQuotientOH(4, 2).orR) + cShiftQM := input.selectedQuotientOH(5) || + (input.selectedQuotientOH(6) && input.selectedQuotientOH(2, 0).orR) + } else if (a == 4) { + qNext := Mux1H( + Seq( + input.selectedQuotientOH(0) -> "b11110".U, // -2 + input.selectedQuotientOH(1) -> "b11111".U, // -1 + input.selectedQuotientOH(2) -> "b00000".U, // 0 + input.selectedQuotientOH(3) -> "b00001".U, // 1 + input.selectedQuotientOH(4) -> "b00010".U // 2 + ) + ) + Mux1H( + Seq( + input.selectedQuotientOH(5) -> "b11110".U, // -2 + input.selectedQuotientOH(6) -> "b00000".U, // 0 + input.selectedQuotientOH(7) -> "b00010".U // 2 + ) + ) + cShiftQ := input.selectedQuotientOH(7) || + (input.selectedQuotientOH(6) && input.selectedQuotientOH(3, 2).orR) + cShiftQM := input.selectedQuotientOH(5) || + (input.selectedQuotientOH(6) && input.selectedQuotientOH(2, 1).orR) + } + + val qIn: UInt = Mux(cShiftQ, qNext, radix.U + qNext)(radixLog2 - 1, 0) + val qmIn: UInt = Mux(!cShiftQM, qNext - 1.U, (radix - 1).U + qNext)(radixLog2 - 1, 0) + + output.quotient := Mux(cShiftQ, input.quotient, input.quotientMinusOne)(qWidth - radixLog2, 0) ## qIn + output.quotientMinusOne := Mux(!cShiftQM, input.quotient, input.quotientMinusOne)(qWidth - radixLog2, 0) ## qmIn +} + +object OTF { + def apply( + radixLog2: Int, + qWidth: Int, + ohWidth: Int, + a: Int + )(quotient: UInt, + quotientMinusOne: UInt, + selectedQuotientOH: UInt + ): Vec[UInt] = { + val m = Module(new OTF(radixLog2, qWidth, ohWidth, a)) + m.input.quotient := quotient + m.input.quotientMinusOne := quotientMinusOne + m.input.selectedQuotientOH := selectedQuotientOH + VecInit(m.output.quotient, m.output.quotientMinusOne) + } +} diff --git a/arithmetic/src/division/srt/srt8/QDS.scala b/arithmetic/src/division/srt/srt8/QDS.scala new file mode 100644 index 0000000..e9adc83 --- /dev/null +++ b/arithmetic/src/division/srt/srt8/QDS.scala @@ -0,0 +1,137 @@ +package division.srt.srt8 + +import division.srt._ +import chisel3._ +import chisel3.util.BitPat +import chisel3.util.BitPat.bitPatToUInt +import chisel3.util.experimental.decode.TruthTable +import utils.{extend, sIntToBitPat} + +class QDS(rWidth: Int, ohWidth: Int, partialDividerWidth: Int, tables: Seq[Seq[Int]], a: Int) extends Module { + // IO + val input = IO(Input(new QDSInput(rWidth, partialDividerWidth))) + val output = IO(Output(new QDSOutput(ohWidth))) + + val columnSelect = input.partialDivider + // Seq[Seq[Int]] => Vec[Vec[UInt]] + lazy val selectRom = VecInit(tables.map { + case x => + VecInit(x.map { + case x => bitPatToUInt(sIntToBitPat(-x, rWidth)) + }) + }) + + val adderWidth = rWidth + 1 + val yTruncate: UInt = input.partialReminderCarry + input.partialReminderSum + val mkVec = selectRom(columnSelect) + val selectPoints = VecInit(mkVec.map { mk => + (extend(yTruncate, adderWidth).asUInt + + extend(mk, adderWidth).asUInt).head(1) + }).asUInt + + output.selectedQuotientOH := chisel3.util.experimental.decode.decoder( + selectPoints, + a match { + case 7 => + TruthTable( + Seq( // 8 4 0 -4 -8__2 1 0 -1 -2 + BitPat("b??_????_????_???0") -> BitPat("b10000_00010"), // 7 = +8 + (-1) + BitPat("b??_????_????_??01") -> BitPat("b01000_10000"), // 6 = +4 + (+2) + BitPat("b??_????_????_?011") -> BitPat("b01000_01000"), // 5 = +4 + (+1) + BitPat("b??_????_????_0111") -> BitPat("b01000_00100"), // 4 = +4 + ( 0) + BitPat("b??_????_???0_1111") -> BitPat("b01000_00010"), // 3 = +4 + (-1) + BitPat("b??_????_??01_1111") -> BitPat("b00100_10000"), // 2 = 0 + (+2) + BitPat("b??_????_?011_1111") -> BitPat("b00100_01000"), // 1 = 0 + (+1) + BitPat("b??_????_0111_1111") -> BitPat("b00100_00100"), // 0 = 0 + ( 0) + BitPat("b??_???0_1111_1111") -> BitPat("b00100_00010"), //-1 = 0 + (-1) + BitPat("b??_??01_1111_1111") -> BitPat("b00100_00001"), //-2 = 0 + (-2) + BitPat("b??_?011_1111_1111") -> BitPat("b00010_01000"), //-3 = -4 + ( 1) + BitPat("b??_0111_1111_1111") -> BitPat("b00010_00100"), //-4 = -4 + ( 0) + BitPat("b?0_1111_1111_1111") -> BitPat("b00010_00010"), //-5 = -4 + (-1) + BitPat("b01_1111_1111_1111") -> BitPat("b00010_00001") // -6 = -4 + (-2) + ), + BitPat("b00001_01000") //-7 = -8 + (+1) + ) + case 6 => + TruthTable( + Seq( // 4 0 -4__2 1 0 -1 -2 + BitPat("b????_????_???0") -> BitPat("b100_10000"), // 6 = +4 + (+2) + BitPat("b????_????_??01") -> BitPat("b100_01000"), // 5 = +4 + (+1) + BitPat("b????_????_?011") -> BitPat("b100_00100"), // 4 = +4 + ( 0) + BitPat("b????_????_0111") -> BitPat("b100_00010"), // 3 = +4 + (-1) + BitPat("b????_???0_1111") -> BitPat("b010_10000"), // 2 = 0 + (+2) + BitPat("b????_??01_1111") -> BitPat("b010_01000"), // 1 = 0 + (+1) + BitPat("b????_?011_1111") -> BitPat("b010_00100"), // 0 = 0 + ( 0) + BitPat("b????_0111_1111") -> BitPat("b010_00010"), //-1 = 0 + (-1) + BitPat("b???0_1111_1111") -> BitPat("b010_00001"), //-2 = 0 + (-2) + BitPat("b??01_1111_1111") -> BitPat("b001_01000"), //-3 = -4 + ( 1) + BitPat("b?011_1111_1111") -> BitPat("b001_00100"), //-4 = -4 + ( 0) + BitPat("b0111_1111_1111") -> BitPat("b001_00010") // -5 = -4 + (-1) + ), + BitPat("b001_00001") //-6 = -4 + (-2) + ) + case 5 => + TruthTable( + Seq( // 4 0 -4__2 1 0 -1 -2 + BitPat("b??_????_???0") -> BitPat("b100_01000"), // 5 = +4 + (+1) + BitPat("b??_????_??01") -> BitPat("b100_00100"), // 4 = +4 + ( 0) + BitPat("b??_????_?011") -> BitPat("b100_00010"), // 3 = +4 + (-1) + BitPat("b??_????_0111") -> BitPat("b010_10000"), // 2 = 0 + (+2) + BitPat("b??_???0_1111") -> BitPat("b010_01000"), // 1 = 0 + (+1) + BitPat("b??_??01_1111") -> BitPat("b010_00100"), // 0 = 0 + ( 0) + BitPat("b??_?011_1111") -> BitPat("b010_00010"), //-1 = 0 + (-1) + BitPat("b??_0111_1111") -> BitPat("b010_00001"), //-2 = 0 + (-2) + BitPat("b?0_1111_1111") -> BitPat("b001_01000"), //-3 = -4 + ( 1) + BitPat("b01_1111_1111") -> BitPat("b001_00100") // -4 = -4 + ( 0) + ), + BitPat("b001_00010") //-5 = -4 + (-1) + ) + case 4 => + TruthTable( + Seq( // 2 0 -2__2 1 0 -1 -2 + BitPat("b????_???0") -> BitPat("b100_10000"), // 4 = +2 + ( 2) + BitPat("b????_??01") -> BitPat("b100_01000"), // 3 = +2 + ( 1) + BitPat("b????_?011") -> BitPat("b100_00100"), // 2 = 2 + ( 0) + BitPat("b????_0111") -> BitPat("b010_01000"), // 1 = 0 + (+1) + BitPat("b???0_1111") -> BitPat("b010_00100"), // 0 = 0 + ( 0) + BitPat("b??01_1111") -> BitPat("b010_00010"), //-1 = 0 + (-1) + BitPat("b?011_1111") -> BitPat("b001_00100"), //-2 = -2 + ( 0) + BitPat("b0111_1111") -> BitPat("b001_00010") // -3 = -2 + (-1) + ), + BitPat("b001_00001") //-4 = -2 + (-2) + ) + // TruthTable( + // Seq( // 4 0 -4__2 1 0 -1 -2 + // BitPat("b????_???0") -> BitPat("b100_00100"), // 4 = +4 + ( 0) + // BitPat("b????_??01") -> BitPat("b100_00010"), // 3 = +4 + (-1) + // BitPat("b????_?011") -> BitPat("b010_10000"), // 2 = 0 + (+2) + // BitPat("b????_0111") -> BitPat("b010_01000"), // 1 = 0 + (+1) + // BitPat("b???0_1111") -> BitPat("b010_00100"), // 0 = 0 + ( 0) + // BitPat("b??01_1111") -> BitPat("b010_00010"), //-1 = 0 + (-1) + // BitPat("b?011_1111") -> BitPat("b010_00001"), //-2 = 0 + (-2) + // BitPat("b0111_1111") -> BitPat("b001_01000") //-3 = -4 + ( 1) + // ), + // BitPat("b001_00100") //-4 = -4 + ( 0) + // ) + } + ) +} + +object QDS { + def apply( + rWidth: Int, + ohWidth: Int, + partialDividerWidth: Int, + tables: Seq[Seq[Int]], + a: Int + )(partialReminderSum: UInt, + partialReminderCarry: UInt, + partialDivider: UInt + ): UInt = { + val m = Module(new QDS(rWidth, ohWidth, partialDividerWidth, tables, a)) + m.input.partialReminderSum := partialReminderSum + m.input.partialReminderCarry := partialReminderCarry + m.input.partialDivider := partialDivider + m.output.selectedQuotientOH + } +} diff --git a/arithmetic/src/division/srt/srt8/SRT8.scala b/arithmetic/src/division/srt/srt8/SRT8.scala new file mode 100644 index 0000000..f45196d --- /dev/null +++ b/arithmetic/src/division/srt/srt8/SRT8.scala @@ -0,0 +1,216 @@ +package division.srt.srt8 + +import division.srt._ +import division.srt.SRTTable +import chisel3._ +import chisel3.util._ +import utils.leftShift + +/** SRT8 + * 1/2 <= d < 1, 1/2 < rho <=1, 0 < q < 2 + * radix = 8 + * a = 7, {-7, ... ,-2, -1, 0, 1, 2, ... 7}, + * dTruncateWidth = 4, rTruncateWidth = 4 + * y^(xxxx.xxxx), d^(0.1xxx) + * table from SRTTable + * -129/16 < y^ < 127/16 + * floor((-r*rho - 2^-t)_t) <= y^ <= floor((r*rho - ulp)_t) + */ + +class SRT8( + dividendWidth: Int, + dividerWidth: Int, + n: Int, // the longest width + radixLog2: Int = 3, + a: Int = 7, + dTruncateWidth: Int = 4, + rTruncateWidth: Int = 4) + extends Module { + + val xLen: Int = dividendWidth + radixLog2 + 1 + val wLen: Int = xLen + radixLog2 + + // IO + val input = IO(Flipped(DecoupledIO(new SRTInput(dividendWidth, dividerWidth, n)))) + val output = IO(ValidIO(new SRTOutput(dividerWidth, dividendWidth))) + + val partialReminderCarryNext, partialReminderSumNext = Wire(UInt(wLen.W)) + val quotientNext, quotientMinusOneNext = Wire(UInt(n.W)) + val dividerNext = Wire(UInt(dividerWidth.W)) + val counterNext = Wire(UInt(log2Ceil(n).W)) + + // Control + // sign of select quotient, true -> negative, false -> positive + // sign of Cycle, true -> (counter === 0.U) + val isLastCycle, enable: Bool = Wire(Bool()) + + // State + // because we need a CSA to minimize the critical path + val partialReminderCarry = RegEnable(partialReminderCarryNext, 0.U(wLen.W), enable) + val partialReminderSum = RegEnable(partialReminderSumNext, 0.U(wLen.W), enable) + val divider = RegEnable(dividerNext, 0.U(dividerWidth.W), enable) + val quotient = RegEnable(quotientNext, 0.U(n.W), enable) + val quotientMinusOne = RegEnable(quotientMinusOneNext, 0.U(n.W), enable) + val counter = RegEnable(counterNext, 0.U(log2Ceil(n).W), enable) + + // Datapath + // according two adders + isLastCycle := !counter.orR + output.valid := isLastCycle + input.ready := isLastCycle + enable := input.fire || !isLastCycle + + val remainderNoCorrect: UInt = partialReminderSum + partialReminderCarry + val remainderCorrect: UInt = + partialReminderSum + partialReminderCarry + (divider << radixLog2) + val needCorrect: Bool = remainderNoCorrect(wLen - 4).asBool + output.bits.reminder := Mux(needCorrect, remainderCorrect, remainderNoCorrect)(wLen - 5, radixLog2) + output.bits.quotient := Mux(needCorrect, quotientMinusOne, quotient) + + val rWidth: Int = 1 + radixLog2 + rTruncateWidth + val tables: Seq[Seq[Int]] = SRTTable(1 << radixLog2, a, dTruncateWidth, rTruncateWidth).tablesToQDS + + val ohWidth: Int = a match { + case 7 => 10 + case 6 => 8 + case 5 => 8 + case 4 => 8 + } + // qds + val selectedQuotientOH: UInt = + QDS(rWidth, ohWidth, dTruncateWidth - 1, tables, a)( + leftShift(partialReminderSum, radixLog2).head(rWidth), + leftShift(partialReminderCarry, radixLog2).head(rWidth), + dividerNext.head(dTruncateWidth)(dTruncateWidth - 2, 0) //.1********* -> 1*** -> *** + ) + // On-The-Fly conversion + val otf = OTF(radixLog2, n, ohWidth, a)(quotient, quotientMinusOne, selectedQuotientOH) + + val dividerLMap = VecInit((-2 to 2).map { + case -2 => divider << 1 // -2 + case -1 => divider // -1 + case 0 => 0.U // 0 + case 1 => Fill(1 + radixLog2, 1.U(1.W)) ## ~divider // 1 + case 2 => Fill(radixLog2, 1.U(1.W)) ## ~(divider << 1) // 2 + }) + + if (a == 7) { + val qHigh: UInt = selectedQuotientOH(9, 5) + val qLow: UInt = selectedQuotientOH(4, 0) + val qdsSign0: Bool = qHigh.head(2).orR + val qdsSign1: Bool = qLow.head(2).orR + // csa for SRT8 -> CSA32+CSA32 + val dividerHMap = VecInit((-2 to 2).map { + case -2 => divider << 3 // -8 + case -1 => divider << 2 // -4 + case 0 => 0.U // 0 + case 1 => Fill(2, 1.U(1.W)) ## ~(divider << 2) // 4 + case 2 => Fill(1, 1.U(1.W)) ## ~(divider << 3) // 8 + }) + val csa0 = addition.csa.c32( + VecInit( + leftShift(partialReminderSum, radixLog2).head(wLen - radixLog2), + leftShift(partialReminderCarry, radixLog2).head(wLen - radixLog2 - 1) ## qdsSign0, + Mux1H(qHigh, dividerHMap) + ) + ) + val csa1 = addition.csa.c32( + VecInit( + csa0(1).head(wLen - radixLog2), + leftShift(csa0(0), 1).head(wLen - radixLog2 - 1) ## qdsSign1, + Mux1H(qLow, dividerLMap) + ) + ) + partialReminderSumNext := Mux(input.fire, input.bits.dividend, csa1(1) << radixLog2) + partialReminderCarryNext := Mux(input.fire, 0.U, csa1(0) << 1 + radixLog2) + } else if (a == 6) { + val qHigh: UInt = selectedQuotientOH(7, 5) + val qLow: UInt = selectedQuotientOH(4, 0) + val qdsSign0: Bool = qHigh.head(1).asBool + val qdsSign1: Bool = qLow.head(2).orR + + // csa for SRT8 -> CSA32+CSA32 + val dividerHMap = VecInit((-1 to 1).map { + case -1 => divider << 2 // -4 + case 0 => 0.U // 0 + case 1 => Fill(2, 1.U(1.W)) ## ~(divider << 2) // 4 + }) + val csa0 = addition.csa.c32( + VecInit( + leftShift(partialReminderSum, radixLog2).head(wLen - radixLog2), + leftShift(partialReminderCarry, radixLog2).head(wLen - radixLog2 - 1) ## qdsSign0, + Mux1H(qHigh, dividerHMap) + ) + ) + val csa1 = addition.csa.c32( + VecInit( + csa0(1).head(wLen - radixLog2), + leftShift(csa0(0), 1).head(wLen - radixLog2 - 1) ## qdsSign1, + Mux1H(qLow, dividerLMap) + ) + ) + partialReminderSumNext := Mux(input.fire, input.bits.dividend, csa1(1) << radixLog2) + partialReminderCarryNext := Mux(input.fire, 0.U, csa1(0) << 1 + radixLog2) + } else if (a == 5) { + val qHigh: UInt = selectedQuotientOH(7, 5) + val qLow: UInt = selectedQuotientOH(4, 0) + val qdsSign0: Bool = qHigh.head(1).asBool + val qdsSign1: Bool = qLow.head(2).orR + + // csa for SRT8 -> CSA32+CSA32 + val dividerHMap = VecInit((-1 to 1).map { + case -1 => divider << 2 // -4 + case 0 => 0.U // 0 + case 1 => Fill(2, 1.U(1.W)) ## ~(divider << 2) // 4 + }) + val csa0 = addition.csa.c32( + VecInit( + leftShift(partialReminderSum, radixLog2).head(wLen - radixLog2), + leftShift(partialReminderCarry, radixLog2).head(wLen - radixLog2 - 1) ## qdsSign0, + Mux1H(qHigh, dividerHMap) + ) + ) + val csa1 = addition.csa.c32( + VecInit( + csa0(1).head(wLen - radixLog2), + leftShift(csa0(0), 1).head(wLen - radixLog2 - 1) ## qdsSign1, + Mux1H(qLow, dividerLMap) + ) + ) + partialReminderSumNext := Mux(input.fire, input.bits.dividend, csa1(1) << radixLog2) + partialReminderCarryNext := Mux(input.fire, 0.U, csa1(0) << 1 + radixLog2) + } else if (a == 4) { + val qHigh: UInt = selectedQuotientOH(7, 5) + val qLow: UInt = selectedQuotientOH(4, 0) + val qdsSign0: Bool = qHigh.head(1).asBool + val qdsSign1: Bool = qLow.head(2).orR + + // csa for SRT8 -> CSA32+CSA32 + val dividerHMap = VecInit((-1 to 1).map { + case -1 => divider << 1 // -2 + case 0 => 0.U // 0 + case 1 => Fill(radixLog2, 1.U(1.W)) ## ~(divider << 1) // 2 + }) + val csa0 = addition.csa.c32( + VecInit( + leftShift(partialReminderSum, radixLog2).head(wLen - radixLog2), + leftShift(partialReminderCarry, radixLog2).head(wLen - radixLog2 - 1) ## qdsSign0, + Mux1H(qHigh, dividerHMap) + ) + ) + val csa1 = addition.csa.c32( + VecInit( + csa0(1).head(wLen - radixLog2), + leftShift(csa0(0), 1).head(wLen - radixLog2 - 1) ## qdsSign1, + Mux1H(qLow, dividerLMap) + ) + ) + partialReminderSumNext := Mux(input.fire, input.bits.dividend, csa1(1) << radixLog2) + partialReminderCarryNext := Mux(input.fire, 0.U, csa1(0) << 1 + radixLog2) + } + + dividerNext := Mux(input.fire, input.bits.divider, divider) + counterNext := Mux(input.fire, input.bits.counter, counter - 1.U) + quotientNext := Mux(input.fire, 0.U, otf(0)) + quotientMinusOneNext := Mux(input.fire, 0.U, otf(1)) +} diff --git a/arithmetic/src/utils/package.scala b/arithmetic/src/utils/package.scala index 6b884a1..2edf1ab 100644 --- a/arithmetic/src/utils/package.scala +++ b/arithmetic/src/utils/package.scala @@ -55,4 +55,10 @@ package object utils { else BitPat((x + (1 << w)).U(w.W)) } + + // left shift and keep the width of Bits + def leftShift(x: Bits, n: Int): UInt = { + val length: Int = x.getWidth + (x << n)(length - 1, 0) + } } diff --git a/arithmetic/tests/src/division/srt/SRT16Test.scala b/arithmetic/tests/src/division/srt/SRT16Test.scala new file mode 100644 index 0000000..47ffa9b --- /dev/null +++ b/arithmetic/tests/src/division/srt/SRT16Test.scala @@ -0,0 +1,78 @@ +package division.srt.srt16 + +import chisel3._ +import chisel3.tester.{ChiselUtestTester, testableClock, testableData} +import utest._ + +import scala.util.Random + +object SRT16Test extends TestSuite with ChiselUtestTester { + def tests: Tests = Tests { + test("SRT16 should pass") { + def testcase(width: Int): Unit ={ + // parameters + val radixLog2: Int = 4 + val n: Int = width + val m: Int = n - 1 + val p: Int = Random.nextInt(m - radixLog2 +1) //order to offer guardwidth + val q: Int = Random.nextInt(m - radixLog2 +1) + val dividend: BigInt = BigInt(p, Random) + val divider: BigInt = BigInt(q, Random) +// val dividend: BigInt = BigInt("65") +// val divider: BigInt = BigInt("1") + def zeroCheck(x: BigInt): Int = { + var flag = false + var a: Int = m + while (!flag && (a >= -1)) { + flag = ((BigInt(1) << a) & x) != 0 + a = a - 1 + } + a + 1 + } + val zeroHeadDividend: Int = m - zeroCheck(dividend) + val zeroHeadDivider: Int = m - zeroCheck(divider) + val needComputerWidth: Int = zeroHeadDivider - zeroHeadDividend + 1 + 1 + val noguard: Boolean = needComputerWidth % radixLog2 == 0 + val guardWidth: Int = if (noguard) 0 else 4 - needComputerWidth % 4 + val counter: Int = (needComputerWidth + guardWidth) / radixLog2 + if ((divider == 0) || (divider > dividend) || (needComputerWidth <= 0)) + return + val quotient: BigInt = dividend / divider + val remainder: BigInt = dividend % divider + val leftShiftWidthDividend: Int = zeroHeadDividend - guardWidth + val leftShiftWidthDivider: Int = zeroHeadDivider + // test + testCircuit(new SRT16(n, n, n), + Seq(chiseltest.internal.NoThreadingAnnotation, + chiseltest.simulator.WriteVcdAnnotation)) { + dut: SRT16 => + dut.clock.setTimeout(0) + dut.input.valid.poke(true.B) + dut.input.bits.dividend.poke((dividend << leftShiftWidthDividend).U) + dut.input.bits.divider.poke((divider << leftShiftWidthDivider).U) + dut.input.bits.counter.poke(counter.U) + dut.clock.step() + dut.input.valid.poke(false.B) + var flag = false + for (a <- 1 to 1000 if !flag) { + if (dut.output.valid.peek().litValue == 1) { + flag = true + println(dut.output.bits.quotient.peek().litValue) + println(dut.output.bits.reminder.peek().litValue) + utest.assert(dut.output.bits.quotient.peek().litValue == quotient) + utest.assert(dut.output.bits.reminder.peek().litValue >> zeroHeadDivider == remainder) + } + dut.clock.step() + } + utest.assert(flag) + dut.clock.step(scala.util.Random.nextInt(10)) + } + } + + testcase(64) +// for( i <- 1 to 50){ +// testcase(64) +// } + } + } +} \ No newline at end of file diff --git a/arithmetic/tests/src/division/srt/SRT4Test.scala b/arithmetic/tests/src/division/srt/SRT4Test.scala new file mode 100644 index 0000000..efb42dd --- /dev/null +++ b/arithmetic/tests/src/division/srt/SRT4Test.scala @@ -0,0 +1,82 @@ +package division.srt.srt4 + +import chisel3._ +import chisel3.tester.{ChiselUtestTester, testableClock, testableData} +import utest._ +import scala.util.{Random} + +object SRT4Test extends TestSuite with ChiselUtestTester { + def tests: Tests = Tests { + test("SRT4 should pass") { + def testcase(width: Int): Unit ={ + // parameters + val radixLog2: Int = 2 + val n: Int = width + val m: Int = n - 1 + val p: Int = Random.nextInt(m) + val q: Int = Random.nextInt(m) + val dividend: BigInt = BigInt(p, Random) + val divider: BigInt = BigInt(q, Random) +// val dividend: BigInt = BigInt("65") +// val divider: BigInt = BigInt("1") + def zeroCheck(x: BigInt): Int = { + var flag = false + var a: Int = m + while (!flag && (a >= -1)) { + flag = ((BigInt(1) << a) & x) != 0 + a = a - 1 + } + a + 1 + } + val zeroHeadDividend: Int = m - zeroCheck(dividend) + val zeroHeadDivider: Int = m - zeroCheck(divider) + val needComputerWidth: Int = zeroHeadDivider - zeroHeadDividend + 1 + radixLog2 - 1 + val noguard: Boolean = needComputerWidth % radixLog2 == 0 + val counter: Int = (needComputerWidth + 1) / 2 + if ((divider == 0) || (divider > dividend) || (needComputerWidth <= 0)) + return + val quotient: BigInt = dividend / divider + val remainder: BigInt = dividend % divider + val leftShiftWidthDividend: Int = zeroHeadDividend - (if (noguard) 0 else 1) + val leftShiftWidthDivider: Int = zeroHeadDivider +// println("dividend = %8x, dividend = %d ".format(dividend, dividend)) +// println("divider = %8x, divider = %d".format(divider, divider)) +// println("zeroHeadDividend = %d, dividend << zeroHeadDividend = %d".format(zeroHeadDividend, dividend << leftShiftWidthDividend)) +// println("zeroHeadDivider = %d, divider << zeroHeadDivider = %d".format(zeroHeadDivider, divider << leftShiftWidthDivider)) +// println("quotient = %d, remainder = %d".format(quotient, remainder)) +// println("counter = %d, needComputerWidth = %d".format(counter, needComputerWidth)) + // test + testCircuit(new SRT4(n, n, n), + Seq(chiseltest.internal.NoThreadingAnnotation, + chiseltest.simulator.WriteVcdAnnotation)) { + dut: SRT4 => + dut.clock.setTimeout(0) + dut.input.valid.poke(true.B) + dut.input.bits.dividend.poke((dividend << leftShiftWidthDividend).U) + dut.input.bits.divider.poke((divider << leftShiftWidthDivider).U) + dut.input.bits.counter.poke(counter.U) + dut.clock.step() + dut.input.valid.poke(false.B) + var flag = false + for (a <- 1 to 1000 if !flag) { + if (dut.output.valid.peek().litValue == 1) { + flag = true + println(dut.output.bits.quotient.peek().litValue) + println(dut.output.bits.reminder.peek().litValue) + utest.assert(dut.output.bits.quotient.peek().litValue == quotient) + utest.assert(dut.output.bits.reminder.peek().litValue >> zeroHeadDivider == remainder) + } + dut.clock.step() + } + utest.assert(flag) + dut.clock.step(scala.util.Random.nextInt(10)) + } + } + + testcase(64) +// for( i <- 1 to 50){ +// testcase(64) +// } + } + } +} \ No newline at end of file diff --git a/arithmetic/tests/src/division/srt/SRT8Test.scala b/arithmetic/tests/src/division/srt/SRT8Test.scala new file mode 100644 index 0000000..317b8e0 --- /dev/null +++ b/arithmetic/tests/src/division/srt/SRT8Test.scala @@ -0,0 +1,77 @@ +package division.srt.srt8 + +import chisel3._ +import chisel3.tester.{ChiselUtestTester, testableClock, testableData} +import utest._ + +import scala.util.Random + +object SRT8Test extends TestSuite with ChiselUtestTester { + def tests: Tests = Tests { + test("SRT8 should pass") { + def testcase(width: Int): Unit ={ + // parameters + val radixLog2: Int = 3 + val n: Int = width + val m: Int = n - 1 + val p: Int = Random.nextInt(m - radixLog2 +1) //order to offer guardwidth + val q: Int = Random.nextInt(m - radixLog2 +1) + val dividend: BigInt = BigInt(p, Random) + val divider: BigInt = BigInt(q, Random) +// val dividend: BigInt = BigInt("65") +// val divider: BigInt = BigInt("1") + def zeroCheck(x: BigInt): Int = { + var flag = false + var a: Int = m + while (!flag && (a >= -1)) { + flag = ((BigInt(1) << a) & x) != 0 + a = a - 1 + } + a + 1 + } + val zeroHeadDividend: Int = m - zeroCheck(dividend) + val zeroHeadDivider: Int = m - zeroCheck(divider) + val needComputerWidth: Int = zeroHeadDivider - zeroHeadDividend + 1 + radixLog2 -1 + val noguard: Boolean = needComputerWidth % radixLog2 == 0 + val guardWidth: Int = if (noguard) 0 else 3 - needComputerWidth % 3 + val counter: Int = (needComputerWidth + guardWidth) / radixLog2 + if ((divider == 0) || (divider > dividend) || (needComputerWidth <= 0)) + return + val quotient: BigInt = dividend / divider + val remainder: BigInt = dividend % divider + val leftShiftWidthDividend: Int = zeroHeadDividend - guardWidth + val leftShiftWidthDivider: Int = zeroHeadDivider + testCircuit(new SRT8(n, n, n), + Seq(chiseltest.internal.NoThreadingAnnotation, + chiseltest.simulator.WriteVcdAnnotation)) { + dut: SRT8 => + dut.clock.setTimeout(0) + dut.input.valid.poke(true.B) + dut.input.bits.dividend.poke((dividend << leftShiftWidthDividend).U) + dut.input.bits.divider.poke((divider << leftShiftWidthDivider).U) + dut.input.bits.counter.poke(counter.U) + dut.clock.step() + dut.input.valid.poke(false.B) + var flag = false + for (a <- 1 to 1000 if !flag) { + if (dut.output.valid.peek().litValue == 1) { + flag = true + println(dut.output.bits.quotient.peek().litValue) + println(dut.output.bits.reminder.peek().litValue) + utest.assert(dut.output.bits.quotient.peek().litValue == quotient) + utest.assert(dut.output.bits.reminder.peek().litValue >> zeroHeadDivider == remainder) + } + dut.clock.step() + } + utest.assert(flag) + dut.clock.step(scala.util.Random.nextInt(10)) + } + } + + testcase(64) +// for( i <- 1 to 50){ +// testcase(64) +// } + } + } +} \ No newline at end of file diff --git a/arithmetic/tests/src/division/srt/SRTSpec.scala b/arithmetic/tests/src/division/srt/SRTSpec.scala new file mode 100644 index 0000000..2d9c960 --- /dev/null +++ b/arithmetic/tests/src/division/srt/SRTSpec.scala @@ -0,0 +1,18 @@ +package division.srt + +import utest._ +import chisel3._ +import utils.extend + + + +object SRTSpec extends TestSuite{ + override def tests: Tests = Tests { + test("SRT should draw PD") { + val srt = SRTTable(8,5,5,5) +// println(srt.tables) +// println(srt.tablesToQDS) + srt.dumpGraph(srt.pd, os.root / "tmp" / "srt8-5-5-5.png") + } + } +} diff --git a/arithmetic/tests/src/division/srt/SRTTest.scala b/arithmetic/tests/src/division/srt/SRTTest.scala new file mode 100644 index 0000000..5c73846 --- /dev/null +++ b/arithmetic/tests/src/division/srt/SRTTest.scala @@ -0,0 +1,88 @@ +package division.srt + +import chisel3._ +import chisel3.tester.{ChiselUtestTester, testableClock, testableData} +import utest._ + +import scala.util.Random + +object SRTTest extends TestSuite with ChiselUtestTester { + def tests: Tests = Tests { + test("SRT should pass") { + def testcase(n: Int = 64, + radixLog2: Int = 4, + a: Int = 2, + dTruncateWidth: Int = 4, + rTruncateWidth: Int = 4): Unit ={ + //tips + println("SRT%d(width = %d, a = %d, dTruncateWidth = %d, rTruncateWidth = %d) should pass ".format( + 1 << radixLog2 , n , a, dTruncateWidth, rTruncateWidth)) + // parameters + val m: Int = n - 1 + val p: Int = Random.nextInt(m - radixLog2 +1) //order to offer guardwidth + val q: Int = Random.nextInt(m - radixLog2 +1) + val dividend: BigInt = BigInt(p, Random) + val divider: BigInt = BigInt(q, Random) + // val dividend: BigInt = BigInt("65") + // val divider: BigInt = BigInt("1") + def zeroCheck(x: BigInt): Int = { + var flag = false + var k: Int = m + while (!flag && (k >= -1)) { + flag = ((BigInt(1) << k) & x) != 0 + k = k - 1 + } + k + 1 + } + val zeroHeadDividend: Int = m - zeroCheck(dividend) + val zeroHeadDivider: Int = m - zeroCheck(divider) + val needComputerWidth: Int = zeroHeadDivider - zeroHeadDividend + 1 + (if(radixLog2 == 4) 2 else radixLog2) -1 + val noguard: Boolean = needComputerWidth % radixLog2 == 0 + val guardWidth: Int = if (noguard) 0 else radixLog2 - needComputerWidth % radixLog2 + val counter: Int = (needComputerWidth + guardWidth) / radixLog2 + if ((divider == 0) || (divider > dividend) || (needComputerWidth <= 0)) + return + val quotient: BigInt = dividend / divider + val remainder: BigInt = dividend % divider + val leftShiftWidthDividend: Int = zeroHeadDividend - guardWidth + val leftShiftWidthDivider: Int = zeroHeadDivider +// println("dividend = %8x, dividend = %d ".format(dividend, dividend)) +// println("divider = %8x, divider = %d".format(divider, divider)) +// println("zeroHeadDividend = %d, dividend << zeroHeadDividend = %d".format(zeroHeadDividend, dividend << leftShiftWidthDividend)) +// println("zeroHeadDivider = %d, divider << zeroHeadDivider = %d".format(zeroHeadDivider, divider << leftShiftWidthDivider)) +// println("quotient = %d, remainder = %d".format(quotient, remainder)) +// println("counter = %d, needComputerWidth = %d".format(counter, needComputerWidth)) + // test + testCircuit(new SRT(n, n, n, radixLog2, a, dTruncateWidth, rTruncateWidth), + Seq(chiseltest.internal.NoThreadingAnnotation, + chiseltest.simulator.WriteVcdAnnotation)) { + dut: SRT => + dut.clock.setTimeout(0) + dut.input.valid.poke(true.B) + dut.input.bits.dividend.poke((dividend << leftShiftWidthDividend).U) + dut.input.bits.divider.poke((divider << leftShiftWidthDivider).U) + dut.input.bits.counter.poke(counter.U) + dut.clock.step() + dut.input.valid.poke(false.B) + var flag = false + for (a <- 1 to 1000 if !flag) { + if (dut.output.valid.peek().litValue == 1) { + flag = true + println(dut.output.bits.quotient.peek().litValue) + println(dut.output.bits.reminder.peek().litValue >> zeroHeadDivider) + utest.assert(dut.output.bits.quotient.peek().litValue == quotient) + utest.assert(dut.output.bits.reminder.peek().litValue >> zeroHeadDivider == remainder) + } + dut.clock.step() + } + utest.assert(flag) + dut.clock.step(scala.util.Random.nextInt(5)) + } + } +// testcase(64) + for( i <- 1 to 50){ + testcase(n = 64, radixLog2 = 3, a = 7, dTruncateWidth = 4, rTruncateWidth = 4) + } + } + } +} \ No newline at end of file diff --git a/build.sc b/build.sc index 202124d..46a78e1 100644 --- a/build.sc +++ b/build.sc @@ -15,6 +15,11 @@ object v { val utest = ivy"com.lihaoyi::utest:latest.integration" val upickle = ivy"com.lihaoyi::upickle:latest.integration" val osLib = ivy"com.lihaoyi::os-lib:latest.integration" +// val breeze = ivy"com.github.ktakagaki.breeze::breeze:2.0" +// val breezeNatives = ivy"com.github.ktakagaki.breeze::breeze-natives:2.0" +// val breezeViz = ivy"org.scalanlp::breeze-viz:2.0" + val spire = ivy"org.typelevel::spire:0.17.0" + val evilplot = ivy"io.github.cibotech::evilplot:0.8.1" // val prime = ivy"org.apache.commons:commons-math3:3.6.1" } @@ -38,6 +43,11 @@ class arithmetic extends ScalaModule with ScalafmtModule with PublishModule { m v.chiseltest, v.upickle, v.osLib, +// v.breeze, +// v.breezeViz, +// v.breezeNatives, + v.spire, + v.evilplot ) object tests extends Tests with Utest {