From be0dff56ade2c7b96efbda81c7d0ee59bd2d221d Mon Sep 17 00:00:00 2001 From: Xiaoling Yi <143962462+xiaoling-yi@users.noreply.github.com> Date: Thu, 26 Sep 2024 11:08:00 +0200 Subject: [PATCH] add cut in gemmx datapath (#346) * add cut in gemmx datapath * rm comments --- .../main/scala/snax_acc/gemm/BlockGemm.scala | 43 +++++- .../snax_acc/utils/CustomOperators.scala | 125 ++++++++++++++++++ .../scala/snax_acc/utils/DecoupledCat.scala | 44 ++++++ 3 files changed, 205 insertions(+), 7 deletions(-) create mode 100644 hw/chisel_acc/src/main/scala/snax_acc/utils/CustomOperators.scala create mode 100644 hw/chisel_acc/src/main/scala/snax_acc/utils/DecoupledCat.scala diff --git a/hw/chisel_acc/src/main/scala/snax_acc/gemm/BlockGemm.scala b/hw/chisel_acc/src/main/scala/snax_acc/gemm/BlockGemm.scala index d861c325e..ca668359a 100644 --- a/hw/chisel_acc/src/main/scala/snax_acc/gemm/BlockGemm.scala +++ b/hw/chisel_acc/src/main/scala/snax_acc/gemm/BlockGemm.scala @@ -2,6 +2,8 @@ package snax_acc.gemm import chisel3._ import chisel3.util._ +import snax_acc.utils._ +import snax_acc.utils.DecoupledCut._ // The BlockGemm's control port declaration. class BlockGemmCtrlIO(params: GemmParams) extends Bundle { @@ -80,6 +82,7 @@ class BlockGemm(params: GemmParams) extends Module with RequireAsyncReset { // control signals for the counter incremental val accumulation = WireInit(0.B) val a_b_data_valid = WireInit(0.B) + val a_b_data_ready = WireInit(0.B) val gemm_a_b_input_fire = WireInit(0.B) val gemm_output_fire = WireInit(0.B) @@ -95,6 +98,30 @@ class BlockGemm(params: GemmParams) extends Module with RequireAsyncReset { val compute_fire = WireInit(0.B) + // ----------------------------------- + // resgiter insert + // ----------------------------------- + + def a_bits_len = params.meshRow * params.tileSize * params.dataWidthA + def b_bits_len = params.tileSize * params.meshCol * params.dataWidthB + def a_b_bits_len = a_bits_len + b_bits_len + + val combined_decoupled_a_b_in = Wire(Decoupled(UInt(a_b_bits_len.W))) + val combined_decoupled_a_b_out = Wire(Decoupled(UInt(a_b_bits_len.W))) + val a_split_out = Wire(Decoupled(UInt(a_bits_len.W))) + val b_split_out = Wire(Decoupled(UInt(b_bits_len.W))) + + val a_b_cat = Module(new DecoupledCat2to1(a_bits_len, b_bits_len)) + val a_b_split = Module(new DecoupledSplit1to2(a_b_bits_len, a_bits_len, b_bits_len)) + + a_b_cat.io.in1 <> io.data.a_i + a_b_cat.io.in2 <> io.data.b_i + a_b_cat.io.out <> combined_decoupled_a_b_in + combined_decoupled_a_b_in -\\> combined_decoupled_a_b_out + a_b_split.io.in <> combined_decoupled_a_b_out + a_split_out <> a_b_split.io.out1 + b_split_out <> a_b_split.io.out2 + // State declaration val sIDLE :: sBUSY :: Nil = Enum(2) val cstate = RegInit(sIDLE) @@ -158,10 +185,12 @@ class BlockGemm(params: GemmParams) extends Module with RequireAsyncReset { } // input data valid signal, when both a and b are valid, the input data is valid - a_b_data_valid := io.data.a_i.valid && io.data.b_i.valid && cstate === sBUSY + a_b_data_valid := a_split_out.valid && b_split_out.valid && cstate === sBUSY + a_b_data_ready := gemm_array.io.ctrl.a_b_c_ready_o && cstate === sBUSY + // gemm input fire signal, when both a and b are valid and gemm is ready for new input data // stall the a b compute if add c - gemm_a_b_input_fire := gemm_array.io.ctrl.a_b_c_ready_o && a_b_data_valid && !add_c && !must_add_c + gemm_a_b_input_fire := a_b_data_ready && a_b_data_valid && !add_c && !must_add_c // accumulation counter for generating the accumulation signal for Gemm Array // value change according to gemm_a_b_input_fire and add_c_fire @@ -244,13 +273,13 @@ class BlockGemm(params: GemmParams) extends Module with RequireAsyncReset { gemm_array.io.ctrl.subtraction_b_i := subtraction_b // data signals - gemm_array.io.data.a_i := io.data.a_i.bits - gemm_array.io.data.b_i := io.data.b_i.bits + gemm_array.io.data.a_i := a_split_out.bits + gemm_array.io.data.b_i := b_split_out.bits gemm_array.io.data.c_i := io.data.c_i.bits - // ready for pop out the data from outside - io.data.a_i.ready := cstate === sBUSY && gemm_a_b_input_fire - io.data.b_i.ready := cstate === sBUSY && gemm_a_b_input_fire + a_split_out.ready := cstate === sBUSY && gemm_a_b_input_fire + b_split_out.ready := cstate === sBUSY && gemm_a_b_input_fire + io.data.c_i.ready := cstate === sBUSY && add_c_fire // gemm output signals diff --git a/hw/chisel_acc/src/main/scala/snax_acc/utils/CustomOperators.scala b/hw/chisel_acc/src/main/scala/snax_acc/utils/CustomOperators.scala new file mode 100644 index 000000000..acff513ff --- /dev/null +++ b/hw/chisel_acc/src/main/scala/snax_acc/utils/CustomOperators.scala @@ -0,0 +1,125 @@ +package snax_acc.utils + +import chisel3._ +import chisel3.util._ + +/** The definition of -|> / -||> / -|||> connector for decoupled signal it + * connects leftward Decoupled signal (Decoupled port) and rightward Decoupled + * signal (Flipped port); and insert one level of pipeline in between to avoid + * long combinatorial datapath + */ + +class DataCut[T <: Data](gen: T, delay: Int) extends Module { + val io = IO(new Bundle { + val in = Flipped(Decoupled(gen)) + val out = Decoupled(gen) + }) + + val in = Wire(ValidIO(gen)) + val out = Wire(ValidIO(gen)) + val shiftPermission = Wire(Bool()) + val shiftSuggestion = Wire(Bool()) + val shift = + shiftPermission && shiftSuggestion // shift is true when both shiftPermission and shiftSuggestion are true + in.bits := io.in.bits + in.valid := io.in.valid + io.in.ready := shiftPermission + io.out.valid := out.valid + io.out.bits := out.bits + out := ShiftRegister(in, delay, shift) + + // shiftPermission is true when last item's valid is true and io.out.ready is true or last item's valid is false + shiftPermission := (out.valid && io.out.ready) || !out.valid + + val dataInsideShiftRegister = Wire(Bool()) + + // shiftSuggestion is true when dataInsideShiftRegister is true or input.valid is true + shiftSuggestion := dataInsideShiftRegister || io.in.valid + + // When the counter is abbout to overflow, data does not inside the shift register + val insideCounter = Counter(0 to delay, shift, io.in.valid) + dataInsideShiftRegister := insideCounter._1 =/= delay.U + +} + +object DecoupledCut { + implicit class BufferedDecoupledConnectionOp[T <: Data]( + val left: DecoupledIO[T] + ) { + // This class defines the implicit class for the new operand -|>,-||>, -|||> for DecoupleIO + + def -|>( + right: DecoupledIO[T] + )(implicit sourceInfo: chisel3.experimental.SourceInfo): DecoupledIO[T] = { + val buffer = Module( + new Queue(chiselTypeOf(left.bits), entries = 1, pipe = false) + ) + buffer.suggestName("fullCutHalfBandwidth") + + left <> buffer.io.enq + buffer.io.deq <> right + right + } + + def -||>( + right: DecoupledIO[T] + )(implicit sourceInfo: chisel3.experimental.SourceInfo): DecoupledIO[T] = { + val buffer = Module( + new Queue(chiselTypeOf(left.bits), entries = 2, pipe = false) + ) + buffer.suggestName("fullCutFullBandwidth") + left <> buffer.io.enq + buffer.io.deq <> right + right + } + + def -\>( + right: DecoupledIO[T] + )(implicit sourceInfo: chisel3.experimental.SourceInfo): DecoupledIO[T] = { + val buffer = Module( + new DataCut(chiselTypeOf(left.bits), delay = 1) + ) + buffer.suggestName("dataCut1") + + left <> buffer.io.in + buffer.io.out <> right + right + } + + def -\\>( + right: DecoupledIO[T] + )(implicit sourceInfo: chisel3.experimental.SourceInfo): DecoupledIO[T] = { + val buffer = Module( + new DataCut(chiselTypeOf(left.bits), delay = 2) + ) + buffer.suggestName("dataCut2") + + left <> buffer.io.in + buffer.io.out <> right + right + } + + def -\\\>( + right: DecoupledIO[T] + )(implicit sourceInfo: chisel3.experimental.SourceInfo): DecoupledIO[T] = { + val buffer = Module( + new DataCut(chiselTypeOf(left.bits), delay = 3) + ) + buffer.suggestName("dataCut3") + + left <> buffer.io.in + buffer.io.out <> right + right + } + } +} + +object BitsConcat { + implicit class UIntConcatOp[T <: Bits](val left: T) { + // This class defines the implicit class for the new operand ++ for UInt + def ++( + right: T + )(implicit sourceInfo: chisel3.experimental.SourceInfo): T = + Cat(left, right).asInstanceOf[T] + } +} diff --git a/hw/chisel_acc/src/main/scala/snax_acc/utils/DecoupledCat.scala b/hw/chisel_acc/src/main/scala/snax_acc/utils/DecoupledCat.scala new file mode 100644 index 000000000..61aa41edc --- /dev/null +++ b/hw/chisel_acc/src/main/scala/snax_acc/utils/DecoupledCat.scala @@ -0,0 +1,44 @@ +package snax_acc.utils + +import chisel3._ +import chisel3.util._ + +class DecoupledCat2to1[T <: Data](aWidth: Int, bWidth: Int) extends Module{ + val io = IO(new Bundle { + val in1 = Flipped(Decoupled(UInt(aWidth.W))) // First decoupled input interface + val in2 = Flipped(Decoupled(UInt(bWidth.W))) // Second decoupled input interface + val out = Decoupled(UInt((aWidth + bWidth).W)) // Decoupled output interface + }) + + // Combine the bits of in1 and in2, in1 in higher bits + io.out.bits := Cat(io.in1.bits, io.in2.bits) + + // Output is valid only when both inputs are valid + io.out.valid := io.in1.valid && io.in2.valid + + // Ready is asserted to inputs when the output is ready + io.in1.ready := io.out.ready && io.out.valid + io.in2.ready := io.out.ready && io.out.valid + +} + +class DecoupledSplit1to2(cWidth: Int, aWidth: Int, bWidth: Int) extends Module { + require(cWidth == aWidth + bWidth, "cWidth must be the sum of aWidth and bWidth") + + val io = IO(new Bundle { + val in = Flipped(Decoupled(UInt(cWidth.W))) // Large decoupled input (c) + val out1 = Decoupled(UInt(aWidth.W)) // Smaller decoupled output (a) + val out2 = Decoupled(UInt(bWidth.W)) // Smaller decoupled output (b) + }) + + // Split the input bits into two parts + io.out1.bits := io.in.bits(cWidth - 1, bWidth) // Upper bits go to out1 (a) + io.out2.bits := io.in.bits(bWidth - 1, 0) // Lower bits go to out2 (b) + + // Both outputs are valid when the input is valid + io.out1.valid := io.in.valid + io.out2.valid := io.in.valid + + // Input is ready when both outputs are ready + io.in.ready := io.out1.ready && io.out2.ready +}