Skip to content

Commit

Permalink
add SRT4Test & fix
Browse files Browse the repository at this point in the history
  • Loading branch information
wissygh committed Apr 22, 2022
1 parent 0cd81d8 commit 4821208
Show file tree
Hide file tree
Showing 5 changed files with 103 additions and 50 deletions.
10 changes: 5 additions & 5 deletions arithmetic/src/division/srt/OTF.scala
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ class OTF(radix: Int, qWidth: Int, ohWidth: Int) extends Module {
val qNext: UInt = Mux1H(
Seq(
input.selectedQuotientOH(0) -> "b110".U,
input.selectedQuotientOH(1) -> "b101".U,
input.selectedQuotientOH(1) -> "b111".U,
input.selectedQuotientOH(2) -> "b000".U,
input.selectedQuotientOH(3) -> "b001".U,
input.selectedQuotientOH(4) -> "b010".U
Expand All @@ -36,9 +36,9 @@ class OTF(radix: Int, qWidth: Int, ohWidth: Int) extends Module {
val cShiftQ: Bool = input.selectedQuotientOH(ohWidth / 2, 0).orR
val cShiftQM: Bool = input.selectedQuotientOH(ohWidth - 1, ohWidth / 2).orR

val qIn: UInt = (Mux(cShiftQ, qNext, radix.U + qNext))(1, 0)
val qmIn: UInt = (Mux(!cShiftQM, qNext - 1.U, (radix - 1).U + qNext))(1, 0)
val qIn: UInt =Mux(cShiftQ, qNext, radix.U + qNext)(1, 0)
val qmIn: UInt = Mux(!cShiftQM, qNext - 1.U, (radix - 1).U + qNext)(1, 0)

output.quotient := Mux(cShiftQ, input.quotient, input.quotientMinusOne) ## qIn
output.quotientMinusOne := Mux(cShiftQM, input.quotientMinusOne, input.quotient) ## qmIn
output.quotient := Mux(cShiftQ, input.quotient, input.quotientMinusOne)(qWidth-2, 0) ## qIn
output.quotientMinusOne := Mux(!cShiftQM, input.quotient, input.quotientMinusOne)(qWidth-2, 0) ## qmIn
}
17 changes: 11 additions & 6 deletions arithmetic/src/division/srt/QDS.scala
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
package division.srt
import chisel3._
import chisel3.util.{log2Ceil, BitPat, RegEnable, Valid}
import chisel3.util.{BitPat, RegEnable, Valid}
import chisel3.util.experimental.decode._
import utils.extend

class QDSInput(rWidth: Int) extends Bundle {
val partialReminderCarry: UInt = UInt(rWidth.W)
Expand All @@ -12,11 +13,11 @@ class QDSOutput(ohWidth: Int) extends Bundle {
val selectedQuotientOH: UInt = UInt(ohWidth.W)
}

class QDS(rWidth: Int, ohWidth: Int) extends Module {
class QDS(rWidth: Int, ohWidth: Int, partialDividerWidth: Int) extends Module {
// IO
val input = IO(Input(new QDSInput(rWidth)))
val output = IO(Output(new QDSOutput(ohWidth)))
val partialDivider = IO(Flipped(Valid(UInt(3.W))))
val partialDivider = IO(Flipped(Valid(UInt(partialDividerWidth.W))))

// State, in order to keep divider's value
val partialDividerReg = RegEnable(partialDivider.bits, partialDivider.valid)
Expand All @@ -37,7 +38,7 @@ class QDS(rWidth: Int, ohWidth: Int) extends Module {
// Array(18, 6, -8, -20),
// Array(20, 6, -8, -20),
// Array(20, 8, -8, -22),
// Array(24, 8, -8, -24)
// Array(24, 8, -8, -24)/16
// )
val columnSelect = partialDividerLatch
val selectRom: Vec[Vec[UInt]] = VecInit(
Expand All @@ -50,10 +51,14 @@ class QDS(rWidth: Int, ohWidth: Int) extends Module {
VecInit("b110_1100".U, "b111_1000".U, "b000_1000".U, "b001_0110".U),
VecInit("b110_1000".U, "b111_1000".U, "b000_1000".U, "b001_1000".U)
)

val mkVec = selectRom(columnSelect)
val adderWidth = rWidth + 1
val selectPoints = VecInit(mkVec.map { mk =>
// maybe have a problem."+&" extend signed to avoid overflow. only for srt4, because -44/16 < y^ < 42/16.
(input.partialReminderCarry +& input.partialReminderSum + mk).head(1)
// extend signed to avoid overflow. only for srt4, because -44/16 < y^ < 42/16.
(extend(input.partialReminderCarry, adderWidth).asUInt
+ extend(input.partialReminderSum, adderWidth).asUInt
+ extend(mk, adderWidth).asUInt).head(1)
}).asUInt

// decoder or findFirstOne here, prefer decoder, the decoder only for srt4
Expand Down
80 changes: 43 additions & 37 deletions arithmetic/src/division/srt/SRT.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,11 @@ import scala.math.ceil
* -44/16 < y^ < 42/16
*/

// TODO: width
// TODO: counter & n
class SRTInput(dividendWidth: Int, dividerWidth: Int, n: Int) extends Bundle {
val dividend = UInt(dividendWidth.W) //0.1**********
val divider = UInt(dividerWidth.W) //0.1**********
val counter = UInt(log2Ceil(n).W) //the width of quotient.
val dividend = UInt(dividendWidth.W) //0.1**********
val divider = UInt(dividerWidth.W) //0.1**********
val counter = UInt(log2Ceil(n).W) //the width of quotient.
}

class SRTOutput(reminderWidth: Int, quotientWidth: Int) extends Bundle {
Expand All @@ -33,12 +33,14 @@ class SRTOutput(reminderWidth: Int, quotientWidth: Int) extends Bundle {
class SRT(
dividendWidth: Int,
dividerWidth: Int,
n: Int, // the longest width
n: Int, // the longest width,
radixLog2: Int = 2,
a: Int = 2,
dTruncateWidth: Int = 4,
rTruncateWidth: Int = 4)
extends Module {

val xLen: Int = dividendWidth + radixLog2
val ohWidth: Int = 2 * a + 1

// IO
Expand All @@ -47,8 +49,8 @@ class SRT(

// State
// because we need a CSA to minimize the critical path
val partialReminderCarry = Reg(UInt((dividendWidth + radixLog2).W))
val partialReminderSum = Reg(UInt((dividendWidth + radixLog2).W))
val partialReminderCarry = Reg(UInt(xLen.W))
val partialReminderSum = Reg(UInt(xLen.W))
val divider = RegInit(input.bits.divider)
val quotient = Reg(UInt(n.W))
val quotientMinusOne = Reg(UInt(n.W))
Expand All @@ -59,67 +61,71 @@ class SRT(
qdsSign := qds.output.selectedQuotientOH(ohWidth - 1, ohWidth / 2).orR

// Datapath
val qds = Module(new QDS(rTruncateWidth, ohWidth))
qds.input.partialReminderSum := partialReminderSum.head(1 + radixLog2 + rTruncateWidth)
qds.input.partialReminderCarry := partialReminderCarry.head(1 + radixLog2 + rTruncateWidth)
qds.partialDivider.bits := input.bits.divider.head(1 + radixLog2 + rTruncateWidth)(dTruncateWidth - 2, 0)
val rWidth: Int = 1 + radixLog2 + rTruncateWidth
val qds = Module(new QDS(rWidth, ohWidth, dTruncateWidth - 1))
qds.input.partialReminderSum := partialReminderSum.head(rWidth)
qds.input.partialReminderCarry := partialReminderCarry.head(rWidth)
qds.partialDivider.bits := input.bits.divider.head(dTruncateWidth+1)(dTruncateWidth-2, 0) //0.1********** -> 0.1*** -> ***

counter := counter - radixLog2.U
counter := counter - 1.U
// if counter === 0.U && sz.output.sign, correct the quotient and reminder. valid = 1
// the output of srt
val sz = Module(new SZ(dividendWidth - 2))
sz.input.partialReminderSum := partialReminderSum(partialReminderSum.getWidth-3, 0)
sz.input.partialReminderCarry := partialReminderCarry(partialReminderSum.getWidth-3, 0)
output.valid := Mux(counter === 0.U, true.B, false.B)

// correcting maybe have problem
quotient := Mux(counter === 0.U && sz.output.sign, quotient - 1.U, quotient)
output.bits.reminder := Mux1H(
Map(
(counter === 0.U && sz.output.zero) -> 0.U,
(counter === 0.U && sz.output.sign) -> (sz.output.remainder + 1.U + divider),
(counter === 0.U && !sz.output.sign) -> (sz.output.remainder + 1.U)
)
)
// val sz = Module(new SZ(dividendWidth - 2))
// sz.input.partialReminderSum := partialReminderSum(partialReminderSum.getWidth-3, 0)
// sz.input.partialReminderCarry := partialReminderCarry(partialReminderSum.getWidth-3, 0)
// // correcting maybe have problem
// quotient := quotient - Mux(sz.output.sign, 1.U, 0.U)
// output.bits.reminder := sz.output.remainder + Mux(sz.output.sign, divider, 0.U)
// output.bits.quotient := quotient

// according two adders
val isLastCycle: Bool = !counter.orR
output.valid := Mux(isLastCycle, true.B, false.B)
val remainderNoCorrect: UInt = partialReminderSum(xLen-3, 0) + partialReminderCarry(xLen-3, 0)
val needCorrect: Bool = Mux(isLastCycle, remainderNoCorrect.head(1).asBool, false.B)
val remainderCorrect: UInt = partialReminderSum(xLen-3, 0) + partialReminderCarry(xLen-3, 0) + divider

quotient := quotient - needCorrect.asUInt
output.bits.reminder := Mux(needCorrect, remainderNoCorrect, remainderCorrect)
output.bits.quotient := quotient

// for SRT4 -> CSA32
// for SRT8 -> CSA32+CSA32
// for SRT16 -> CSA53+CSA32
// SRT16 <- SRT4 + SRT4*5
val csa = Module(new CarrySaveAdder(CSACompressor3_2, dividendWidth + radixLog2))
val csa = Module(new CarrySaveAdder(CSACompressor3_2, xLen))
csa.in(0) := partialReminderSum
csa.in(1) := (partialReminderCarry ## !qdsSign)
csa.in(1) := (partialReminderCarry(xLen, 1) ## !qdsSign)
csa.in(2) := Mux1H(
qds.output.selectedQuotientOH,
// TODO: this is for SRT4, for SRT8 or SRT16, this should be changed
//this is for SRT4, for SRT8 or SRT16, this should be changed
VecInit((-2 to 2).map {
case -2 => divider << 1
case -1 => divider
case 0 => 0.U
case 1 => extend(~divider, dividendWidth + radixLog2)
case 2 => extend((~divider) << 1, dividendWidth + radixLog2)
case 1 => extend(~divider, xLen)
case 2 => extend((~divider) << 1, xLen)
})
)

// TODO: sel maybe have a problem
partialReminderSum := Mux1H(
Map(
(counter === input.bits.counter) -> input.bits.dividend,
(counter > 0.U) -> (csa.out(0) << radixLog2),
(counter === 0.U) -> partialReminderSum
counter.orR -> (csa.out(0) << radixLog2)(xLen-1, 0),
isLastCycle -> partialReminderSum
)
)

partialReminderCarry := Mux1H(
Map(
(counter === input.bits.counter) -> 0.U,
(counter > 0.U) -> (csa.out(1) << (radixLog2 - 1)),
(counter === 0.U) -> partialReminderCarry
counter.orR -> (csa.out(1) << radixLog2)(xLen-1, 0),
isLastCycle -> partialReminderCarry
)
)

// On-The-Fly conversion
val otf = Module(new OTF((1 << radixLog2), n, ohWidth))
val otf = Module(new OTF(1 << radixLog2, n, ohWidth))
otf.input.quotient := quotient
otf.input.quotientMinusOne := quotientMinusOne
otf.input.selectedQuotientOH := qds.output.selectedQuotientOH
Expand Down
4 changes: 2 additions & 2 deletions arithmetic/src/division/srt/SZ.scala
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,18 @@ class SZInput(rWidth: Int) extends Bundle {
class SZOutput(rWidth: Int) extends Bundle {
val sign: Bool = Bool()
val zero: Bool = Bool()
val remainder: UInt = UInt((rWidth + 1).W)
val remainder: UInt = UInt((rWidth).W)
}

class SZ(rWidth: Int, prefixSum: PrefixSum = BrentKungSum) extends Module {
val input = IO(Input(new SZInput(rWidth)))
val output = IO(Output(new SZOutput(rWidth)))

//controlpath

//datapath
// csa(ws,wc,-2^-b) => Seq[(Bool,Bool)]
// drop signed bits
// prefixtree by group
val ws = input.partialReminderCarry.asBools
val wc = input.partialReminderSum.asBools
val psc: Seq[(Bool, Bool)] = ws.zip(wc).map { case (s, c) => (!(s ^ c), (s | c)) }
Expand Down
42 changes: 42 additions & 0 deletions arithmetic/tests/src/division/srt/SRT4Test.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
package division.srt

import chisel3._
import chisel3.tester.{ChiselUtestTester, testableClock, testableData}
import utest._

object SRT4Test extends TestSuite with ChiselUtestTester{
def tests: Tests = Tests {
test("SRT4 should pass") {
// parameters
val dividendWidth: Int = 4
val dividerWidth: Int = 3
val n: Int = 3
val dividend: Int = 7
val divider: Int = 3
val countr: Int = 2
val remainder: Int = dividend / divider
val quotient: Int = dividend % divider
//test
testCircuit(new SRT(dividendWidth, dividerWidth, n),
Seq(chiseltest.internal.NoThreadingAnnotation,
chiseltest.simulator.WriteVcdAnnotation)){
dut: SRT =>
dut.clock.setTimeout(0)
dut.input.valid.poke(true.B)
dut.input.bits.dividend.poke(dividend.U)
dut.input.bits.divider.poke(divider.U)
dut.input.bits.counter.poke(countr.U)
var flag = false
for(a <- 1 to 1000) {
dut.clock.step()
if(dut.output.valid.peek().litValue == 1) {
flag = true
utest.assert(dut.output.bits.quotient.peek().litValue == quotient)
utest.assert(dut.output.bits.reminder.peek().litValue == remainder)
}
}
utest.assert(flag)
}
}
}
}

0 comments on commit 4821208

Please sign in to comment.