Skip to content

Commit

Permalink
fix srt16 & SRTTable fixed
Browse files Browse the repository at this point in the history
  • Loading branch information
wissygh committed Jun 11, 2022
1 parent 4d045c0 commit a086604
Show file tree
Hide file tree
Showing 11 changed files with 55 additions and 48 deletions.
2 changes: 1 addition & 1 deletion arithmetic/src/division/srt/SRTTable.scala
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ case class SRTTable(
}.flatMap {
case (i, ps) =>
ps.map {
case (x, y) => (x.toDouble, y.toDouble * 16)
case (x, y) => (x.toDouble, y.toDouble * (1 << xTruncateWidth.toInt))
}
}.groupBy(_._1).toSeq.sortBy(_._1).map { case (x, y) => y.map { case (x, y) => y.toInt }.reverse }

Expand Down
2 changes: 1 addition & 1 deletion arithmetic/src/division/srt/srt16/QDS.scala
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ class QDS(rWidth: Int, ohWidth: Int, partialDividerWidth: Int, tables: Seq[Seq[I
else (-x).toBinaryString
)
.toString
.U
.U(rWidth.W)
})
})

Expand Down
38 changes: 21 additions & 17 deletions arithmetic/src/division/srt/srt16/SRT16.scala
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ class SRT16(
val xLen: Int = dividendWidth + radixLog2 + 1
val wLen: Int = xLen + radixLog2
val ohWidth: Int = 2 * a + 1
val rWidth: Int = 1 + radixLog2 + rTruncateWidth

// IO
val input = IO(Flipped(DecoupledIO(new SRTInput(dividendWidth, dividerWidth, n))))
Expand Down Expand Up @@ -64,17 +65,19 @@ class SRT16(
case 1 => Fill(1 + radixLog2, 1.U(1.W)) ## ~divider
case 2 => Fill(radixLog2, 1.U(1.W)) ## ~(divider << 1)
})
val csaIn1 = leftShift(partialReminderSum, radixLog2).head(wLen - radixLog2)
val csaIn2 = leftShift(partialReminderCarry, radixLog2).head(wLen - radixLog2 - 1)
val csa1 = addition.csa.c32(VecInit(csaIn1, csaIn2 ## false.B, dividerMap(0))) // -2
val csa2 = addition.csa.c32(VecInit(csaIn1, csaIn2 ## false.B, dividerMap(1))) // -1
val csa3 = addition.csa.c32(VecInit(csaIn1, csaIn2 ## false.B, dividerMap(2))) // 0
val csa4 = addition.csa.c32(VecInit(csaIn1, csaIn2 ## true.B, dividerMap(3))) // 1
val csa5 = addition.csa.c32(VecInit(csaIn1, csaIn2 ## true.B, dividerMap(4))) // 2
val csa0InWidth = rWidth + radixLog2 + 1
val csaIn1 = leftShift(partialReminderSum, radixLog2).head(csa0InWidth)
val csaIn2 = leftShift(partialReminderCarry, radixLog2).head(csa0InWidth)

val csa1 = addition.csa.c32(VecInit(csaIn1, csaIn2, dividerMap(0).head(csa0InWidth))) // -2 csain 10bit
val csa2 = addition.csa.c32(VecInit(csaIn1, csaIn2, dividerMap(1).head(csa0InWidth))) // -1
val csa3 = addition.csa.c32(VecInit(csaIn1, csaIn2, dividerMap(2).head(csa0InWidth))) // 0
val csa4 = addition.csa.c32(VecInit(csaIn1, csaIn2, dividerMap(3).head(csa0InWidth))) // 1
val csa5 = addition.csa.c32(VecInit(csaIn1, csaIn2, dividerMap(4).head(csa0InWidth))) // 2

// qds
val rWidth: Int = 1 + radixLog2 + rTruncateWidth
val tables: Seq[Seq[Int]] = division.srt.SRTTable(1 << radixLog2, a, dTruncateWidth, rTruncateWidth).tablesToQDS

val tables: Seq[Seq[Int]] = SRTTable(1 << radixLog2, a, dTruncateWidth, rTruncateWidth).tablesToQDS
val partialDivider: UInt = dividerNext.head(dTruncateWidth)(dTruncateWidth - 2, 0)
val qdsOH0: UInt =
QDS(rWidth, ohWidth, dTruncateWidth - 1, tables)(
Expand All @@ -97,13 +100,6 @@ class SRT16(
val qds4SelectedQuotientOH: UInt = qds(csa4) // 1
val qds5SelectedQuotientOH: UInt = qds(csa5) // 2

val csa0OutMap = VecInit((-2 to 2).map {
case -2 => csa1
case -1 => csa2
case 0 => csa3
case 1 => csa4
case 2 => csa5
})
val qds1SelectedQuotientOHMap = VecInit((-2 to 2).map {
case -2 => qds1SelectedQuotientOH
case -1 => qds2SelectedQuotientOH
Expand All @@ -113,8 +109,16 @@ class SRT16(
})

val qdsOH1 = Mux1H(qdsOH0, qds1SelectedQuotientOHMap) // q_j+2 oneHot
val qds0sign = qdsOH0(ohWidth - 1, ohWidth / 2 + 1).orR
val qds1sign = qdsOH1(ohWidth - 1, ohWidth / 2 + 1).orR
val csa0Out = Mux1H(qdsOH0, csa0OutMap)

val csa0Out = addition.csa.c32(
VecInit(
leftShift(partialReminderSum, radixLog2).head(wLen - radixLog2),
leftShift(partialReminderCarry, radixLog2).head(wLen - radixLog2 - 1) ## qds0sign,
Mux1H(qdsOH0, dividerMap)
)
)
val csa1Out = addition.csa.c32(
VecInit(
leftShift(csa0Out(1), radixLog2).head(wLen - radixLog2),
Expand Down
13 changes: 6 additions & 7 deletions arithmetic/src/division/srt/srt4/QDS.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,7 @@ package division.srt.srt4

import chisel3._
import chisel3.util.BitPat
import chisel3.util.experimental.decode._
import division.srt.SRTTable
import chisel3.util.experimental.decode.{TruthTable}
import utils.extend

class QDSInput(rWidth: Int, partialDividerWidth: Int) extends Bundle {
Expand All @@ -16,7 +15,7 @@ class QDSOutput(ohWidth: Int) extends Bundle {
val selectedQuotientOH: UInt = UInt(ohWidth.W)
}

class QDS(rWidth: Int, ohWidth: Int, partialDividerWidth: Int) extends Module {
class QDS(rWidth: Int, ohWidth: Int, partialDividerWidth: Int, tables: Seq[Seq[Int]]) extends Module {
// IO
val input = IO(Input(new QDSInput(rWidth, partialDividerWidth)))
val output = IO(Output(new QDSOutput(ohWidth)))
Expand Down Expand Up @@ -44,7 +43,6 @@ class QDS(rWidth: Int, ohWidth: Int, partialDividerWidth: Int) extends Module {
// )

// get from SRTTable.
val tables: Seq[Seq[Int]] = SRTTable(4, 2, 4, 4).tablesToQDS
lazy val selectRom = VecInit(tables.map {
case x =>
VecInit(x.map {
Expand All @@ -55,7 +53,7 @@ class QDS(rWidth: Int, ohWidth: Int, partialDividerWidth: Int) extends Module {
else (-x).toBinaryString
)
.toString
.U
.U(rWidth.W)
})
})

Expand Down Expand Up @@ -87,12 +85,13 @@ object QDS {
def apply(
rWidth: Int,
ohWidth: Int,
partialDividerWidth: Int
partialDividerWidth: Int,
tables: Seq[Seq[Int]]
)(partialReminderSum: UInt,
partialReminderCarry: UInt,
partialDivider: UInt
): UInt = {
val m = Module(new QDS(rWidth, ohWidth, partialDividerWidth))
val m = Module(new QDS(rWidth, ohWidth, partialDividerWidth, tables))
m.input.partialReminderSum := partialReminderSum
m.input.partialReminderCarry := partialReminderCarry
m.input.partialDivider := partialDivider
Expand Down
5 changes: 3 additions & 2 deletions arithmetic/src/division/srt/srt4/SRT4.scala
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import utils.leftShift
* dTruncateWidth = 4, rTruncateWidth = 8
* y^(xxx.xxxx), d^(0.1xxx)
* -44/16 < y^ < 42/16
* floor((-r*rho - 2^-t)_t) <= y^ <= floor((r*rho - ulp)_t)
*/

class SRT4(
Expand Down Expand Up @@ -69,9 +70,9 @@ class SRT4(

// qds
val rWidth: Int = 1 + radixLog2 + rTruncateWidth

val tables: Seq[Seq[Int]] = SRTTable(1 << radixLog2, a, dTruncateWidth, rTruncateWidth).tablesToQDS
val selectedQuotientOH: UInt =
QDS(rWidth, ohWidth, dTruncateWidth - 1)(
QDS(rWidth, ohWidth, dTruncateWidth - 1, tables)(
leftShift(partialReminderSum, radixLog2).head(rWidth),
leftShift(partialReminderCarry, radixLog2).head(rWidth),
dividerNext.head(dTruncateWidth)(dTruncateWidth - 2, 0) //.1********* -> 1*** -> ***
Expand Down
13 changes: 6 additions & 7 deletions arithmetic/src/division/srt/srt8/QDS.scala
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
package division.srt.srt8

import chisel3._
import chisel3.util.{BitPat, ValidIO}
import chisel3.util.experimental.decode.{TruthTable, _}
import division.srt.SRTTable
import chisel3.util.{BitPat}
import chisel3.util.experimental.decode.{TruthTable}
import utils.extend

class QDSInput(rWidth: Int, partialDividerWidth: Int) extends Bundle {
Expand All @@ -16,14 +15,13 @@ class QDSOutput(ohWidth: Int) extends Bundle {
val selectedQuotientOH: UInt = UInt(ohWidth.W)
}

class QDS(rWidth: Int, ohWidth: Int, partialDividerWidth: Int) extends Module {
class QDS(rWidth: Int, ohWidth: Int, partialDividerWidth: Int, tables: Seq[Seq[Int]]) extends Module {
// IO
val input = IO(Input(new QDSInput(rWidth, partialDividerWidth)))
val output = IO(Output(new QDSOutput(ohWidth)))

val columnSelect = input.partialDivider
// Seq[Seq[Int]] => Vec[Vec[UInt]]
val tables: Seq[Seq[Int]] = SRTTable(8, 7, 4, 4).tablesToQDS
lazy val selectRom = VecInit(tables.map {
case x =>
VecInit(x.map {
Expand Down Expand Up @@ -75,12 +73,13 @@ object QDS {
def apply(
rWidth: Int,
ohWidth: Int,
partialDividerWidth: Int
partialDividerWidth: Int,
tables: Seq[Seq[Int]]
)(partialReminderSum: UInt,
partialReminderCarry: UInt,
partialDivider: UInt
): UInt = {
val m = Module(new QDS(rWidth, ohWidth, partialDividerWidth))
val m = Module(new QDS(rWidth, ohWidth, partialDividerWidth, tables))
m.input.partialReminderSum := partialReminderSum
m.input.partialReminderCarry := partialReminderCarry
m.input.partialDivider := partialDivider
Expand Down
15 changes: 9 additions & 6 deletions arithmetic/src/division/srt/srt8/SRT8.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import division.srt._
import division.srt.SRTTable
import chisel3._
import chisel3.util._
import utils.leftShift
import utils.{leftShift}

/** SRT8
* 1/2 <= d < 1, 1/2 < rho <=1, 0 < q < 2
Expand All @@ -13,6 +13,8 @@ import utils.leftShift
* dTruncateWidth = 4, rTruncateWidth = 4
* y^(xxxx.xxxx), d^(0.1xxx)
* table from SRTTable
* -129/16 < y^ < 127/16
* floor((-r*rho - 2^-t)_t) <= y^ <= floor((r*rho - ulp)_t)
*/

class SRT8(
Expand Down Expand Up @@ -68,8 +70,9 @@ class SRT8(

// qds
val rWidth: Int = 1 + radixLog2 + rTruncateWidth
val tables: Seq[Seq[Int]] = SRTTable(1 << radixLog2, a, dTruncateWidth, rTruncateWidth).tablesToQDS
val selectedQuotientOH: UInt =
QDS(rWidth, ohWidth, dTruncateWidth - 1)(
QDS(rWidth, ohWidth, dTruncateWidth - 1, tables)(
leftShift(partialReminderSum, radixLog2).head(rWidth),
leftShift(partialReminderCarry, radixLog2).head(rWidth),
dividerNext.head(dTruncateWidth)(dTruncateWidth - 2, 0) //.1********* -> 1*** -> ***
Expand All @@ -81,14 +84,14 @@ class SRT8(
val qHigh: UInt = selectedQuotientOH(9, 5)
val qLow: UInt = selectedQuotientOH(4, 0)
// csa for SRT8 -> CSA32+CSA32
val divideMap0 = VecInit((-2 to 2).map {
val dividerMap0 = VecInit((-2 to 2).map {
case -2 => divider << 3 // -8
case -1 => divider << 2 // -4
case 0 => 0.U // 0
case 1 => Fill(2, 1.U(1.W)) ## ~(divider << 2) // 4
case 2 => Fill(1, 1.U(1.W)) ## ~(divider << 3) // 8
})
val divideMap1 = VecInit((-2 to 2).map {
val dividerMap1 = VecInit((-2 to 2).map {
case -2 => divider << 1 // -2
case -1 => divider // -1
case 0 => 0.U // 0
Expand All @@ -99,14 +102,14 @@ class SRT8(
VecInit(
leftShift(partialReminderSum, radixLog2).head(wLen - radixLog2),
leftShift(partialReminderCarry, radixLog2).head(wLen - radixLog2 - 1) ## qdsSign0,
Mux1H(qHigh, divideMap0)
Mux1H(qHigh, dividerMap0)
)
)
val csa1 = addition.csa.c32(
VecInit(
csa0(1).head(wLen - radixLog2),
leftShift(csa0(0), 1).head(wLen - radixLog2 - 1) ## qdsSign1,
Mux1H(qLow, divideMap1)
Mux1H(qLow, dividerMap1)
)
)

Expand Down
4 changes: 2 additions & 2 deletions arithmetic/tests/src/division/srt/SRT16Test.scala
Original file line number Diff line number Diff line change
Expand Up @@ -69,9 +69,9 @@ object SRT16Test extends TestSuite with ChiselUtestTester {
}
}

testcase(16)
testcase(64)
// for( i <- 1 to 50){
// testcase(128)
// testcase(64)
// }
}
}
Expand Down
4 changes: 2 additions & 2 deletions arithmetic/tests/src/division/srt/SRT4Test.scala
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ object SRT4Test extends TestSuite with ChiselUtestTester {
}
val zeroHeadDividend: Int = m - zeroCheck(dividend)
val zeroHeadDivider: Int = m - zeroCheck(divider)
val needComputerWidth: Int = zeroHeadDivider - zeroHeadDividend + 1 + 1
val needComputerWidth: Int = zeroHeadDivider - zeroHeadDividend + 1 + radixLog2 - 1
val noguard: Boolean = needComputerWidth % radixLog2 == 0
val counter: Int = (needComputerWidth + 1) / 2
if ((divider == 0) || (divider > dividend) || (needComputerWidth <= 0))
Expand Down Expand Up @@ -74,7 +74,7 @@ object SRT4Test extends TestSuite with ChiselUtestTester {
}

testcase(64)
// for( i <- 1 to 100){
// for( i <- 1 to 50){
// testcase(64)
// }
}
Expand Down
2 changes: 1 addition & 1 deletion arithmetic/tests/src/division/srt/SRT8Test.scala
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ object SRT8Test extends TestSuite with ChiselUtestTester {

testcase(64)
// for( i <- 1 to 50){
// testcase(128)
// testcase(64)
// }
}
}
Expand Down
5 changes: 3 additions & 2 deletions arithmetic/tests/src/division/srt/SRTSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,10 @@ import utils.extend
object SRTSpec extends TestSuite{
override def tests: Tests = Tests {
test("SRT should draw PD") {
val srt = SRTTable(8, 7, 4, 4)
val srt = SRTTable(4,2,4,4)
// println(srt.tables)
// println(srt.tablesToQDS)
srt.dumpGraph(srt.pd, os.root / "tmp" / "srt8-7-4-4.png")
srt.dumpGraph(srt.pd, os.root / "tmp" / "srt4-2-4-4.png")
}
}
}

0 comments on commit a086604

Please sign in to comment.