From a1c30bb906ea7a5453128e560eb608d5cf4d2a24 Mon Sep 17 00:00:00 2001 From: my-mayfly Date: Tue, 24 Sep 2024 01:04:16 +0800 Subject: [PATCH] fix(BPU): adjust fallThroughErr signal usage strategy (#3627) --- src/main/scala/xiangshan/frontend/BPU.scala | 7 ++++++- src/main/scala/xiangshan/frontend/FTB.scala | 8 ++++++-- .../scala/xiangshan/frontend/FrontendBundle.scala | 1 + src/main/scala/xiangshan/frontend/newRAS.scala | 11 ++++------- 4 files changed, 17 insertions(+), 10 deletions(-) diff --git a/src/main/scala/xiangshan/frontend/BPU.scala b/src/main/scala/xiangshan/frontend/BPU.scala index bef4a76f8e..98bd35d2ea 100644 --- a/src/main/scala/xiangshan/frontend/BPU.scala +++ b/src/main/scala/xiangshan/frontend/BPU.scala @@ -725,11 +725,16 @@ class Predictor(implicit p: Parameters) extends XSModule with HasBPUConst with H ) } + val s3_redirect_target_dup = dup_wire(UInt(VAddrBits.W)) + for ((((s3_redirect_target, s3_fall_thru_error), s3_fallThroughAddr), s3_target) <- s3_redirect_target_dup zip s3_redirect_on_fall_thru_error_dup zip resp.s3.fallThroughAddr zip resp.s3.getTarget){ + s3_redirect_target := Mux(s3_fall_thru_error, s3_fallThroughAddr, s3_target) + } + XSPerfAccumulate(f"s3_redirect_on_br_taken", s3_fire_dup(0) && s3_redirect_on_br_taken_dup(0)) XSPerfAccumulate(f"s3_redirect_on_jalr_target", s3_fire_dup(0) && s3_redirect_on_jalr_target_dup(0)) XSPerfAccumulate(f"s3_redirect_on_others", s3_redirect_dup(0) && !(s3_redirect_on_br_taken_dup(0) || s3_redirect_on_jalr_target_dup(0))) - for (((npcGen, s3_redirect), s3_target) <- npcGen_dup zip s3_redirect_dup zip resp.s3.getTarget) + for (((npcGen, s3_redirect), s3_target) <- npcGen_dup zip s3_redirect_dup zip s3_redirect_target_dup) npcGen.register(s3_redirect, s3_target, Some("s3_target"), 3) for (((foldedGhGen, s3_redirect), s3_predicted_fh) <- foldedGhGen_dup zip s3_redirect_dup zip s3_predicted_fh_dup) foldedGhGen.register(s3_redirect, s3_predicted_fh, Some("s3_FGH"), 3) diff --git a/src/main/scala/xiangshan/frontend/FTB.scala b/src/main/scala/xiangshan/frontend/FTB.scala index 0479182a3e..701417b2a3 100644 --- a/src/main/scala/xiangshan/frontend/FTB.scala +++ b/src/main/scala/xiangshan/frontend/FTB.scala @@ -643,12 +643,13 @@ class FTB(implicit p: Parameters) extends BasePredictor with FTBParams with BPUU s2_fauftb_ftb_entry_dup zip s2_ftbBank_dup zip s2_ftb_entry_dup){ s2_ftb_entry := Mux(s2_close_ftb_req, s2_fauftb_entry, s2_ftbBank_entry) } - val s3_ftb_entry_dup = io.s2_fire.zip(s2_ftb_entry_dup).map {case (f, e) => RegEnable(Mux(s2_multi_hit_enable, s2_multi_hit_entry, e), f)} + val s3_ftb_entry_dup = io.s2_fire.zip(s2_ftb_entry_dup).map {case (f, e) => RegEnable(Mux(s2_multi_hit_enable, s2_multi_hit_entry, e), f)} val real_s2_ftb_entry = Mux(s2_multi_hit_enable, s2_multi_hit_entry, s2_ftb_entry_dup(0)) val real_s2_pc = s2_pc_dup(0).getAddr() val real_s2_startLower = Cat(0.U(1.W), real_s2_pc(instOffsetBits+log2Ceil(PredictWidth)-1, instOffsetBits)) val real_s2_endLowerwithCarry = Cat(real_s2_ftb_entry.carry, real_s2_ftb_entry.pftAddr) val real_s2_fallThroughErr = real_s2_startLower >= real_s2_endLowerwithCarry || real_s2_endLowerwithCarry > (real_s2_startLower + (PredictWidth).U) + val real_s3_fallThroughErr_dup = io.s2_fire.map {f => RegEnable(real_s2_fallThroughErr, f)} //After closing ftb, the hit output from s2 is the hit of FauFTB cached in s1. //s1_hit is the ftbBank hit. @@ -659,7 +660,7 @@ class FTB(implicit p: Parameters) extends BasePredictor with FTBParams with BPUU s2_fauftb_ftb_entry_hit_dup zip s2_ftb_hit_dup zip s2_hit_dup){ s2_hit := Mux(s2_close_ftb_req, s2_fauftb_hit, s2_ftb_hit) } - val s3_hit_dup = io.s2_fire.zip(s2_hit_dup).map {case (f, h) => RegEnable(Mux(s2_multi_hit_enable, s2_multi_hit, h) && !real_s2_fallThroughErr, 0.B, f)} + val s3_hit_dup = io.s2_fire.zip(s2_hit_dup).map {case (f, h) => RegEnable(Mux(s2_multi_hit_enable, s2_multi_hit, h), 0.B, f)} val s3_multi_hit_dup = io.s2_fire.map(f => RegEnable(s2_multi_hit_enable,f)) val writeWay = Mux(s1_close_ftb_req, 0.U, ftbBank.io.read_hits.bits) val s2_ftb_meta = RegEnable(FTBMeta(writeWay.asUInt, s1_hit, GTimer()).asUInt, io.s1_fire(0)) @@ -728,6 +729,9 @@ class FTB(implicit p: Parameters) extends BasePredictor with FTBParams with BPUU io.out.s3.full_pred zip s3_ftb_entry_dup zip s3_pc_dup zip s2_pc_dup zip io.s2_fire) full_pred.fromFtbEntry(s3_ftb_entry, s3_pc.getAddr(), Some((s2_pc.getAddr(), s2_fire))) + // Overwrite the fallThroughErr value + io.out.s3.full_pred.zipWithIndex.map {case(fp, i) => fp.fallThroughErr := real_s3_fallThroughErr_dup(i)} + io.out.last_stage_ftb_entry := s3_ftb_entry_dup(0) io.out.last_stage_meta := RegEnable(Mux(s2_multi_hit_enable, s2_multi_hit_meta, s2_ftb_meta), io.s2_fire(0)) io.out.s1_ftbCloseReq := s1_close_ftb_req diff --git a/src/main/scala/xiangshan/frontend/FrontendBundle.scala b/src/main/scala/xiangshan/frontend/FrontendBundle.scala index f7a7da2931..d0756c7e29 100644 --- a/src/main/scala/xiangshan/frontend/FrontendBundle.scala +++ b/src/main/scala/xiangshan/frontend/FrontendBundle.scala @@ -710,6 +710,7 @@ class BranchPredictionBundle(implicit p: Parameters) extends XSBundle def shouldShiftVec = VecInit(full_pred.map(_.shouldShiftVec)) def fallThruError = VecInit(full_pred.map(_.fallThruError)) def ftbMultiHit = VecInit(full_pred.map(_.ftbMultiHit)) + def fallThroughAddr = VecInit(full_pred.map(_.fallThroughAddr)) def taken = VecInit(cfiIndex.map(_.valid)) diff --git a/src/main/scala/xiangshan/frontend/newRAS.scala b/src/main/scala/xiangshan/frontend/newRAS.scala index 13bfd9bb54..737e8b505b 100644 --- a/src/main/scala/xiangshan/frontend/newRAS.scala +++ b/src/main/scala/xiangshan/frontend/newRAS.scala @@ -582,11 +582,8 @@ class RAS(implicit p: Parameters) extends BasePredictor { val s3_top = RegEnable(stack.spec_pop_addr, io.s2_fire(2)) val s3_spec_new_addr = RegEnable(s2_spec_new_addr, io.s2_fire(2)) - // val s3_jalr_target = io.out.s3.full_pred.jalr_target - // val s3_last_target_in = io.in.bits.resp_in(0).s3.full_pred(2).targets.last - // val s3_last_target_out = io.out.s3.full_pred(2).targets.last - val s3_is_jalr = io.in.bits.resp_in(0).s3.full_pred(2).is_jalr - val s3_is_ret = io.in.bits.resp_in(0).s3.full_pred(2).is_ret + val s3_is_jalr = io.in.bits.resp_in(0).s3.full_pred(2).is_jalr && !io.in.bits.resp_in(0).s3.full_pred(2).fallThroughErr + val s3_is_ret = io.in.bits.resp_in(0).s3.full_pred(2).is_ret && !io.in.bits.resp_in(0).s3.full_pred(2).fallThroughErr // assert(is_jalr && is_ret || !is_ret) when(s3_is_ret && io.ctrl.ras_enable) { io.out.s3.full_pred.map(_.jalr_target).foreach(_ := s3_top) @@ -599,8 +596,8 @@ class RAS(implicit p: Parameters) extends BasePredictor { val s3_pushed_in_s2 = RegEnable(s2_spec_push, io.s2_fire(2)) val s3_popped_in_s2 = RegEnable(s2_spec_pop, io.s2_fire(2)) - val s3_push = io.in.bits.resp_in(0).s3.full_pred(2).hit_taken_on_call - val s3_pop = io.in.bits.resp_in(0).s3.full_pred(2).hit_taken_on_ret + val s3_push = io.in.bits.resp_in(0).s3.full_pred(2).hit_taken_on_call && !io.in.bits.resp_in(0).s3.full_pred(2).fallThroughErr + val s3_pop = io.in.bits.resp_in(0).s3.full_pred(2).hit_taken_on_ret && !io.in.bits.resp_in(0).s3.full_pred(2).fallThroughErr val s3_cancel = io.s3_fire(2) && (s3_pushed_in_s2 =/= s3_push || s3_popped_in_s2 =/= s3_pop) stack.s2_fire := io.s2_fire(2)