Skip to content

Commit

Permalink
wip SRT
Browse files Browse the repository at this point in the history
  • Loading branch information
sequencer committed Apr 1, 2022
1 parent 0ef1d57 commit 79edb60
Show file tree
Hide file tree
Showing 5 changed files with 257 additions and 164 deletions.
27 changes: 27 additions & 0 deletions arithmetic/src/division/srt/QDS.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
package division.srt
import chisel3._
import chisel3.util.{RegEnable, Valid}

class QDSInput extends Bundle {
val partialReminderCarry: UInt = ???
val partialReminderSum: UInt = ???
}

class QDSOutput extends Bundle {
val selectedQuotient: UInt = ???
}

class QDS extends Module {
val input = IO(Input(new QDSInput))
val output = IO(Output(new QDSOutput))
// used to select a column of SRT Table
val partialDivider = IO(Flipped(Valid(UInt())))
val partialDividerReg = RegEnable(partialDivider.bits, partialDivider.valid)
// for the first cycle: use partialDivider on the IO
// for the reset of cycles: use partialDividerReg
// for synthesis: the constraint should be IO -> Output is a multi-cycle design
// Reg -> Output is single-cycle
// to avoid glitch, valid should be larger than raise time of partialDividerReg
val partialDividerLatch = Mux(partialDivider.valid, partialDivider.bits, partialDividerReg)

}
211 changes: 54 additions & 157 deletions arithmetic/src/division/srt/SRT.scala
Original file line number Diff line number Diff line change
@@ -1,169 +1,66 @@
package division.srt

import com.cibo.evilplot.colors.HTMLNamedColors
import com.cibo.evilplot.numeric.Bounds
import com.cibo.evilplot.plot._
import com.cibo.evilplot.plot.aesthetics.DefaultTheme._
import com.cibo.evilplot.plot.renderers.PointRenderer
import os.Path
import spire.implicits._
import spire.math._
import addition.csa.CarrySaveAdder
import addition.csa.common.CSACompressor3_2
import chisel3._
import chisel3.util.{Decoupled, DecoupledIO, Mux1H, log2Ceil}

/** Base SRT class.
*
* @param radix is the radix of SRT.
* It defined how many rounds can be calculate in one cycle.
* @note 5.2
*/
case class SRT(
radix: Algebraic,
a: Algebraic,
dTruncateWidth: Algebraic,
xTruncateWidth: Algebraic,
dMin: Algebraic = 0.5,
dMax: Algebraic = 1) {
require(a > 0)
lazy val xMin: Algebraic = -rho * dMax * radix
lazy val xMax: Algebraic = rho * dMax * radix
class SRTInput(dividendWidth: Int, dividerWidth: Int, n: Int) extends Bundle {
val dividend = UInt(dividendWidth.W)
val divider = UInt(dividerWidth.W)
val counter = UInt(log2Ceil(???).W)
}

/** P-D Diagram
*
* @note Graph 5.17(b)
*/
lazy val pd: Plot = Overlay((aMin.toBigInt to aMax.toBigInt).flatMap { k: BigInt =>
Seq(
FunctionPlot.series(
_ * uRate(k.toInt).toDouble,
s"U($k)",
HTMLNamedColors.blue,
Some(Bounds(dMin.toDouble, dMax.toDouble)),
strokeWidth = Some(1)
),
FunctionPlot.series(
_ * lRate(k.toInt).toDouble,
s"L($k)",
HTMLNamedColors.red,
Some(Bounds(dMin.toDouble, dMax.toDouble)),
strokeWidth = Some(1)
)
) ++ qdsPoints :+ mesh
}: _*)
.title(s"P-D Graph of $this")
.xLabel("d")
.yLabel(s"${radix.toInt}ω[j]")
.rightLegend()
.standard()
lazy val aMax: Algebraic = a
lazy val aMin: Algebraic = -a
lazy val deltaD: Algebraic = pow(2, -dTruncateWidth.toDouble)
lazy val deltaX: Algebraic = pow(2, -xTruncateWidth.toDouble)
class SRTOutput(reminderWidth: Int, quotientWidth: Int) extends Bundle {
val reminder = UInt(reminderWidth.W)
val quotient = UInt(quotientWidth.W)
}

/** redundancy factor
* @note 5.8
*/
lazy val rho: Algebraic = a / (radix - 1)
lazy val tables: Seq[(Int, Seq[(Algebraic, Seq[Algebraic])])] = {
(aMin.toInt to aMax.toInt).drop(1).map { k =>
k -> dSet.dropRight(1).map { d =>
val (floor, ceil) = xRange(k, d, d + deltaD)
val m: Seq[Algebraic] = xSet.filter { x: Algebraic => x <= ceil && x >= floor }
(d, m)
}
}
}
lazy val qdsPoints: Seq[Plot] = {
tables.map {
case (i, ps) =>
ScatterPlot(
ps.flatMap { case (d, xs) => xs.map(x => com.cibo.evilplot.numeric.Point(d.toDouble, x.toDouble)) },
Some(
PointRenderer
.default[com.cibo.evilplot.numeric.Point](pointSize = Some(1), color = Some(HTMLNamedColors.gold))
)
)
}
}
// only SRT4 currently
class SRT(
dividendWidth: Int,
dividerWidth: Int,
n: Int)
extends Module {
// IO
val input: DecoupledIO[SRTInput] = Flipped(Decoupled(new SRTInput(dividendWidth, dividerWidth, n)))
val output: DecoupledIO[SRTOutput] = Decoupled(new SRTOutput(dividerWidth, dividendWidth))

private val xStep = (xMax - xMin) / deltaX
// @note 5.7
require(a >= radix / 2)
private val xSet = Seq.tabulate((xStep + 1).toInt) { n => xMin + deltaX * n }
private val dStep: Algebraic = (dMax - dMin) / deltaD
assert((rho > 1 / 2) && (rho <= 1))
private val dSet = Seq.tabulate((dStep + 1).toInt) { n => dMin + deltaD * n }
private val mesh =
ScatterPlot(
xSet.flatMap { y =>
dSet.map { x =>
com.cibo.evilplot.numeric.Point(x.toDouble, y.toDouble)
}
},
Some(
PointRenderer
.default[com.cibo.evilplot.numeric.Point](pointSize = Some(0.5), color = Some(HTMLNamedColors.gray))
)
)

override def toString: String =
s"SRT${radix.toInt} with quotient set: from ${aMin.toInt} to ${aMax.toInt}"
// State
// because we need a CSA to minimize the critical path
val partialReminderCarry = Reg(UInt())
val partialReminderSum = Reg(UInt())
val divider = Reg(UInt())

/** Robertson Diagram
*
* @note Graph 5.17(a)
*/
def robertson(d: Algebraic): Plot = {
require(d > dMin && d < dMax)
Overlay((aMin.toBigInt to aMax.toBigInt).map { k: BigInt =>
FunctionPlot.series(
_ - (Algebraic(k) * d).toDouble,
s"$k",
HTMLNamedColors.black,
xbounds = Some(Bounds(((Algebraic(k) - rho) * d).toDouble, ((Algebraic(k) + rho) * d).toDouble))
)
}: _*)
.title(s"Robertson Graph of $this divisor: $d")
.xLabel("rω[j]")
.yLabel("ω[j+1]")
.xbounds((-radix * rho * dMax).toDouble, (radix * rho * dMax).toDouble)
.ybounds((-rho * d).toDouble, (rho * d).toDouble)
.rightLegend()
.standard()
}
val quotient = Reg(UInt())
val quotientMinusOne = Reg(UInt())

def dumpGraph(plot: Plot, path: Path) = {
javax.imageio.ImageIO.write(
plot.render().asBufferedImage,
"png",
path.wrapped.toFile
)
}
val state = Reg(UInt())
val counter = Reg(UInt())

/** for range `dLeft` to `dRight`, return the `rOmegaCeil` and `rOmegaFloor`
* this is used for constructing the rectangle where m_k(i) is located.
*/
private def xRange(k: Algebraic, dLeft: Algebraic, dRight: Algebraic): (Algebraic, Algebraic) = {
Seq(L(k, dLeft), L(k, dRight), U(k - 1, dLeft), U(k - 1, dRight))
// not safe
.sortBy(_.toDouble)
.drop(1)
.dropRight(1) match { case Seq(l, r) => (l, r) }
}
// Control
// sign of select quotient, true -> negative, false -> positive
val qdsSign: Bool = Wire(Bool())

/** find the intersection point between L`k` and `d` */
private def L(k: Algebraic, d: Algebraic): Algebraic = lRate(k) * d
// Datapath
val qds = new QDS()

/** slope factor of L_k
*
* @note 5.56
*/
private def lRate(k: Algebraic): Algebraic = k - rho

/** find the intersection point between U`k` and `d` */
private def U(k: Algebraic, d: Algebraic): Algebraic = uRate(k) * d

/** slope factor of U_k
*
* @note 5.56
*/
private def uRate(k: Algebraic): Algebraic = k + rho
}
val csa = new CarrySaveAdder(CSACompressor3_2, ???)
csa.in(0) := partialReminderSum
csa.in(1) := (partialReminderCarry ## !qdsSign)
csa.in(2) := Mux1H(Map(
??? -> ,
??? ->
))
partialReminderSum := Mux1H(Map(
??? -> input.bits.dividend,
??? -> (csa.out(0) << log2Ceil(n)),
??? -> partialReminderSum
))
partialReminderCarry := Mux1H(Map(
??? -> 0.U,
??? -> (csa.out(1) << log2Ceil(n)),
??? -> partialReminderCarry
))
}
Loading

0 comments on commit 79edb60

Please sign in to comment.