Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Div sqrt #37

Merged
merged 5 commits into from
Nov 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ project/project
project/target
target/
# test
test_run_dir/
verdiLog

.bsp/
Expand Down
3 changes: 0 additions & 3 deletions .gitmodules
Original file line number Diff line number Diff line change
@@ -1,3 +0,0 @@
[submodule "dependencies/chisel"]
path = dependencies/chisel
url = [email protected]:chipsalliance/chisel.git
1 change: 0 additions & 1 deletion .mill-version

This file was deleted.

20 changes: 20 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@


init:
git submodule update --init

compile:
mill -i -j 0 arithmetic[5.0.0].compile

run:
mill -i -j 0 arithmetic[5.0.0].run

test:
mill -i -j 0 test[5.0.0].test

bsp:
mill -i mill.bsp.BSP/install

clean:
git clean -fd

27 changes: 13 additions & 14 deletions arithmetic/src/division/srt/srt16/SRT16.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,14 @@ class SRT16(
extends Module {
val guardBitWidth = 3
val xLen: Int = dividendWidth + radixLog2 + 1 + guardBitWidth
val wLen: Int = xLen + radixLog2
val ohWidth: Int = 2 * a + 1
val rWidth: Int = 1 + radixLog2 + rTruncateWidth

// IO
val input = IO(Flipped(DecoupledIO(new SRTInput(dividendWidth, dividerWidth, n, 4))))
val output = IO(ValidIO(new SRTOutput(dividerWidth, dividendWidth)))

val partialReminderCarryNext, partialReminderSumNext = Wire(UInt(wLen.W))
val partialReminderCarryNext, partialReminderSumNext = Wire(UInt(xLen.W))
val dividerNext = Wire(UInt(dividerWidth.W))
val counterNext = Wire(UInt(log2Ceil(n).W))
val quotientNext, quotientMinusOneNext = Wire(UInt(n.W))
Expand All @@ -37,8 +36,8 @@ class SRT16(
val isLastCycle, enable: Bool = Wire(Bool())
// State
// because we need a CSA to minimize the critical path
val partialReminderCarry = RegEnable(partialReminderCarryNext, 0.U(wLen.W), enable)
val partialReminderSum = RegEnable(partialReminderSumNext, 0.U(wLen.W), enable)
val partialReminderCarry = RegEnable(partialReminderCarryNext, 0.U(xLen.W), enable)
val partialReminderSum = RegEnable(partialReminderSumNext, 0.U(xLen.W), enable)
val divider = RegEnable(dividerNext, 0.U(dividerWidth.W), enable)
val quotient = RegEnable(quotientNext, 0.U(n.W), enable)
val quotientMinusOne = RegEnable(quotientMinusOneNext, 0.U(n.W), enable)
Expand All @@ -59,9 +58,9 @@ class SRT16(
val remainderNoCorrect: UInt = partialReminderSum + partialReminderCarry
val remainderCorrect: UInt =
partialReminderSum + partialReminderCarry + (divisorExtended << radixLog2)
val needCorrect: Bool = remainderNoCorrect(wLen - 3).asBool
val needCorrect: Bool = remainderNoCorrect(xLen - 1).asBool

output.bits.reminder := Mux(needCorrect, remainderCorrect, remainderNoCorrect)(wLen - 4, radixLog2 + guardBitWidth)
output.bits.reminder := Mux(needCorrect, remainderCorrect, remainderNoCorrect)(xLen - 2, radixLog2 + guardBitWidth)
output.bits.quotient := Mux(needCorrect, quotientMinusOne, quotient)

// 5*CSA32 SRT16 <- SRT4 + SRT4*5 /SRT16 -> CSA53+CSA32
Expand All @@ -73,8 +72,8 @@ class SRT16(
case 2 => Fill(radixLog2, 1.U(1.W)) ## ~(divisorExtended << 1)
})
val csa0InWidth = rWidth + radixLog2 + 1
val csaIn1 = leftShift(partialReminderSum, radixLog2).head(csa0InWidth)
val csaIn2 = leftShift(partialReminderCarry, radixLog2).head(csa0InWidth)
val csaIn1 = partialReminderSum.head(csa0InWidth)
val csaIn2 = partialReminderCarry.head(csa0InWidth)

val csa1 = addition.csa.c32(VecInit(csaIn1, csaIn2, dividerMap(0).head(csa0InWidth))) // -2 csain 10bit
val csa2 = addition.csa.c32(VecInit(csaIn1, csaIn2, dividerMap(1).head(csa0InWidth))) // -1
Expand All @@ -87,8 +86,8 @@ class SRT16(
val partialDivider: UInt = dividerNext.head(dTruncateWidth)(dTruncateWidth - 2, 0)
val qdsOH0: UInt =
QDS(rWidth, ohWidth, dTruncateWidth - 1, tables)(
leftShift(partialReminderSum, radixLog2).head(rWidth),
leftShift(partialReminderCarry, radixLog2).head(rWidth),
partialReminderSum.head(rWidth),
partialReminderCarry.head(rWidth),
partialDivider
) // q_j+1 oneHot

Expand Down Expand Up @@ -120,15 +119,15 @@ class SRT16(

val csa0Out = addition.csa.c32(
VecInit(
leftShift(partialReminderSum, radixLog2).head(wLen - radixLog2),
leftShift(partialReminderCarry, radixLog2).head(wLen - radixLog2 - 1) ## qds0sign,
partialReminderSum.head(xLen),
partialReminderCarry.head(xLen - 1) ## qds0sign,
Mux1H(qdsOH0, dividerMap)
)
)
val csa1Out = addition.csa.c32(
VecInit(
leftShift(csa0Out(1), radixLog2).head(wLen - radixLog2),
leftShift(csa0Out(0), radixLog2 + 1).head(wLen - radixLog2 - 1) ## qds1sign,
leftShift(csa0Out(1), radixLog2).head(xLen),
leftShift(csa0Out(0), radixLog2 + 1).head(xLen - 1) ## qds1sign,
Mux1H(qdsOH1, dividerMap)
)
)
Expand Down
43 changes: 21 additions & 22 deletions arithmetic/src/division/srt/srt4/SRT4.scala
Original file line number Diff line number Diff line change
Expand Up @@ -26,25 +26,24 @@ import utils.leftShift
* @param rTruncateWidth TruncateWidth for residual fractional part
*/
class SRT4(
dividendWidth: Int,
dividerWidth: Int,
n: Int, // the longest width
radixLog2: Int = 2,
a: Int = 2,
dTruncateWidth: Int = 4,
rTruncateWidth: Int = 4)
extends Module {
dividendWidth: Int,
dividerWidth: Int,
n: Int, // the longest width
radixLog2: Int = 2,
a: Int = 2,
dTruncateWidth: Int = 4,
rTruncateWidth: Int = 4)
extends Module {
val guardBitWidth = 1

/** width for csa */
val xLen: Int = dividendWidth + radixLog2 + 1 + guardBitWidth
val wLen: Int = xLen + radixLog2
// IO
val input = IO(Flipped(DecoupledIO(new SRTInput(dividendWidth, dividerWidth, n, 2))))
val output = IO(ValidIO(new SRTOutput(dividerWidth, dividendWidth)))

//rW[j]
val partialReminderCarryNext, partialReminderSumNext = Wire(UInt(wLen.W))
val partialReminderCarryNext, partialReminderSumNext = Wire(UInt(xLen.W))
val quotientNext, quotientMinusOneNext = Wire(UInt(n.W))
val dividerNext = Wire(UInt(dividerWidth.W))
val counterNext = Wire(UInt(log2Ceil(n).W))
Expand All @@ -55,8 +54,8 @@ class SRT4(

// State
// because we need a CSA to minimize the critical path
val partialReminderCarry = RegEnable(partialReminderCarryNext, 0.U(wLen.W), enable)
val partialReminderSum = RegEnable(partialReminderSumNext, 0.U(wLen.W), enable)
val partialReminderCarry = RegEnable(partialReminderCarryNext, 0.U(xLen.W), enable)
val partialReminderSum = RegEnable(partialReminderSumNext, 0.U(xLen.W), enable)
val divider = RegEnable(dividerNext, 0.U(dividerWidth.W), enable)
val quotient = RegEnable(quotientNext, 0.U(n.W), enable)
val quotientMinusOne = RegEnable(quotientMinusOneNext, 0.U(n.W), enable)
Expand All @@ -80,9 +79,9 @@ class SRT4(
/** partialReminderSum is r*W[j], so remainderCorrect = remainderNoCorrect + r*divisor */
val remainderCorrect: UInt =
partialReminderSum + partialReminderCarry + (divisorExtended << radixLog2)
val needCorrect: Bool = remainderNoCorrect(wLen - 3).asBool
val needCorrect: Bool = remainderNoCorrect(xLen - 1).asBool

output.bits.reminder := Mux(needCorrect, remainderCorrect, remainderNoCorrect)(wLen - 4, radixLog2 + guardBitWidth)
output.bits.reminder := Mux(needCorrect, remainderCorrect, remainderNoCorrect)(xLen - 2, radixLog2 + guardBitWidth)
output.bits.quotient := Mux(needCorrect, quotientMinusOne, quotient)

/** width for truncated y */
Expand All @@ -97,8 +96,8 @@ class SRT4(
/** QDS module whose output needs to be decoded */
val selectedQuotientOH: UInt =
QDS(rWidth, ohWidth, dTruncateWidth - 1, tables, a)(
leftShift(partialReminderSum, radixLog2).head(rWidth),
leftShift(partialReminderCarry, radixLog2).head(rWidth),
partialReminderSum.head(rWidth),
partialReminderCarry.head(rWidth),
dividerNext.head(dTruncateWidth)(dTruncateWidth - 2, 0) //.1********* -> 1*** -> ***
)
// On-The-Fly conversion
Expand All @@ -120,8 +119,8 @@ class SRT4(

addition.csa.c32(
VecInit(
leftShift(partialReminderSum, radixLog2).head(wLen - radixLog2),
leftShift(partialReminderCarry, radixLog2).head(wLen - radixLog2 - 1) ## qdsSign,
partialReminderSum.head(xLen),
partialReminderCarry.head(xLen - 1) ## qdsSign,
Mux1H(selectedQuotientOH, dividerMap)
)
)
Expand All @@ -144,15 +143,15 @@ class SRT4(
})
val csa0 = addition.csa.c32(
VecInit(
leftShift(partialReminderSum, radixLog2).head(wLen - radixLog2),
leftShift(partialReminderCarry, radixLog2).head(wLen - radixLog2 - 1) ## qds0Sign,
partialReminderSum.head(xLen),
partialReminderCarry.head(xLen - 1) ## qds0Sign,
Mux1H(qHigh, dividerHMap)
)
)
addition.csa.c32(
VecInit(
csa0(1).head(wLen - radixLog2),
leftShift(csa0(0), 1).head(wLen - radixLog2 - 1) ## qds1Sign,
csa0(1).head(xLen),
leftShift(csa0(0), 1).head(xLen - 1) ## qds1Sign,
Mux1H(qLow, dividerLMap)
)
)
Expand Down
47 changes: 23 additions & 24 deletions arithmetic/src/division/srt/srt8/SRT8.scala
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,12 @@ class SRT8(

val guardBitWidth = 2
val xLen: Int = dividendWidth + radixLog2 + 1 + guardBitWidth
val wLen: Int = xLen + radixLog2

// IO
val input = IO(Flipped(DecoupledIO(new SRTInput(dividendWidth, dividerWidth, n, 3))))
val output = IO(ValidIO(new SRTOutput(dividerWidth, dividendWidth)))

val partialReminderCarryNext, partialReminderSumNext = Wire(UInt(wLen.W))
val partialReminderCarryNext, partialReminderSumNext = Wire(UInt(xLen.W))
val quotientNext, quotientMinusOneNext = Wire(UInt(n.W))
val dividerNext = Wire(UInt(dividerWidth.W))
val counterNext = Wire(UInt(log2Ceil(n).W))
Expand All @@ -47,8 +46,8 @@ class SRT8(

// State
// because we need a CSA to minimize the critical path
val partialReminderCarry = RegEnable(partialReminderCarryNext, 0.U(wLen.W), enable)
val partialReminderSum = RegEnable(partialReminderSumNext, 0.U(wLen.W), enable)
val partialReminderCarry = RegEnable(partialReminderCarryNext, 0.U(xLen.W), enable)
val partialReminderSum = RegEnable(partialReminderSumNext, 0.U(xLen.W), enable)
val divider = RegEnable(dividerNext, 0.U(dividerWidth.W), enable)
val quotient = RegEnable(quotientNext, 0.U(n.W), enable)
val quotientMinusOne = RegEnable(quotientMinusOneNext, 0.U(n.W), enable)
Expand All @@ -69,8 +68,8 @@ class SRT8(
val remainderNoCorrect: UInt = partialReminderSum + partialReminderCarry
val remainderCorrect: UInt =
partialReminderSum + partialReminderCarry + (divisorExtended << radixLog2)
val needCorrect: Bool = remainderNoCorrect(wLen - 4).asBool
output.bits.reminder := Mux(needCorrect, remainderCorrect, remainderNoCorrect)(wLen - 5, radixLog2 + guardBitWidth)
val needCorrect: Bool = remainderNoCorrect(xLen - 1).asBool
output.bits.reminder := Mux(needCorrect, remainderCorrect, remainderNoCorrect)(xLen - 2, radixLog2 + guardBitWidth)
output.bits.quotient := Mux(needCorrect, quotientMinusOne, quotient)

val rWidth: Int = 1 + radixLog2 + rTruncateWidth
Expand All @@ -85,8 +84,8 @@ class SRT8(
// qds
val selectedQuotientOH: UInt =
QDS(rWidth, ohWidth, dTruncateWidth - 1, tables, a)(
leftShift(partialReminderSum, radixLog2).head(rWidth),
leftShift(partialReminderCarry, radixLog2).head(rWidth),
partialReminderSum.head(rWidth),
partialReminderCarry.head(rWidth),
dividerNext.head(dTruncateWidth)(dTruncateWidth - 2, 0) //.1********* -> 1*** -> ***
)
// On-The-Fly conversion
Expand Down Expand Up @@ -115,15 +114,15 @@ class SRT8(
})
val csa0 = addition.csa.c32(
VecInit(
leftShift(partialReminderSum, radixLog2).head(wLen - radixLog2),
leftShift(partialReminderCarry, radixLog2).head(wLen - radixLog2 - 1) ## qdsSign0,
partialReminderSum.head(xLen),
partialReminderCarry.head(xLen - 1) ## qdsSign0,
Mux1H(qHigh, dividerHMap)
)
)
val csa1 = addition.csa.c32(
VecInit(
csa0(1).head(wLen - radixLog2),
leftShift(csa0(0), 1).head(wLen - radixLog2 - 1) ## qdsSign1,
csa0(1).head(xLen),
leftShift(csa0(0), 1).head(xLen - 1) ## qdsSign1,
Mux1H(qLow, dividerLMap)
)
)
Expand All @@ -143,15 +142,15 @@ class SRT8(
})
val csa0 = addition.csa.c32(
VecInit(
leftShift(partialReminderSum, radixLog2).head(wLen - radixLog2),
leftShift(partialReminderCarry, radixLog2).head(wLen - radixLog2 - 1) ## qdsSign0,
partialReminderSum.head(xLen),
partialReminderCarry.head(xLen - 1) ## qdsSign0,
Mux1H(qHigh, dividerHMap)
)
)
val csa1 = addition.csa.c32(
VecInit(
csa0(1).head(wLen - radixLog2),
leftShift(csa0(0), 1).head(wLen - radixLog2 - 1) ## qdsSign1,
csa0(1).head(xLen),
leftShift(csa0(0), 1).head(xLen - 1) ## qdsSign1,
Mux1H(qLow, dividerLMap)
)
)
Expand All @@ -171,15 +170,15 @@ class SRT8(
})
val csa0 = addition.csa.c32(
VecInit(
leftShift(partialReminderSum, radixLog2).head(wLen - radixLog2),
leftShift(partialReminderCarry, radixLog2).head(wLen - radixLog2 - 1) ## qdsSign0,
partialReminderSum.head(xLen),
partialReminderCarry.head(xLen - 1) ## qdsSign0,
Mux1H(qHigh, dividerHMap)
)
)
val csa1 = addition.csa.c32(
VecInit(
csa0(1).head(wLen - radixLog2),
leftShift(csa0(0), 1).head(wLen - radixLog2 - 1) ## qdsSign1,
csa0(1).head(xLen),
leftShift(csa0(0), 1).head(xLen - 1) ## qdsSign1,
Mux1H(qLow, dividerLMap)
)
)
Expand All @@ -199,15 +198,15 @@ class SRT8(
})
val csa0 = addition.csa.c32(
VecInit(
leftShift(partialReminderSum, radixLog2).head(wLen - radixLog2),
leftShift(partialReminderCarry, radixLog2).head(wLen - radixLog2 - 1) ## qdsSign0,
partialReminderSum.head(xLen),
partialReminderCarry.head(xLen - 1) ## qdsSign0,
Mux1H(qHigh, dividerHMap)
)
)
val csa1 = addition.csa.c32(
VecInit(
csa0(1).head(wLen - radixLog2),
leftShift(csa0(0), 1).head(wLen - radixLog2 - 1) ## qdsSign1,
csa0(1).head(xLen),
leftShift(csa0(0), 1).head(xLen - 1) ## qdsSign1,
Mux1H(qLow, dividerLMap)
)
)
Expand Down
Loading
Loading