From 18a5514737a389d280e40e3fccbb923fd0d5cdd0 Mon Sep 17 00:00:00 2001 From: kadmin Date: Fri, 9 Oct 2020 20:44:22 +0000 Subject: [PATCH 1/2] Add identity match branch mir-optimization --- .../rustc_mir/src/transform/match_identity.rs | 76 +++++++++++++++++++ compiler/rustc_mir/src/transform/mod.rs | 1 + src/test/mir-opt/match_identity.rs | 15 ++++ 3 files changed, 92 insertions(+) create mode 100644 compiler/rustc_mir/src/transform/match_identity.rs create mode 100644 src/test/mir-opt/match_identity.rs diff --git a/compiler/rustc_mir/src/transform/match_identity.rs b/compiler/rustc_mir/src/transform/match_identity.rs new file mode 100644 index 0000000000000..2b255fe6e5f7d --- /dev/null +++ b/compiler/rustc_mir/src/transform/match_identity.rs @@ -0,0 +1,76 @@ +use crate::transform::MirPass; +use rustc_middle::mir::*; +use rustc_middle::ty::TyCtxt; +use rustc_target::abi::VariantIdx; + +pub struct MatchIdentitySimplification; + +impl<'tcx> MirPass<'tcx> for MatchIdentitySimplification { + fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { + //let param_env = tcx.param_env(body.source.def_id()); + let (bbs, local_decls) = body.basic_blocks_and_local_decls_mut(); + for bb_idx in bbs.indices() { + let (read_discr, og_match) = match &bbs[bb_idx].statements[..] { + &[Statement { + kind: StatementKind::Assign(box (dst, Rvalue::Discriminant(src))), + .. + }] => (dst, src), + _ => continue, + }; + let (var_idx, fst, snd) = match bbs[bb_idx].terminator().kind { + TerminatorKind::SwitchInt { + discr: Operand::Copy(ref place) | Operand::Move(ref place), + ref targets, + ref values, + .. + } if targets.len() == 2 + && values.len() == 1 + && targets[0] != targets[1] + // check that we're switching on the read discr + && place == &read_discr + // check that this is actually + && place.ty(local_decls, tcx).ty.is_enum() => + { + (VariantIdx::from(values[0] as usize), targets[0], targets[1]) + } + // Only optimize switch int statements + _ => continue, + }; + let stmts_ok = |stmts: &[Statement<'_>], expected_variant| match stmts { + [Statement { + kind: + StatementKind::Assign(box ( + dst0, + Rvalue::Use(Operand::Copy(from) | Operand::Move(from)), + )), + .. + }, Statement { + kind: StatementKind::SetDiscriminant { place: box dst1, variant_index }, + .. + }] => *variant_index == expected_variant && dst0 == dst1 && og_match == *from, + _ => false, + }; + let bb1 = &bbs[fst]; + let bb2 = &bbs[snd]; + if bb1.terminator().kind != bb2.terminator().kind + || stmts_ok(&bb1.statements[..], var_idx) + || stmts_ok(&bb2.statements[..], var_idx + 1) + { + continue; + } + let dst = match (&bb1.statements[0], &bb2.statements[0]) { + ( + Statement { kind: StatementKind::Assign(box (dst0, _)), .. }, + Statement { kind: StatementKind::Assign(box (dst1, _)), .. }, + ) if dst0 == dst1 => dst0.clone(), + _ => continue, + }; + let term_kind = bb1.terminator().kind.clone(); + // Reassign the output to just be the original + // Replace the terminator with the terminator of the output + bbs[bb_idx].statements[0].kind = + StatementKind::Assign(box (dst, Rvalue::Use(Operand::Copy(og_match)))); + bbs[bb_idx].terminator_mut().kind = term_kind; + } + } +} diff --git a/compiler/rustc_mir/src/transform/mod.rs b/compiler/rustc_mir/src/transform/mod.rs index b4f5947f5a339..3acb19527717a 100644 --- a/compiler/rustc_mir/src/transform/mod.rs +++ b/compiler/rustc_mir/src/transform/mod.rs @@ -33,6 +33,7 @@ pub mod inline; pub mod instcombine; pub mod instrument_coverage; pub mod match_branches; +pub mod match_identity; pub mod multiple_return_terminators; pub mod no_landing_pads; pub mod nrvo; diff --git a/src/test/mir-opt/match_identity.rs b/src/test/mir-opt/match_identity.rs new file mode 100644 index 0000000000000..70b6cad34db1d --- /dev/null +++ b/src/test/mir-opt/match_identity.rs @@ -0,0 +1,15 @@ +// EMIT_MIR match_identity.id_result.match_identity.diff +pub fn id_result(a: Result) -> Result { + match a { + Ok(x) => Ok(x), + Err(y) => Err(y), + } +} + +// EMIT_MIR match_identity.id_result.match_identity.diff +pub fn flip_flop(a: Result) -> Result { + match a { + Ok(x) => Err(x), + Err(y) => Ok(y), + } +} From cf8c66510c0aa2528571f8b7460960de4139afb1 Mon Sep 17 00:00:00 2001 From: kadmin Date: Fri, 9 Oct 2020 22:27:06 +0000 Subject: [PATCH 2/2] Add mir-opt tests --- compiler/rustc_mir/src/transform/mod.rs | 1 + ..._flop.MatchBranchSimplification.64bit.diff | 52 +++++++++++++++++++ ...ult.MatchIdentitySimplification.64bit.diff | 32 ++++++++++++ src/test/mir-opt/match_identity.rs | 14 +++-- 4 files changed, 95 insertions(+), 4 deletions(-) create mode 100644 src/test/mir-opt/match_identity.flip_flop.MatchBranchSimplification.64bit.diff create mode 100644 src/test/mir-opt/match_identity.id_result.MatchIdentitySimplification.64bit.diff diff --git a/compiler/rustc_mir/src/transform/mod.rs b/compiler/rustc_mir/src/transform/mod.rs index 3acb19527717a..54f0fb4f8148b 100644 --- a/compiler/rustc_mir/src/transform/mod.rs +++ b/compiler/rustc_mir/src/transform/mod.rs @@ -413,6 +413,7 @@ fn run_optimization_passes<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { &nrvo::RenameReturnPlace, &simplify::SimplifyLocals, &multiple_return_terminators::MultipleReturnTerminators, + &match_identity::MatchIdentitySimplification, ]; // Optimizations to run even if mir optimizations have been disabled. diff --git a/src/test/mir-opt/match_identity.flip_flop.MatchBranchSimplification.64bit.diff b/src/test/mir-opt/match_identity.flip_flop.MatchBranchSimplification.64bit.diff new file mode 100644 index 0000000000000..3eb66e4e03c79 --- /dev/null +++ b/src/test/mir-opt/match_identity.flip_flop.MatchBranchSimplification.64bit.diff @@ -0,0 +1,52 @@ +- // MIR for `flip_flop` before MatchBranchSimplification ++ // MIR for `flip_flop` after MatchBranchSimplification + + fn flip_flop(_1: std::result::Result) -> std::result::Result { + debug a => _1; // in scope 0 at $DIR/match_identity.rs:16:18: 16:19 + let mut _0: std::result::Result; // return place in scope 0 at $DIR/match_identity.rs:16:38: 16:50 + let mut _2: isize; // in scope 0 at $DIR/match_identity.rs:18:9: 18:14 + let _3: i32; // in scope 0 at $DIR/match_identity.rs:18:12: 18:13 + let mut _4: i32; // in scope 0 at $DIR/match_identity.rs:18:22: 18:23 + let _5: u32; // in scope 0 at $DIR/match_identity.rs:19:13: 19:14 + let mut _6: u32; // in scope 0 at $DIR/match_identity.rs:19:22: 19:23 + scope 1 { + debug x => _3; // in scope 1 at $DIR/match_identity.rs:18:12: 18:13 + } + scope 2 { + debug y => _5; // in scope 2 at $DIR/match_identity.rs:19:13: 19:14 + } + + bb0: { + _2 = discriminant(_1); // scope 0 at $DIR/match_identity.rs:18:9: 18:14 + switchInt(move _2) -> [0_isize: bb2, otherwise: bb1]; // scope 0 at $DIR/match_identity.rs:18:9: 18:14 + } + + bb1: { + StorageLive(_5); // scope 0 at $DIR/match_identity.rs:19:13: 19:14 + _5 = ((_1 as Err).0: u32); // scope 0 at $DIR/match_identity.rs:19:13: 19:14 + StorageLive(_6); // scope 2 at $DIR/match_identity.rs:19:22: 19:23 + _6 = _5; // scope 2 at $DIR/match_identity.rs:19:22: 19:23 + ((_0 as Ok).0: u32) = move _6; // scope 2 at $DIR/match_identity.rs:19:19: 19:24 + discriminant(_0) = 0; // scope 2 at $DIR/match_identity.rs:19:19: 19:24 + StorageDead(_6); // scope 2 at $DIR/match_identity.rs:19:23: 19:24 + StorageDead(_5); // scope 0 at $DIR/match_identity.rs:19:23: 19:24 + goto -> bb3; // scope 0 at $DIR/match_identity.rs:17:5: 20:6 + } + + bb2: { + StorageLive(_3); // scope 0 at $DIR/match_identity.rs:18:12: 18:13 + _3 = ((_1 as Ok).0: i32); // scope 0 at $DIR/match_identity.rs:18:12: 18:13 + StorageLive(_4); // scope 1 at $DIR/match_identity.rs:18:22: 18:23 + _4 = _3; // scope 1 at $DIR/match_identity.rs:18:22: 18:23 + ((_0 as Err).0: i32) = move _4; // scope 1 at $DIR/match_identity.rs:18:18: 18:24 + discriminant(_0) = 1; // scope 1 at $DIR/match_identity.rs:18:18: 18:24 + StorageDead(_4); // scope 1 at $DIR/match_identity.rs:18:23: 18:24 + StorageDead(_3); // scope 0 at $DIR/match_identity.rs:18:23: 18:24 + goto -> bb3; // scope 0 at $DIR/match_identity.rs:17:5: 20:6 + } + + bb3: { + return; // scope 0 at $DIR/match_identity.rs:21:2: 21:2 + } + } + diff --git a/src/test/mir-opt/match_identity.id_result.MatchIdentitySimplification.64bit.diff b/src/test/mir-opt/match_identity.id_result.MatchIdentitySimplification.64bit.diff new file mode 100644 index 0000000000000..f7bebbe6607de --- /dev/null +++ b/src/test/mir-opt/match_identity.id_result.MatchIdentitySimplification.64bit.diff @@ -0,0 +1,32 @@ +- // MIR for `id_result` before MatchIdentitySimplification ++ // MIR for `id_result` after MatchIdentitySimplification + + fn id_result(_1: std::result::Result) -> std::result::Result { + debug a => _1; // in scope 0 at $DIR/match_identity.rs:8:18: 8:19 + let mut _0: std::result::Result; // return place in scope 0 at $DIR/match_identity.rs:8:38: 8:50 + let mut _2: isize; // in scope 0 at $DIR/match_identity.rs:10:9: 10:14 + scope 1 { + debug x => ((_0 as Ok).0: i32); // in scope 1 at $DIR/match_identity.rs:10:12: 10:13 + } + scope 2 { + debug y => ((_0 as Err).0: u32); // in scope 2 at $DIR/match_identity.rs:11:13: 11:14 + } + + bb0: { + _2 = discriminant(_1); // scope 0 at $DIR/match_identity.rs:10:9: 10:14 + switchInt(move _2) -> [0_isize: bb2, otherwise: bb1]; // scope 0 at $DIR/match_identity.rs:10:9: 10:14 + } + + bb1: { + ((_0 as Err).0: u32) = ((_1 as Err).0: u32); // scope 0 at $DIR/match_identity.rs:11:13: 11:14 + discriminant(_0) = 1; // scope 2 at $DIR/match_identity.rs:11:19: 11:25 + return; // scope 0 at $DIR/match_identity.rs:9:5: 12:6 + } + + bb2: { + ((_0 as Ok).0: i32) = ((_1 as Ok).0: i32); // scope 0 at $DIR/match_identity.rs:10:12: 10:13 + discriminant(_0) = 0; // scope 1 at $DIR/match_identity.rs:10:18: 10:23 + return; // scope 0 at $DIR/match_identity.rs:9:5: 12:6 + } + } + diff --git a/src/test/mir-opt/match_identity.rs b/src/test/mir-opt/match_identity.rs index 70b6cad34db1d..5e54cbcda31c2 100644 --- a/src/test/mir-opt/match_identity.rs +++ b/src/test/mir-opt/match_identity.rs @@ -1,13 +1,19 @@ -// EMIT_MIR match_identity.id_result.match_identity.diff -pub fn id_result(a: Result) -> Result { +#![crate_type = "lib"] + +type T = i32; +type E = u32; + +// EMIT_MIR_FOR_EACH_BIT_WIDTH +// EMIT_MIR match_identity.id_result.MatchIdentitySimplification.diff +pub fn id_result(a: Result) -> Result { match a { Ok(x) => Ok(x), Err(y) => Err(y), } } -// EMIT_MIR match_identity.id_result.match_identity.diff -pub fn flip_flop(a: Result) -> Result { +// EMIT_MIR match_identity.flip_flop.MatchBranchSimplification.diff +pub fn flip_flop(a: Result) -> Result { match a { Ok(x) => Err(x), Err(y) => Ok(y),