Skip to content

Commit

Permalink
add transpose feature for reader (#279)
Browse files Browse the repository at this point in the history
* add transpose feature for reader

* rm trans in r&w

* fix wire connect for no trans

* scalafmr
  • Loading branch information
xiaoling-yi committed Aug 26, 2024
1 parent 37d2736 commit 1e16c05
Show file tree
Hide file tree
Showing 10 changed files with 109 additions and 25 deletions.
38 changes: 37 additions & 1 deletion hw/chisel/src/main/scala/snax/streamer/DataReader.scala
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ class DataReaderIO(
"fifoWidth should match with TCDM datawidth for now!"
)

val ifTranspose = if (params.hasTranspose) Some(Input(Bool())) else None

}

/** This class is data reader module,.It is responsible for sending read request
Expand Down Expand Up @@ -120,8 +122,42 @@ class DataReader(
)
}

// transpose the data
// !!!attention: only works for 8x8 matrix!!!
val data_fifo_input_concat = Cat(data_fifo_input.reverse)
val data_fifo_input_transpose = Wire(Vec(8, Vec(8, UInt(8.W))))
if (params.hasTranspose) {
require(
params.tcdmDataWidth == 64,
"transposeInWidth must be tcdmDataWidth = 64 for now"
)
require(
params.tcdmPortsNum == 8,
"transposeOutWidth must be params.tcdmPortsNum = 8 for now"
)
}

// gether all the response data
io.data_fifo_o.bits := Cat(data_fifo_input.reverse)
if (params.hasTranspose) {
for (i <- 0 until 8) {
for (j <- 0 until 8) {
data_fifo_input_transpose(i)(j) := data_fifo_input_concat(
i * 8 + j * 8 * 8 + 7,
i * 8 + j * 8 * 8 + 0
)
}
}
when(io.ifTranspose.get === true.B) {
io.data_fifo_o.bits := data_fifo_input_transpose.asUInt
}.otherwise {
io.data_fifo_o.bits := Cat(data_fifo_input.reverse)
}
} else {
io.data_fifo_o.bits := Cat(data_fifo_input.reverse)
data_fifo_input_transpose := VecInit(
Seq.fill(8)(VecInit(Seq.fill(8)(0.U(8.W))))
)
}

// ************************************************************
// ********** Logic for handling fifo handshake ***************
Expand Down
6 changes: 5 additions & 1 deletion hw/chisel/src/main/scala/snax/streamer/Parameters.scala
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ case class SpatialAddrGenUnitParams(
case class DataMoverParams(
tcdmPortsNum: Int,
addrWidth: Int,
hasTranspose: Boolean = false,
spatialBounds: Seq[Int],
spatialDim: Int,
elementWidth: Int,
Expand Down Expand Up @@ -116,6 +117,8 @@ trait HasStreamerCoreParams {
val ifShareTempAddrGenLoopBounds: Boolean

val addrWidth: Int

val hasTranspose: Boolean
}

/** trait for Streamer inferred parameters
Expand Down Expand Up @@ -206,7 +209,8 @@ case class StreamerParams(
readOnlyCsrNum: Int = 2,
csrAddrWidth: Int = 32,
ifShareTempAddrGenLoopBounds: Boolean = true,
tagName: String = ""
tagName: String = "",
hasTranspose: Boolean = false
) extends HasStreamerCoreParams
with HasStreamerInferredParams
with CommonParams
16 changes: 16 additions & 0 deletions hw/chisel/src/main/scala/snax/streamer/Streamer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,9 @@ class StreamerCsrIO(
val ptr_i =
Vec(params.dataMoverNum, UInt(params.addrWidth.W))

// only has transpose function for the data readers
val ifTranspose =
if (params.hasTranspose) Some(Vec(params.dataReaderNum, Bool())) else None
}

// data related io
Expand Down Expand Up @@ -393,6 +396,19 @@ class Streamer(
}
}

// the cfg for transpose for only reader
val transpose_cfg = RegInit(0.U.asTypeOf(Vec(params.dataReaderNum, Bool())))
// store the transpose configuration when the cfg is valid

if (params.hasTranspose) {
when(io.csr.valid) {
transpose_cfg := io.csr.bits.ifTranspose.get
}
for (i <- 0 until params.dataReaderNum) {
data_reader(i).io.ifTranspose.get := transpose_cfg(i)
}
}

// data reader and data writer <> address generation units interface
for (i <- 0 until params.dataMoverNum) {
if (i < params.dataReaderNum) {
Expand Down
14 changes: 12 additions & 2 deletions hw/chisel/src/main/scala/snax/streamer/StreamerTop.scala
Original file line number Diff line number Diff line change
Expand Up @@ -42,12 +42,13 @@ class StreamerTop(
override val desiredName = params.tagName + "StreamerTop"

var csrNumReadWrite: Int = 0
val transposeCSRNum = if (params.hasTranspose) 1 else 0
if (params.ifShareTempAddrGenLoopBounds == true) {
csrNumReadWrite =
params.temporalDimInt + params.dataMoverNum * params.temporalDimInt + params.spatialDim.sum + params.dataMoverNum + 1
params.temporalDimInt + params.dataMoverNum * params.temporalDimInt + params.spatialDim.sum + params.dataMoverNum + transposeCSRNum + 1
} else {
csrNumReadWrite =
params.temporalDimSeq.sum + params.temporalDimSeq.sum + params.spatialDim.sum + params.dataMoverNum + 1
params.temporalDimSeq.sum + params.temporalDimSeq.sum + params.spatialDim.sum + params.dataMoverNum + transposeCSRNum + 1
}

val io = IO(
Expand Down Expand Up @@ -194,6 +195,15 @@ class StreamerTop(
}
}

// transpose configruations
require(params.dataReaderNum <= 32, "dataMoverNum should be less than 32")
if (params.hasTranspose) {
for (i <- 0 until params.dataReaderNum) {
streamer.io.csr.bits.ifTranspose
.get(i) := csr_manager.io.csr_config_out.bits.last(i)
}
}

// io.data and streamer data ports connection
io.data <> streamer.io.data

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ object StreamerTestConstant extends CommonParams {
object StreamerWithReaderWriterTestConstant extends CommonParams {

def addrWidth = 17

def temporalAddrGenUnitParams: Seq[TemporalAddrGenUnitParams] =
Seq(
TemporalAddrGenUnitParams(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,7 @@ class basicCounterTester extends AnyFlatSpec with ChiselScalatestTester {
}
}

class AddressGenUnitTester
extends AnyFlatSpec
with ChiselScalatestTester {
class AddressGenUnitTester extends AnyFlatSpec with ChiselScalatestTester {

println(
getVerilogString(new AddressGenUnit(AddressGenUnitParam()))
Expand Down
12 changes: 11 additions & 1 deletion hw/templates/stream_param_gen.scala.tpl
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,12 @@ object StreamerParametersGen extends CommonParams {
def addrWidth = ${tcdm_addr_width}

% if "has_transpose" in cfg["snax_streamer_cfg"] and cfg["snax_streamer_cfg"]["has_transpose"]:
def hasTranspose = true
% else:
def hasTranspose = false
% endif

def temporalAddrGenUnitParams: Seq[TemporalAddrGenUnitParams] =
Seq(
% for idx in range(0,len(cfg["snax_streamer_cfg"]["temporal_addrgen_unit_params"]["loop_dim"])):
Expand Down Expand Up @@ -109,6 +115,7 @@ ${', ' if not loop.last else ''}
DataMoverParams(
tcdmPortsNum = ${cfg["snax_streamer_cfg"]["data_reader_params"]["tcdm_ports_num"][idx]},
addrWidth,
hasTranspose,
spatialBounds = Seq(\
% for c in cfg["snax_streamer_cfg"]["data_reader_params"]["spatial_bounds"][idx]:
${c}${', ' if not loop.last else ''}\
Expand All @@ -130,6 +137,7 @@ ${c}${', ' if not loop.last else ''}\
DataMoverParams(
tcdmPortsNum = ${cfg["snax_streamer_cfg"]["data_writer_params"]["tcdm_ports_num"][idx]},
addrWidth,
hasTranspose = false,
spatialBounds = Seq(\
% for c in cfg["snax_streamer_cfg"]["data_writer_params"]["spatial_bounds"][idx]:
${c}${', ' if not loop.last else ''}\
Expand All @@ -151,6 +159,7 @@ ${c}${', ' if not loop.last else ''}\
DataMoverParams(
tcdmPortsNum = ${cfg["snax_streamer_cfg"]["data_reader_writer_params"]["tcdm_ports_num"][idx]},
addrWidth,
hasTranspose = false,
spatialBounds = Seq(\
% for c in cfg["snax_streamer_cfg"]["data_reader_writer_params"]["spatial_bounds"][idx]:
${c}${', ' if not loop.last else ''}\
Expand Down Expand Up @@ -191,7 +200,8 @@ object StreamerTopGen {
addrWidth = StreamerParametersGen.addrWidth,
stationarity = StreamerParametersGen.stationarity,
ifShareTempAddrGenLoopBounds = StreamerParametersGen.ifShareTempAddrGenLoopBounds,
tagName = "${cfg["tag_name"]}_streamer_"
tagName = "${cfg["tag_name"]}_streamer_",
hasTranspose = StreamerParametersGen.hasTranspose
)
),
Array("--target-dir", outPath)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,8 @@
}

stationarity: [0,0,0,0,0]

has_transpose: true
},
// SNAX Streamer Templates
snax_data_reshuffler_streamer_template :{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -155,53 +155,56 @@ void set_gemmx_streamer_csr(

// base ptr for D32
write_csr(1010, (uint32_t)(delta_local_d32 + snrt_l1_next()));

// transpose or not
write_csr(1011, 0);
}

// Set CSR to start STREAMER
void set_gemmx_streamer_start() { write_csr(1011, 1); }
void set_gemmx_streamer_start() { write_csr(1012, 1); }

// Set GEMM configuration CSR
void set_gemmx_csr(int tempLoop0, int tempLoop1, int tempLoop2,
int subtractions, uint32_t csr0, uint32_t csr1,
uint32_t csr2, uint32_t temporal_loop_bound,
uint32_t bypassSIMD) {
// set loop bounds, from innermost to outermost, aka from K to N to M
write_csr(1014, tempLoop0);
write_csr(1015, tempLoop1);
write_csr(1016, tempLoop2);
write_csr(1015, tempLoop0);
write_csr(1016, tempLoop1);
write_csr(1017, tempLoop2);

// set subtraction a and b
write_csr(1017, subtractions);
write_csr(1018, subtractions);

// set the constants for the SIMD unit
write_csr(1018, csr0);
write_csr(1019, csr1);
write_csr(1020, csr2);
write_csr(1019, csr0);
write_csr(1020, csr1);
write_csr(1021, csr2);

// set the temporal loop bound
write_csr(1021, temporal_loop_bound);
write_csr(1022, bypassSIMD);
write_csr(1022, temporal_loop_bound);
write_csr(1023, bypassSIMD);
}

// Set CSR to start GEMM
void set_gemmx_start() { write_csr(1023, 1); }
void set_gemmx_start() { write_csr(1024, 1); }

// Stall until Streamer and GEMM accelerator finish
void wait_gemmx_and_streamer() {
write_csr(1011, 0);
write_csr(1011, 0);
write_csr(1023, 0);
write_csr(1012, 0);
write_csr(1012, 0);
write_csr(1024, 0);
}

// Read performance counter of the Streamer, a read-only CSR
uint32_t read_gemmx_streamer_perf_counter() {
uint32_t perf_counter = read_csr(1013);
uint32_t perf_counter = read_csr(1014);
return perf_counter;
}

// Read performance counter of GEMM, a read-only CSR
uint32_t read_gemmx_perf_counter() {
uint32_t perf_counter = read_csr(1025);
uint32_t perf_counter = read_csr(1026);
return perf_counter;
}

Expand Down
5 changes: 5 additions & 0 deletions util/snaxgen/snaxgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,11 @@ def streamer_csr_num(acc_cfgs):
2 * num_loop_dim + num_spatial_dim + num_data_mover + 1 + 1 + 1
) # noqa: E501

# transpose csr
if "has_transpose" in acc_cfgs["snax_streamer_cfg"]:
if acc_cfgs["snax_streamer_cfg"]["has_transpose"]:
streamer_csr_num += 1

return streamer_csr_num


Expand Down

0 comments on commit 1e16c05

Please sign in to comment.