Skip to content

Commit

Permalink
add cut in gemmx datapath
Browse files Browse the repository at this point in the history
  • Loading branch information
xiaoling-yi committed Sep 25, 2024
1 parent f936698 commit 58b4b4a
Show file tree
Hide file tree
Showing 3 changed files with 209 additions and 6 deletions.
46 changes: 40 additions & 6 deletions hw/chisel_acc/src/main/scala/snax_acc/gemm/BlockGemm.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ package snax_acc.gemm

import chisel3._
import chisel3.util._
import snax_acc.utils._
import snax_acc.utils.DecoupledCut._

// The BlockGemm's control port declaration.
class BlockGemmCtrlIO(params: GemmParams) extends Bundle {
Expand Down Expand Up @@ -80,6 +82,7 @@ class BlockGemm(params: GemmParams) extends Module with RequireAsyncReset {
// control signals for the counter incremental
val accumulation = WireInit(0.B)
val a_b_data_valid = WireInit(0.B)
val a_b_data_ready = WireInit(0.B)

val gemm_a_b_input_fire = WireInit(0.B)
val gemm_output_fire = WireInit(0.B)
Expand All @@ -95,6 +98,31 @@ class BlockGemm(params: GemmParams) extends Module with RequireAsyncReset {

val compute_fire = WireInit(0.B)

// -----------------------------------
// resgiter insert
// -----------------------------------

def a_bits_len = params.meshRow * params.tileSize * params.dataWidthA
def b_bits_len = params.tileSize * params.meshCol * params.dataWidthB
def a_b_bits_len = a_bits_len + b_bits_len
val combined_a_b_bits = WireInit(0.U(a_b_bits_len.W))

val combined_decoupled_a_b_in = Wire(Decoupled(UInt(a_b_bits_len.W)))
val combined_decoupled_a_b_out = Wire(Decoupled(UInt(a_b_bits_len.W)))
val a_split_out = Wire(Decoupled(UInt(a_bits_len.W)))
val b_split_out = Wire(Decoupled(UInt(b_bits_len.W)))

val a_b_cat = Module(new DecoupledCat2to1(a_bits_len, b_bits_len))
val a_b_split = Module(new DecoupledSplit1to2(a_b_bits_len, a_bits_len, b_bits_len))

a_b_cat.io.in1 <> io.data.a_i
a_b_cat.io.in2 <> io.data.b_i
a_b_cat.io.out <> combined_decoupled_a_b_in
combined_decoupled_a_b_in -\\> combined_decoupled_a_b_out
a_b_split.io.in <> combined_decoupled_a_b_out
a_split_out <> a_b_split.io.out1
b_split_out <> a_b_split.io.out2

// State declaration
val sIDLE :: sBUSY :: Nil = Enum(2)
val cstate = RegInit(sIDLE)
Expand Down Expand Up @@ -158,10 +186,12 @@ class BlockGemm(params: GemmParams) extends Module with RequireAsyncReset {
}

// input data valid signal, when both a and b are valid, the input data is valid
a_b_data_valid := io.data.a_i.valid && io.data.b_i.valid && cstate === sBUSY
a_b_data_valid := a_split_out.valid && b_split_out.valid && cstate === sBUSY
a_b_data_ready := gemm_array.io.ctrl.a_b_c_ready_o && cstate === sBUSY

// gemm input fire signal, when both a and b are valid and gemm is ready for new input data
// stall the a b compute if add c
gemm_a_b_input_fire := gemm_array.io.ctrl.a_b_c_ready_o && a_b_data_valid && !add_c && !must_add_c
gemm_a_b_input_fire := a_b_data_ready && a_b_data_valid && !add_c && !must_add_c

// accumulation counter for generating the accumulation signal for Gemm Array
// value change according to gemm_a_b_input_fire and add_c_fire
Expand Down Expand Up @@ -244,13 +274,17 @@ class BlockGemm(params: GemmParams) extends Module with RequireAsyncReset {
gemm_array.io.ctrl.subtraction_b_i := subtraction_b

// data signals
gemm_array.io.data.a_i := io.data.a_i.bits
gemm_array.io.data.b_i := io.data.b_i.bits
gemm_array.io.data.a_i := a_split_out.bits
gemm_array.io.data.b_i := b_split_out.bits
gemm_array.io.data.c_i := io.data.c_i.bits

// ready for pop out the data from outside
io.data.a_i.ready := cstate === sBUSY && gemm_a_b_input_fire
io.data.b_i.ready := cstate === sBUSY && gemm_a_b_input_fire
// io.data.a_i.ready := cstate === sBUSY && gemm_a_b_input_fire
// io.data.b_i.ready := cstate === sBUSY && gemm_a_b_input_fire

a_split_out.ready := cstate === sBUSY && gemm_a_b_input_fire
b_split_out.ready := cstate === sBUSY && gemm_a_b_input_fire

io.data.c_i.ready := cstate === sBUSY && add_c_fire

// gemm output signals
Expand Down
125 changes: 125 additions & 0 deletions hw/chisel_acc/src/main/scala/snax_acc/utils/CustomOperators.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
package snax_acc.utils

import chisel3._
import chisel3.util._

/** The definition of -|> / -||> / -|||> connector for decoupled signal it
* connects leftward Decoupled signal (Decoupled port) and rightward Decoupled
* signal (Flipped port); and insert one level of pipeline in between to avoid
* long combinatorial datapath
*/

class DataCut[T <: Data](gen: T, delay: Int) extends Module {
val io = IO(new Bundle {
val in = Flipped(Decoupled(gen))
val out = Decoupled(gen)
})

val in = Wire(ValidIO(gen))
val out = Wire(ValidIO(gen))
val shiftPermission = Wire(Bool())
val shiftSuggestion = Wire(Bool())
val shift =
shiftPermission && shiftSuggestion // shift is true when both shiftPermission and shiftSuggestion are true
in.bits := io.in.bits
in.valid := io.in.valid
io.in.ready := shiftPermission
io.out.valid := out.valid
io.out.bits := out.bits
out := ShiftRegister(in, delay, shift)

// shiftPermission is true when last item's valid is true and io.out.ready is true or last item's valid is false
shiftPermission := (out.valid && io.out.ready) || !out.valid

val dataInsideShiftRegister = Wire(Bool())

// shiftSuggestion is true when dataInsideShiftRegister is true or input.valid is true
shiftSuggestion := dataInsideShiftRegister || io.in.valid

// When the counter is abbout to overflow, data does not inside the shift register
val insideCounter = Counter(0 to delay, shift, io.in.valid)
dataInsideShiftRegister := insideCounter._1 =/= delay.U

}

object DecoupledCut {
implicit class BufferedDecoupledConnectionOp[T <: Data](
val left: DecoupledIO[T]
) {
// This class defines the implicit class for the new operand -|>,-||>, -|||> for DecoupleIO

def -|>(
right: DecoupledIO[T]
)(implicit sourceInfo: chisel3.experimental.SourceInfo): DecoupledIO[T] = {
val buffer = Module(
new Queue(chiselTypeOf(left.bits), entries = 1, pipe = false)
)
buffer.suggestName("fullCutHalfBandwidth")

left <> buffer.io.enq
buffer.io.deq <> right
right
}

def -||>(
right: DecoupledIO[T]
)(implicit sourceInfo: chisel3.experimental.SourceInfo): DecoupledIO[T] = {
val buffer = Module(
new Queue(chiselTypeOf(left.bits), entries = 2, pipe = false)
)
buffer.suggestName("fullCutFullBandwidth")
left <> buffer.io.enq
buffer.io.deq <> right
right
}

def -\>(
right: DecoupledIO[T]
)(implicit sourceInfo: chisel3.experimental.SourceInfo): DecoupledIO[T] = {
val buffer = Module(
new DataCut(chiselTypeOf(left.bits), delay = 1)
)
buffer.suggestName("dataCut1")

left <> buffer.io.in
buffer.io.out <> right
right
}

def -\\>(
right: DecoupledIO[T]
)(implicit sourceInfo: chisel3.experimental.SourceInfo): DecoupledIO[T] = {
val buffer = Module(
new DataCut(chiselTypeOf(left.bits), delay = 2)
)
buffer.suggestName("dataCut2")

left <> buffer.io.in
buffer.io.out <> right
right
}

def -\\\>(
right: DecoupledIO[T]
)(implicit sourceInfo: chisel3.experimental.SourceInfo): DecoupledIO[T] = {
val buffer = Module(
new DataCut(chiselTypeOf(left.bits), delay = 3)
)
buffer.suggestName("dataCut3")

left <> buffer.io.in
buffer.io.out <> right
right
}
}
}

object BitsConcat {
implicit class UIntConcatOp[T <: Bits](val left: T) {
// This class defines the implicit class for the new operand ++ for UInt
def ++(
right: T
)(implicit sourceInfo: chisel3.experimental.SourceInfo): T =
Cat(left, right).asInstanceOf[T]
}
}
44 changes: 44 additions & 0 deletions hw/chisel_acc/src/main/scala/snax_acc/utils/DecoupledCat.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
package snax_acc.utils

import chisel3._
import chisel3.util._

class DecoupledCat2to1[T <: Data](aWidth: Int, bWidth: Int) extends Module{
val io = IO(new Bundle {
val in1 = Flipped(Decoupled(UInt(aWidth.W))) // First decoupled input interface
val in2 = Flipped(Decoupled(UInt(bWidth.W))) // Second decoupled input interface
val out = Decoupled(UInt((aWidth + bWidth).W)) // Decoupled output interface
})

// Combine the bits of in1 and in2, in1 in higher bits
io.out.bits := Cat(io.in1.bits, io.in2.bits)

// Output is valid only when both inputs are valid
io.out.valid := io.in1.valid && io.in2.valid

// Ready is asserted to inputs when the output is ready
io.in1.ready := io.out.ready && io.out.valid
io.in2.ready := io.out.ready && io.out.valid

}

class DecoupledSplit1to2(cWidth: Int, aWidth: Int, bWidth: Int) extends Module {
require(cWidth == aWidth + bWidth, "cWidth must be the sum of aWidth and bWidth")

val io = IO(new Bundle {
val in = Flipped(Decoupled(UInt(cWidth.W))) // Large decoupled input (c)
val out1 = Decoupled(UInt(aWidth.W)) // Smaller decoupled output (a)
val out2 = Decoupled(UInt(bWidth.W)) // Smaller decoupled output (b)
})

// Split the input bits into two parts
io.out1.bits := io.in.bits(cWidth - 1, bWidth) // Upper bits go to out1 (a)
io.out2.bits := io.in.bits(bWidth - 1, 0) // Lower bits go to out2 (b)

// Both outputs are valid when the input is valid
io.out1.valid := io.in.valid
io.out2.valid := io.in.valid

// Input is ready when both outputs are ready
io.in.ready := io.out1.ready && io.out2.ready
}

0 comments on commit 58b4b4a

Please sign in to comment.