From 1e16c053df186b357ed679f2af6894802a509c08 Mon Sep 17 00:00:00 2001 From: Xiaoling Yi <143962462+xiaoling-yi@users.noreply.github.com> Date: Mon, 26 Aug 2024 15:29:28 +0200 Subject: [PATCH] add transpose feature for reader (#279) * add transpose feature for reader * rm trans in r&w * fix wire connect for no trans * scalafmr --- .../main/scala/snax/streamer/DataReader.scala | 38 ++++++++++++++++++- .../main/scala/snax/streamer/Parameters.scala | 6 ++- .../main/scala/snax/streamer/Streamer.scala | 16 ++++++++ .../scala/snax/streamer/StreamerTop.scala | 14 ++++++- .../snax/streamer/StreamerTestParameter.scala | 2 +- .../xdmaStreamer/AddressGenUnitTester.scala | 4 +- hw/templates/stream_param_gen.scala.tpl | 12 +++++- .../snax-kul-cluster-mixed-narrow-wide.hjson | 2 + .../src/snax-streamer-gemm-conv-simd-lib.c | 35 +++++++++-------- util/snaxgen/snaxgen.py | 5 +++ 10 files changed, 109 insertions(+), 25 deletions(-) diff --git a/hw/chisel/src/main/scala/snax/streamer/DataReader.scala b/hw/chisel/src/main/scala/snax/streamer/DataReader.scala index aeec185f3..a5b53959c 100644 --- a/hw/chisel/src/main/scala/snax/streamer/DataReader.scala +++ b/hw/chisel/src/main/scala/snax/streamer/DataReader.scala @@ -27,6 +27,8 @@ class DataReaderIO( "fifoWidth should match with TCDM datawidth for now!" ) + val ifTranspose = if (params.hasTranspose) Some(Input(Bool())) else None + } /** This class is data reader module,.It is responsible for sending read request @@ -120,8 +122,42 @@ class DataReader( ) } + // transpose the data + // !!!attention: only works for 8x8 matrix!!! + val data_fifo_input_concat = Cat(data_fifo_input.reverse) + val data_fifo_input_transpose = Wire(Vec(8, Vec(8, UInt(8.W)))) + if (params.hasTranspose) { + require( + params.tcdmDataWidth == 64, + "transposeInWidth must be tcdmDataWidth = 64 for now" + ) + require( + params.tcdmPortsNum == 8, + "transposeOutWidth must be params.tcdmPortsNum = 8 for now" + ) + } + // gether all the response data - io.data_fifo_o.bits := Cat(data_fifo_input.reverse) + if (params.hasTranspose) { + for (i <- 0 until 8) { + for (j <- 0 until 8) { + data_fifo_input_transpose(i)(j) := data_fifo_input_concat( + i * 8 + j * 8 * 8 + 7, + i * 8 + j * 8 * 8 + 0 + ) + } + } + when(io.ifTranspose.get === true.B) { + io.data_fifo_o.bits := data_fifo_input_transpose.asUInt + }.otherwise { + io.data_fifo_o.bits := Cat(data_fifo_input.reverse) + } + } else { + io.data_fifo_o.bits := Cat(data_fifo_input.reverse) + data_fifo_input_transpose := VecInit( + Seq.fill(8)(VecInit(Seq.fill(8)(0.U(8.W)))) + ) + } // ************************************************************ // ********** Logic for handling fifo handshake *************** diff --git a/hw/chisel/src/main/scala/snax/streamer/Parameters.scala b/hw/chisel/src/main/scala/snax/streamer/Parameters.scala index 35aafd381..1f96c6c09 100644 --- a/hw/chisel/src/main/scala/snax/streamer/Parameters.scala +++ b/hw/chisel/src/main/scala/snax/streamer/Parameters.scala @@ -64,6 +64,7 @@ case class SpatialAddrGenUnitParams( case class DataMoverParams( tcdmPortsNum: Int, addrWidth: Int, + hasTranspose: Boolean = false, spatialBounds: Seq[Int], spatialDim: Int, elementWidth: Int, @@ -116,6 +117,8 @@ trait HasStreamerCoreParams { val ifShareTempAddrGenLoopBounds: Boolean val addrWidth: Int + + val hasTranspose: Boolean } /** trait for Streamer inferred parameters @@ -206,7 +209,8 @@ case class StreamerParams( readOnlyCsrNum: Int = 2, csrAddrWidth: Int = 32, ifShareTempAddrGenLoopBounds: Boolean = true, - tagName: String = "" + tagName: String = "", + hasTranspose: Boolean = false ) extends HasStreamerCoreParams with HasStreamerInferredParams with CommonParams diff --git a/hw/chisel/src/main/scala/snax/streamer/Streamer.scala b/hw/chisel/src/main/scala/snax/streamer/Streamer.scala index 44d4a6509..04456a24a 100644 --- a/hw/chisel/src/main/scala/snax/streamer/Streamer.scala +++ b/hw/chisel/src/main/scala/snax/streamer/Streamer.scala @@ -62,6 +62,9 @@ class StreamerCsrIO( val ptr_i = Vec(params.dataMoverNum, UInt(params.addrWidth.W)) + // only has transpose function for the data readers + val ifTranspose = + if (params.hasTranspose) Some(Vec(params.dataReaderNum, Bool())) else None } // data related io @@ -393,6 +396,19 @@ class Streamer( } } + // the cfg for transpose for only reader + val transpose_cfg = RegInit(0.U.asTypeOf(Vec(params.dataReaderNum, Bool()))) + // store the transpose configuration when the cfg is valid + + if (params.hasTranspose) { + when(io.csr.valid) { + transpose_cfg := io.csr.bits.ifTranspose.get + } + for (i <- 0 until params.dataReaderNum) { + data_reader(i).io.ifTranspose.get := transpose_cfg(i) + } + } + // data reader and data writer <> address generation units interface for (i <- 0 until params.dataMoverNum) { if (i < params.dataReaderNum) { diff --git a/hw/chisel/src/main/scala/snax/streamer/StreamerTop.scala b/hw/chisel/src/main/scala/snax/streamer/StreamerTop.scala index 083eda857..332ca2d84 100644 --- a/hw/chisel/src/main/scala/snax/streamer/StreamerTop.scala +++ b/hw/chisel/src/main/scala/snax/streamer/StreamerTop.scala @@ -42,12 +42,13 @@ class StreamerTop( override val desiredName = params.tagName + "StreamerTop" var csrNumReadWrite: Int = 0 + val transposeCSRNum = if (params.hasTranspose) 1 else 0 if (params.ifShareTempAddrGenLoopBounds == true) { csrNumReadWrite = - params.temporalDimInt + params.dataMoverNum * params.temporalDimInt + params.spatialDim.sum + params.dataMoverNum + 1 + params.temporalDimInt + params.dataMoverNum * params.temporalDimInt + params.spatialDim.sum + params.dataMoverNum + transposeCSRNum + 1 } else { csrNumReadWrite = - params.temporalDimSeq.sum + params.temporalDimSeq.sum + params.spatialDim.sum + params.dataMoverNum + 1 + params.temporalDimSeq.sum + params.temporalDimSeq.sum + params.spatialDim.sum + params.dataMoverNum + transposeCSRNum + 1 } val io = IO( @@ -194,6 +195,15 @@ class StreamerTop( } } + // transpose configruations + require(params.dataReaderNum <= 32, "dataMoverNum should be less than 32") + if (params.hasTranspose) { + for (i <- 0 until params.dataReaderNum) { + streamer.io.csr.bits.ifTranspose + .get(i) := csr_manager.io.csr_config_out.bits.last(i) + } + } + // io.data and streamer data ports connection io.data <> streamer.io.data diff --git a/hw/chisel/src/test/scala/snax/streamer/StreamerTestParameter.scala b/hw/chisel/src/test/scala/snax/streamer/StreamerTestParameter.scala index 37d6e6c7c..ea62cd3e4 100644 --- a/hw/chisel/src/test/scala/snax/streamer/StreamerTestParameter.scala +++ b/hw/chisel/src/test/scala/snax/streamer/StreamerTestParameter.scala @@ -83,7 +83,7 @@ object StreamerTestConstant extends CommonParams { object StreamerWithReaderWriterTestConstant extends CommonParams { def addrWidth = 17 - + def temporalAddrGenUnitParams: Seq[TemporalAddrGenUnitParams] = Seq( TemporalAddrGenUnitParams( diff --git a/hw/chisel/src/test/scala/snax/xdma/xdmaStreamer/AddressGenUnitTester.scala b/hw/chisel/src/test/scala/snax/xdma/xdmaStreamer/AddressGenUnitTester.scala index 040edc3df..8d54124ce 100644 --- a/hw/chisel/src/test/scala/snax/xdma/xdmaStreamer/AddressGenUnitTester.scala +++ b/hw/chisel/src/test/scala/snax/xdma/xdmaStreamer/AddressGenUnitTester.scala @@ -21,9 +21,7 @@ class basicCounterTester extends AnyFlatSpec with ChiselScalatestTester { } } -class AddressGenUnitTester - extends AnyFlatSpec - with ChiselScalatestTester { +class AddressGenUnitTester extends AnyFlatSpec with ChiselScalatestTester { println( getVerilogString(new AddressGenUnit(AddressGenUnitParam())) diff --git a/hw/templates/stream_param_gen.scala.tpl b/hw/templates/stream_param_gen.scala.tpl index 04b33e4ef..bc6877439 100644 --- a/hw/templates/stream_param_gen.scala.tpl +++ b/hw/templates/stream_param_gen.scala.tpl @@ -51,6 +51,12 @@ object StreamerParametersGen extends CommonParams { def addrWidth = ${tcdm_addr_width} +% if "has_transpose" in cfg["snax_streamer_cfg"] and cfg["snax_streamer_cfg"]["has_transpose"]: + def hasTranspose = true +% else: + def hasTranspose = false +% endif + def temporalAddrGenUnitParams: Seq[TemporalAddrGenUnitParams] = Seq( % for idx in range(0,len(cfg["snax_streamer_cfg"]["temporal_addrgen_unit_params"]["loop_dim"])): @@ -109,6 +115,7 @@ ${', ' if not loop.last else ''} DataMoverParams( tcdmPortsNum = ${cfg["snax_streamer_cfg"]["data_reader_params"]["tcdm_ports_num"][idx]}, addrWidth, + hasTranspose, spatialBounds = Seq(\ % for c in cfg["snax_streamer_cfg"]["data_reader_params"]["spatial_bounds"][idx]: ${c}${', ' if not loop.last else ''}\ @@ -130,6 +137,7 @@ ${c}${', ' if not loop.last else ''}\ DataMoverParams( tcdmPortsNum = ${cfg["snax_streamer_cfg"]["data_writer_params"]["tcdm_ports_num"][idx]}, addrWidth, + hasTranspose = false, spatialBounds = Seq(\ % for c in cfg["snax_streamer_cfg"]["data_writer_params"]["spatial_bounds"][idx]: ${c}${', ' if not loop.last else ''}\ @@ -151,6 +159,7 @@ ${c}${', ' if not loop.last else ''}\ DataMoverParams( tcdmPortsNum = ${cfg["snax_streamer_cfg"]["data_reader_writer_params"]["tcdm_ports_num"][idx]}, addrWidth, + hasTranspose = false, spatialBounds = Seq(\ % for c in cfg["snax_streamer_cfg"]["data_reader_writer_params"]["spatial_bounds"][idx]: ${c}${', ' if not loop.last else ''}\ @@ -191,7 +200,8 @@ object StreamerTopGen { addrWidth = StreamerParametersGen.addrWidth, stationarity = StreamerParametersGen.stationarity, ifShareTempAddrGenLoopBounds = StreamerParametersGen.ifShareTempAddrGenLoopBounds, - tagName = "${cfg["tag_name"]}_streamer_" + tagName = "${cfg["tag_name"]}_streamer_", + hasTranspose = StreamerParametersGen.hasTranspose ) ), Array("--target-dir", outPath) diff --git a/target/snitch_cluster/cfg/snax-kul-cluster-mixed-narrow-wide.hjson b/target/snitch_cluster/cfg/snax-kul-cluster-mixed-narrow-wide.hjson index d538a7bd2..3aabdae10 100644 --- a/target/snitch_cluster/cfg/snax-kul-cluster-mixed-narrow-wide.hjson +++ b/target/snitch_cluster/cfg/snax-kul-cluster-mixed-narrow-wide.hjson @@ -215,6 +215,8 @@ } stationarity: [0,0,0,0,0] + + has_transpose: true }, // SNAX Streamer Templates snax_data_reshuffler_streamer_template :{ diff --git a/target/snitch_cluster/sw/snax/streamer-gemm-conv-simd/src/snax-streamer-gemm-conv-simd-lib.c b/target/snitch_cluster/sw/snax/streamer-gemm-conv-simd/src/snax-streamer-gemm-conv-simd-lib.c index 1fe72404e..4eff93968 100644 --- a/target/snitch_cluster/sw/snax/streamer-gemm-conv-simd/src/snax-streamer-gemm-conv-simd-lib.c +++ b/target/snitch_cluster/sw/snax/streamer-gemm-conv-simd/src/snax-streamer-gemm-conv-simd-lib.c @@ -155,10 +155,13 @@ void set_gemmx_streamer_csr( // base ptr for D32 write_csr(1010, (uint32_t)(delta_local_d32 + snrt_l1_next())); + + // transpose or not + write_csr(1011, 0); } // Set CSR to start STREAMER -void set_gemmx_streamer_start() { write_csr(1011, 1); } +void set_gemmx_streamer_start() { write_csr(1012, 1); } // Set GEMM configuration CSR void set_gemmx_csr(int tempLoop0, int tempLoop1, int tempLoop2, @@ -166,42 +169,42 @@ void set_gemmx_csr(int tempLoop0, int tempLoop1, int tempLoop2, uint32_t csr2, uint32_t temporal_loop_bound, uint32_t bypassSIMD) { // set loop bounds, from innermost to outermost, aka from K to N to M - write_csr(1014, tempLoop0); - write_csr(1015, tempLoop1); - write_csr(1016, tempLoop2); + write_csr(1015, tempLoop0); + write_csr(1016, tempLoop1); + write_csr(1017, tempLoop2); // set subtraction a and b - write_csr(1017, subtractions); + write_csr(1018, subtractions); // set the constants for the SIMD unit - write_csr(1018, csr0); - write_csr(1019, csr1); - write_csr(1020, csr2); + write_csr(1019, csr0); + write_csr(1020, csr1); + write_csr(1021, csr2); // set the temporal loop bound - write_csr(1021, temporal_loop_bound); - write_csr(1022, bypassSIMD); + write_csr(1022, temporal_loop_bound); + write_csr(1023, bypassSIMD); } // Set CSR to start GEMM -void set_gemmx_start() { write_csr(1023, 1); } +void set_gemmx_start() { write_csr(1024, 1); } // Stall until Streamer and GEMM accelerator finish void wait_gemmx_and_streamer() { - write_csr(1011, 0); - write_csr(1011, 0); - write_csr(1023, 0); + write_csr(1012, 0); + write_csr(1012, 0); + write_csr(1024, 0); } // Read performance counter of the Streamer, a read-only CSR uint32_t read_gemmx_streamer_perf_counter() { - uint32_t perf_counter = read_csr(1013); + uint32_t perf_counter = read_csr(1014); return perf_counter; } // Read performance counter of GEMM, a read-only CSR uint32_t read_gemmx_perf_counter() { - uint32_t perf_counter = read_csr(1025); + uint32_t perf_counter = read_csr(1026); return perf_counter; } diff --git a/util/snaxgen/snaxgen.py b/util/snaxgen/snaxgen.py index dfcc1a268..22794e546 100755 --- a/util/snaxgen/snaxgen.py +++ b/util/snaxgen/snaxgen.py @@ -147,6 +147,11 @@ def streamer_csr_num(acc_cfgs): 2 * num_loop_dim + num_spatial_dim + num_data_mover + 1 + 1 + 1 ) # noqa: E501 + # transpose csr + if "has_transpose" in acc_cfgs["snax_streamer_cfg"]: + if acc_cfgs["snax_streamer_cfg"]["has_transpose"]: + streamer_csr_num += 1 + return streamer_csr_num