Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement SRT class #3

Draft
wants to merge 34 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
c83d329
Implement SRT class
sequencer Apr 16, 2021
d570098
Use spire to increase precision to arbitrary precision.
sequencer Apr 16, 2021
a9960a9
rewrite SRT.
sequencer Apr 22, 2021
cb2caed
bug fix and add test.
sequencer Apr 23, 2021
0ef1d57
fix
sequencer Apr 1, 2022
79edb60
wip SRT
sequencer Apr 1, 2022
1108845
srt fix
wissygh Apr 4, 2022
c502ab5
wip
sequencer Apr 5, 2022
c3c3ed1
wip
sequencer Apr 7, 2022
65808e2
srt fetch
wissygh Apr 8, 2022
26888fe
Coding OTF
wissygh Apr 10, 2022
0b2a14e
using table from SRTTable
wissygh Apr 11, 2022
70c19e7
using table from XS
wissygh Apr 13, 2022
a3930e2
SZ fix
wissygh Apr 14, 2022
0cd81d8
rm srt4test
wissygh Apr 18, 2022
c9c87ca
add SRT4Test & fix test
wissygh Apr 25, 2022
ea32e50
SRT4Test fix
wissygh May 9, 2022
db72ed0
SRT fix
wissygh May 9, 2022
44655b4
Merge branch 'srt' into srt
wissygh May 10, 2022
1660f14
fix reformat
wissygh May 10, 2022
8749cf8
Merge remote-tracking branch 'origin/srt' into srt
wissygh May 10, 2022
b2e8911
reformat
wissygh May 10, 2022
a92fa59
Merge pull request #25 from wissygh/srt
wissygh May 10, 2022
3f33c99
srt4 fix
wissygh May 10, 2022
7018ab8
using RegEnable
wissygh May 15, 2022
f0fc979
srt4 debug
wissygh May 23, 2022
ddadec6
srt4test fixed
wissygh Jun 2, 2022
9aa7dc5
srt4 fix & add srt16
wissygh Jun 3, 2022
bb63956
fix srt4 & naive srt16 implement
wissygh Jun 4, 2022
6dc8605
fix SRTTable
wissygh Jun 5, 2022
4d045c0
srt8 fixed & get tables from SRTTable
wissygh Jun 10, 2022
bf3eec4
fix srt16 & SRTTable fixed
wissygh Jun 11, 2022
ffdb455
fix selectRom & add selection of Radix
wissygh Jun 12, 2022
2220115
add selection of a & fix srt8
wissygh Jun 14, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -24,4 +24,5 @@ verdiLog
*.out
*.cmd
*.log
*.json
*.json
*.iml
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ object PrefixGraph {
}

object CommonSumByConsole extends HasPrefixSumWithGraphImp with CommonPrefixSum {
val filePath = Path(io.StdIn.readLine("Import your graph generated by `dot -Txdot_json`: "), pwd)
val filePath = Path(scala.io.StdIn.readLine("Import your graph generated by `dot -Txdot_json`: "), pwd)
val fileName = filePath.baseName
val prefixGraph: PrefixGraph = PrefixGraph(filePath)
}
Expand Down
68 changes: 68 additions & 0 deletions arithmetic/src/division/srt/SRT.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
package division.srt

import division.srt.srt4._
import division.srt.srt8._
import division.srt.srt16._
import chisel3._
import chisel3.util.{DecoupledIO, ValidIO}

class SRT(
dividendWidth: Int,
dividerWidth: Int,
n: Int, // the longest width
radixLog2: Int = 2,
a: Int = 2,
dTruncateWidth: Int = 4,
rTruncateWidth: Int = 4)
extends Module {
// val x = (radixLog2, a, dTruncateWidth)
// val tips = x match {
// case (2,2,4) => require(rTruncateWidth >= 4, "rTruncateWidth need >= 4")
// case (2,2,5) => require(rTruncateWidth >= 4, "rTruncateWidth need >= 4")
// case (2,2,6) => require(rTruncateWidth >= 4, "rTruncateWidth need >= 4")
//
// case (3,4,6) => require(rTruncateWidth >= 7, "rTruncateWidth need >= 7")
// case (3,4,7) => require(rTruncateWidth >= 6, "rTruncateWidth need >= 6")
//
// case (3,5,5) => require(rTruncateWidth >= 5, "rTruncateWidth need >= 5")
// case (3,5,6) => require(rTruncateWidth >= 4, "rTruncateWidth need >= 4")
//
// case (3,6,4) => require(rTruncateWidth >= 6, "rTruncateWidth need >= 6")
// case (3,6,5) => require(rTruncateWidth >= 4, "rTruncateWidth need >= 4")
//
// case (3,7,4) => require(rTruncateWidth >= 4, "rTruncateWidth need >= 4")
// case (3,7,5) => require(rTruncateWidth >= 3, "rTruncateWidth need >= 3")
//
// case (4,2,4) => require(rTruncateWidth >= 4, "rTruncateWidth need >= 4")
// case (4,2,5) => require(rTruncateWidth >= 4, "rTruncateWidth need >= 4")
// case (4,2,6) => require(rTruncateWidth >= 4, "rTruncateWidth need >= 4")
//
// case _ => println("this srt is not supported")
// }

val input = IO(Flipped(DecoupledIO(new SRTInput(dividendWidth, dividerWidth, n))))
val output = IO(ValidIO(new SRTOutput(dividerWidth, dividendWidth)))

// select radix
if (radixLog2 == 2) { // SRT4
val srt = Module(new SRT4(dividendWidth, dividerWidth, n, radixLog2, a, dTruncateWidth, rTruncateWidth))
srt.input <> input
output <> srt.output
} else if (radixLog2 == 3) { // SRT8
val srt = Module(new SRT8(dividendWidth, dividerWidth, n, radixLog2, a, dTruncateWidth, rTruncateWidth))
srt.input <> input
output <> srt.output
} else if (radixLog2 == 4) { //SRT16
val srt = Module(new SRT16(dividendWidth, dividerWidth, n, radixLog2 >> 1, a, dTruncateWidth, rTruncateWidth))
srt.input <> input
output <> srt.output
}

// val srt = radixLog2 match {
// case 2 => Module(new SRT4(dividendWidth, dividerWidth, n, radixLog2, a, dTruncateWidth, rTruncateWidth))
// case 3 => Module(new SRT8(dividendWidth, dividerWidth, n, radixLog2, a, dTruncateWidth, rTruncateWidth))
// case 4 => Module(new SRT16(dividendWidth, dividerWidth, n, radixLog2 >> 1, a, dTruncateWidth, rTruncateWidth))
// }
// srt.input <> input
// output <> srt.output
}
38 changes: 38 additions & 0 deletions arithmetic/src/division/srt/SRTIO.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
package division.srt

import chisel3._
import chisel3.util.log2Ceil
// SRTIO
class SRTInput(dividendWidth: Int, dividerWidth: Int, n: Int) extends Bundle {
val dividend = UInt(dividendWidth.W) //.***********
val divider = UInt(dividerWidth.W) //.1**********
val counter = UInt(log2Ceil(n).W) //the width of quotient.
}

class SRTOutput(reminderWidth: Int, quotientWidth: Int) extends Bundle {
val reminder = UInt(reminderWidth.W)
val quotient = UInt(quotientWidth.W)
}

//OTFIO
class OTFInput(qWidth: Int, ohWidth: Int) extends Bundle {
val quotient = UInt(qWidth.W)
val quotientMinusOne = UInt(qWidth.W)
val selectedQuotientOH = UInt(ohWidth.W)
}

class OTFOutput(qWidth: Int) extends Bundle {
val quotient = UInt(qWidth.W)
val quotientMinusOne = UInt(qWidth.W)
}

// QDSIO
class QDSInput(rWidth: Int, partialDividerWidth: Int) extends Bundle {
val partialReminderCarry: UInt = UInt(rWidth.W)
val partialReminderSum: UInt = UInt(rWidth.W)
val partialDivider: UInt = UInt(partialDividerWidth.W)
}

class QDSOutput(ohWidth: Int) extends Bundle {
val selectedQuotientOH: UInt = UInt(ohWidth.W)
}
194 changes: 194 additions & 0 deletions arithmetic/src/division/srt/SRTTable.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,194 @@
package division.srt

import com.cibo.evilplot.colors.HTMLNamedColors
import com.cibo.evilplot.numeric.Bounds
import com.cibo.evilplot.plot._
import com.cibo.evilplot.plot.aesthetics.DefaultTheme._
import com.cibo.evilplot.plot.renderers.PointRenderer
import os.Path
import spire.implicits._
import spire.math._

/** Base SRT class.
*
* @param radix is the radix of SRT.
* It defined how many rounds can be calculate in one cycle.
* @note 5.2
*/
case class SRTTable(
radix: Algebraic,
a: Algebraic,
dTruncateWidth: Algebraic,
xTruncateWidth: Algebraic,
dMin: Algebraic = 0.5,
dMax: Algebraic = 1) {
require(a > 0)
lazy val xMin: Algebraic = -rho * dMax * radix
lazy val xMax: Algebraic = rho * dMax * radix

/** P-D Diagram
*
* @note Graph 5.17(b)
*/
lazy val pd: Plot = Overlay((aMin.toBigInt to aMax.toBigInt).flatMap { k: BigInt =>
Seq(
FunctionPlot.series(
_ * uRate(k.toInt).toDouble,
s"U($k)",
HTMLNamedColors.blue,
Some(Bounds(dMin.toDouble, dMax.toDouble)),
strokeWidth = Some(1)
),
FunctionPlot.series(
_ * lRate(k.toInt).toDouble,
s"L($k)",
HTMLNamedColors.red,
Some(Bounds(dMin.toDouble, dMax.toDouble)),
strokeWidth = Some(1)
)
) ++ qdsPoints :+ mesh
}: _*)
.title(s"P-D Graph of $this")
.xLabel("d")
.yLabel(s"${radix.toInt}ω[j]")
.rightLegend()
.standard()

lazy val aMax: Algebraic = a
lazy val aMin: Algebraic = -a
lazy val deltaD: Algebraic = pow(2, -dTruncateWidth.toDouble)
lazy val deltaX: Algebraic = pow(2, -xTruncateWidth.toDouble)

/** redundancy factor
* @note 5.8
*/
lazy val rho: Algebraic = a / (radix - 1)
// k d m xSet
lazy val tables: Seq[(Int, Seq[(Algebraic, Seq[Algebraic])])] = {
(aMin.toInt to aMax.toInt).drop(1).map { k =>
k -> dSet.dropRight(1).map { d =>
val (floor, ceil) = xRange(k, d, d + deltaD)
val m: Seq[Algebraic] = xSet.filter { x: Algebraic => x <= (ceil - deltaX) && x >= floor }
(d, m)
}
}
}
lazy val qdsPoints: Seq[Plot] = {
tables.map {
case (i, ps) =>
ScatterPlot(
ps.flatMap { case (d, xs) => xs.map(x => com.cibo.evilplot.numeric.Point(d.toDouble, x.toDouble)) },
Some(
PointRenderer
.default[com.cibo.evilplot.numeric.Point](pointSize = Some(1), color = Some(HTMLNamedColors.gold))
)
)
}
}

// TODO: select a Constant from each m, then offer the table to QDS.
// todo: ? select rule: symmetry and draw a line parallel to the X-axis, how define the rule
lazy val tablesToQDS: Seq[Seq[Int]] = {
(aMin.toInt to aMax.toInt).drop(1).map { k =>
k -> dSet.dropRight(1).map { d =>
val (floor, ceil) = xRange(k, d, d + deltaD)
val m: Seq[Algebraic] = xSet.filter { x: Algebraic => x <= (ceil - deltaX) && x >= floor }
(d, m.head)
}
}
}.flatMap {
case (i, ps) =>
ps.map {
case (x, y) => (x.toDouble, y.toDouble * (1 << xTruncateWidth.toInt))
}
}.groupBy(_._1).toSeq.sortBy(_._1).map { case (x, y) => y.map { case (x, y) => y.toInt }.reverse }

private val xStep = (xMax - xMin) / deltaX
// @note 5.7
require(a >= radix / 2)
private val xSet = Seq.tabulate((xStep / 2 + 1).toInt) { n => deltaX * n } ++ Seq.tabulate((xStep / 2 + 1).toInt) {
n => -deltaX * n
}

private val dStep: Algebraic = (dMax - dMin) / deltaD
assert((rho > 1 / 2) && (rho <= 1))
private val dSet = Seq.tabulate((dStep + 1).toInt) { n => dMin + deltaD * n }

private val mesh =
ScatterPlot(
xSet.flatMap { y =>
dSet.map { x =>
com.cibo.evilplot.numeric.Point(x.toDouble, y.toDouble)
}
},
Some(
PointRenderer
.default[com.cibo.evilplot.numeric.Point](pointSize = Some(0.5), color = Some(HTMLNamedColors.gray))
)
)

override def toString: String =
s"SRT${radix.toInt} with quotient set: from ${aMin.toInt} to ${aMax.toInt}"

/** Robertson Diagram
*
* @note Graph 5.17(a)
*/
def robertson(d: Algebraic): Plot = {
require(d > dMin && d < dMax)
Overlay((aMin.toBigInt to aMax.toBigInt).map { k: BigInt =>
FunctionPlot.series(
_ - (Algebraic(k) * d).toDouble,
s"$k",
HTMLNamedColors.black,
xbounds = Some(Bounds(((Algebraic(k) - rho) * d).toDouble, ((Algebraic(k) + rho) * d).toDouble))
)
}: _*)
.title(s"Robertson Graph of $this divisor: $d")
.xLabel("rω[j]")
.yLabel("ω[j+1]")
.xbounds((-radix * rho * dMax).toDouble, (radix * rho * dMax).toDouble)
.ybounds((-rho * d).toDouble, (rho * d).toDouble)
.rightLegend()
.standard()
}

def dumpGraph(plot: Plot, path: Path) = {
javax.imageio.ImageIO.write(
plot.render().asBufferedImage,
"png",
path.wrapped.toFile
)
}

// select four points, then drop the first and the last one.
/** for range `dLeft` to `dRight`, return the `rOmegaCeil` and `rOmegaFloor`
* this is used for constructing the rectangle where m_k(i) is located.
*/
private def xRange(k: Algebraic, dLeft: Algebraic, dRight: Algebraic): (Algebraic, Algebraic) = {
Seq(L(k, dLeft), L(k, dRight), U(k - 1, dLeft), U(k - 1, dRight))
// not safe
.sortBy(_.toDouble)
.drop(1)
.dropRight(1) match { case Seq(l, r) => (l, r) }
}

// U_k = (k + rho) * d, L_k = (k - rho) * d
/** find the intersection point between L`k` and `d` */
private def L(k: Algebraic, d: Algebraic): Algebraic = lRate(k) * d

/** slope factor of L_k
*
* @note 5.56
*/
private def lRate(k: Algebraic): Algebraic = k - rho

/** find the intersection point between U`k` and `d` */
private def U(k: Algebraic, d: Algebraic): Algebraic = uRate(k) * d

/** slope factor of U_k
*
* @note 5.56
*/
private def uRate(k: Algebraic): Algebraic = k + rho
}
50 changes: 50 additions & 0 deletions arithmetic/src/division/srt/srt16/OTF.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
package division.srt.srt16

import division.srt._
import chisel3._
import chisel3.util.Mux1H

class OTF(radixLog2: Int, qWidth: Int, ohWidth: Int) extends Module {
val input = IO(Input(new OTFInput(qWidth, ohWidth)))
val output = IO(Output(new OTFOutput(qWidth)))

val radix: Int = 1 << radixLog2
// datapath
// q_j+1 in this circle, only for srt4
val qNext: UInt = Mux1H(
Seq(
input.selectedQuotientOH(0) -> "b110".U,
input.selectedQuotientOH(1) -> "b111".U,
input.selectedQuotientOH(2) -> "b000".U,
input.selectedQuotientOH(3) -> "b001".U,
input.selectedQuotientOH(4) -> "b010".U
)
)

// val cShiftQ: Bool = qNext >= 0.U
// val cShiftQM: Bool = qNext <= 0.U
val cShiftQ: Bool = input.selectedQuotientOH(ohWidth - 1, ohWidth / 2).orR
val cShiftQM: Bool = input.selectedQuotientOH(ohWidth / 2, 0).orR
val qIn: UInt = Mux(cShiftQ, qNext, radix.U + qNext)(radixLog2 - 1, 0)
val qmIn: UInt = Mux(!cShiftQM, qNext - 1.U, (radix - 1).U + qNext)(radixLog2 - 1, 0)

output.quotient := Mux(cShiftQ, input.quotient, input.quotientMinusOne)(qWidth - radixLog2, 0) ## qIn
output.quotientMinusOne := Mux(!cShiftQM, input.quotient, input.quotientMinusOne)(qWidth - radixLog2, 0) ## qmIn
}

object OTF {
def apply(
radixLog2: Int,
qWidth: Int,
ohWidth: Int
)(quotient: UInt,
quotientMinusOne: UInt,
selectedQuotientOH: UInt
): Vec[UInt] = {
val m = Module(new OTF(radixLog2, qWidth, ohWidth))
m.input.quotient := quotient
m.input.quotientMinusOne := quotientMinusOne
m.input.selectedQuotientOH := selectedQuotientOH
VecInit(m.output.quotient, m.output.quotientMinusOne)
}
}
Loading