Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix gemm cp #392

Merged
merged 1 commit into from
Nov 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 49 additions & 25 deletions hw/chisel_acc/src/main/scala/snax_acc/gemm/BlockGemm.scala
Original file line number Diff line number Diff line change
Expand Up @@ -98,29 +98,6 @@ class BlockGemm(params: GemmParams) extends Module with RequireAsyncReset {

val compute_fire = WireInit(0.B)

// -----------------------------------
// resgiter insert
// -----------------------------------

def a_bits_len = params.meshRow * params.tileSize * params.dataWidthA
def b_bits_len = params.tileSize * params.meshCol * params.dataWidthB
def a_b_bits_len = a_bits_len + b_bits_len

val combined_decoupled_a_b_in = Wire(Decoupled(UInt(a_b_bits_len.W)))
val combined_decoupled_a_b_out = Wire(Decoupled(UInt(a_b_bits_len.W)))
val a_split_out = Wire(UInt(a_bits_len.W))
val b_split_out = Wire(UInt(b_bits_len.W))

val a_b_cat = Module(new DecoupledCat2to1(a_bits_len, b_bits_len))

a_b_cat.io.in1 <> io.data.a_i
a_b_cat.io.in2 <> io.data.b_i
a_b_cat.io.out <> combined_decoupled_a_b_in
combined_decoupled_a_b_in -\\> combined_decoupled_a_b_out
a_split_out := combined_decoupled_a_b_out.bits(a_b_bits_len - 1, b_bits_len)
b_split_out := combined_decoupled_a_b_out.bits(b_bits_len - 1, 0)
// combined_decoupled_a_b_out will be connected to further control signals

// State declaration
val sIDLE :: sBUSY :: Nil = Enum(2)
val cstate = RegInit(sIDLE)
Expand Down Expand Up @@ -176,6 +153,53 @@ class BlockGemm(params: GemmParams) extends Module with RequireAsyncReset {
// write all the results out means the operation is done
computation_finish := d_output_counter === (M * N - 1.U) && io.data.d_o.fire && cstate === sBUSY

// -----------------------------------
// resgiter insert
// -----------------------------------

def a_bits_len = params.meshRow * params.tileSize * params.dataWidthA
def b_bits_len = params.tileSize * params.meshCol * params.dataWidthB
def sa_bits_len = params.dataWidthA
def sb_bits_len = params.dataWidthB

val combined_decoupled_a_b_in = Wire(
Decoupled(new CutBundle(a_bits_len, b_bits_len, sa_bits_len, sb_bits_len))
)
val combined_decoupled_a_b_out = Wire(
Decoupled(new CutBundle(a_bits_len, b_bits_len, sa_bits_len, sb_bits_len))
)
val a_split_out = Wire(UInt(a_bits_len.W))
val b_split_out = Wire(UInt(b_bits_len.W))
val subtraction_a_split_out = Wire(UInt(sa_bits_len.W))
val subtraction_b_split_out = Wire(UInt(sb_bits_len.W))

val decoupled_subtraction_a = Wire(Decoupled(UInt(sa_bits_len.W)))
val decoupled_subtraction_b = Wire(Decoupled(UInt(sb_bits_len.W)))

val a_b_sa_sb_cat = Module(
new DecoupledCat4to1(a_bits_len, b_bits_len, sa_bits_len, sb_bits_len)
)

// cat several decoupled signals into one for synchronization
a_b_sa_sb_cat.io.in1 <> io.data.a_i
a_b_sa_sb_cat.io.in2 <> io.data.b_i
a_b_sa_sb_cat.io.in3 <> decoupled_subtraction_a
a_b_sa_sb_cat.io.in4 <> decoupled_subtraction_b
a_b_sa_sb_cat.io.out <> combined_decoupled_a_b_in

// insert registers
combined_decoupled_a_b_in -\\> combined_decoupled_a_b_out
a_split_out := combined_decoupled_a_b_out.bits.a
b_split_out := combined_decoupled_a_b_out.bits.b
subtraction_a_split_out := combined_decoupled_a_b_out.bits.c
subtraction_b_split_out := combined_decoupled_a_b_out.bits.d
// combined_decoupled_a_b_out will be connected to further control signals

decoupled_subtraction_a.valid := cstate === sBUSY
decoupled_subtraction_a.bits := subtraction_a
decoupled_subtraction_b.valid := cstate === sBUSY
decoupled_subtraction_b.bits := subtraction_b

// write counter increment according to output data fire
when(io.data.d_o.fire) {
d_output_counter := d_output_counter + 1.U
Expand Down Expand Up @@ -268,8 +292,8 @@ class BlockGemm(params: GemmParams) extends Module with RequireAsyncReset {
// or when don't need to output d
gemm_array.io.ctrl.d_ready_i := Mux(io.data.d_o.valid, io.data.d_o.ready, 1.B)

gemm_array.io.ctrl.subtraction_a_i := subtraction_a
gemm_array.io.ctrl.subtraction_b_i := subtraction_b
gemm_array.io.ctrl.subtraction_a_i := subtraction_a_split_out
gemm_array.io.ctrl.subtraction_b_i := subtraction_b_split_out

// data signals
gemm_array.io.data.a_i := a_split_out
Expand Down
32 changes: 28 additions & 4 deletions hw/chisel_acc/src/main/scala/snax_acc/utils/DecoupledCat.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3,24 +3,48 @@ package snax_acc.utils
import chisel3._
import chisel3.util._

class DecoupledCat2to1[T <: Data](aWidth: Int, bWidth: Int) extends Module {
class CutBundle(aWidth: Int, bWidth: Int, cWidth: Int, dWidth: Int)
extends Bundle {
val a = UInt(aWidth.W)
val b = UInt(bWidth.W)
val c = UInt(cWidth.W)
val d = UInt(dWidth.W)
}

class DecoupledCat4to1[T <: Data](
aWidth: Int,
bWidth: Int,
cWidth: Int,
dWith: Int
) extends Module {
val io = IO(new Bundle {
val in1 =
Flipped(Decoupled(UInt(aWidth.W))) // First decoupled input interface
val in2 =
Flipped(Decoupled(UInt(bWidth.W))) // Second decoupled input interface
val out = Decoupled(UInt((aWidth + bWidth).W)) // Decoupled output interface
val in3 =
Flipped(Decoupled(UInt(cWidth.W))) // Third decoupled input interface
val in4 =
Flipped(Decoupled(UInt(dWith.W))) // Fourth decoupled input interface
val out = Decoupled(
new CutBundle(aWidth, bWidth, cWidth, dWith)
) // Decoupled output interface
})

// Combine the bits of in1 and in2, in1 in higher bits
io.out.bits := Cat(io.in1.bits, io.in2.bits)
io.out.bits.a := io.in1.bits
io.out.bits.b := io.in2.bits
io.out.bits.c := io.in3.bits
io.out.bits.d := io.in4.bits

// Output is valid only when both inputs are valid
io.out.valid := io.in1.valid && io.in2.valid
io.out.valid := io.in1.valid && io.in2.valid && io.in3.valid && io.in4.valid

// Ready is asserted to inputs when the output is ready
io.in1.ready := io.out.ready && io.out.valid
io.in2.ready := io.out.ready && io.out.valid
io.in3.ready := io.out.ready && io.out.valid
io.in4.ready := io.out.ready && io.out.valid

}

Expand Down
Loading