Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add identity match mir pass #77770

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 76 additions & 0 deletions compiler/rustc_mir/src/transform/match_identity.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
use crate::transform::MirPass;
use rustc_middle::mir::*;
use rustc_middle::ty::TyCtxt;
use rustc_target::abi::VariantIdx;

pub struct MatchIdentitySimplification;

impl<'tcx> MirPass<'tcx> for MatchIdentitySimplification {
fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
//let param_env = tcx.param_env(body.source.def_id());
let (bbs, local_decls) = body.basic_blocks_and_local_decls_mut();
for bb_idx in bbs.indices() {
let (read_discr, og_match) = match &bbs[bb_idx].statements[..] {
&[Statement {
kind: StatementKind::Assign(box (dst, Rvalue::Discriminant(src))),
..
}] => (dst, src),
_ => continue,
};
let (var_idx, fst, snd) = match bbs[bb_idx].terminator().kind {
TerminatorKind::SwitchInt {
discr: Operand::Copy(ref place) | Operand::Move(ref place),
ref targets,
ref values,
..
} if targets.len() == 2
&& values.len() == 1
&& targets[0] != targets[1]
// check that we're switching on the read discr
&& place == &read_discr
// check that this is actually
&& place.ty(local_decls, tcx).ty.is_enum() =>
{
(VariantIdx::from(values[0] as usize), targets[0], targets[1])
}
// Only optimize switch int statements
_ => continue,
};
let stmts_ok = |stmts: &[Statement<'_>], expected_variant| match stmts {
[Statement {
kind:
StatementKind::Assign(box (
dst0,
Rvalue::Use(Operand::Copy(from) | Operand::Move(from)),
)),
..
}, Statement {
kind: StatementKind::SetDiscriminant { place: box dst1, variant_index },
..
}] => *variant_index == expected_variant && dst0 == dst1 && og_match == *from,
_ => false,
};
let bb1 = &bbs[fst];
let bb2 = &bbs[snd];
if bb1.terminator().kind != bb2.terminator().kind
|| stmts_ok(&bb1.statements[..], var_idx)
|| stmts_ok(&bb2.statements[..], var_idx + 1)
{
continue;
}
let dst = match (&bb1.statements[0], &bb2.statements[0]) {
(
Statement { kind: StatementKind::Assign(box (dst0, _)), .. },
Statement { kind: StatementKind::Assign(box (dst1, _)), .. },
) if dst0 == dst1 => dst0.clone(),
_ => continue,
};
let term_kind = bb1.terminator().kind.clone();
// Reassign the output to just be the original
// Replace the terminator with the terminator of the output
bbs[bb_idx].statements[0].kind =
StatementKind::Assign(box (dst, Rvalue::Use(Operand::Copy(og_match))));
bbs[bb_idx].terminator_mut().kind = term_kind;
}
}
}
2 changes: 2 additions & 0 deletions compiler/rustc_mir/src/transform/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ pub mod inline;
pub mod instcombine;
pub mod instrument_coverage;
pub mod match_branches;
pub mod match_identity;
pub mod multiple_return_terminators;
pub mod no_landing_pads;
pub mod nrvo;
Expand Down Expand Up @@ -412,6 +413,7 @@ fn run_optimization_passes<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
&nrvo::RenameReturnPlace,
&simplify::SimplifyLocals,
&multiple_return_terminators::MultipleReturnTerminators,
&match_identity::MatchIdentitySimplification,
];

// Optimizations to run even if mir optimizations have been disabled.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
- // MIR for `flip_flop` before MatchBranchSimplification
+ // MIR for `flip_flop` after MatchBranchSimplification

fn flip_flop(_1: std::result::Result<i32, u32>) -> std::result::Result<u32, i32> {
debug a => _1; // in scope 0 at $DIR/match_identity.rs:16:18: 16:19
let mut _0: std::result::Result<u32, i32>; // return place in scope 0 at $DIR/match_identity.rs:16:38: 16:50
let mut _2: isize; // in scope 0 at $DIR/match_identity.rs:18:9: 18:14
let _3: i32; // in scope 0 at $DIR/match_identity.rs:18:12: 18:13
let mut _4: i32; // in scope 0 at $DIR/match_identity.rs:18:22: 18:23
let _5: u32; // in scope 0 at $DIR/match_identity.rs:19:13: 19:14
let mut _6: u32; // in scope 0 at $DIR/match_identity.rs:19:22: 19:23
scope 1 {
debug x => _3; // in scope 1 at $DIR/match_identity.rs:18:12: 18:13
}
scope 2 {
debug y => _5; // in scope 2 at $DIR/match_identity.rs:19:13: 19:14
}

bb0: {
_2 = discriminant(_1); // scope 0 at $DIR/match_identity.rs:18:9: 18:14
switchInt(move _2) -> [0_isize: bb2, otherwise: bb1]; // scope 0 at $DIR/match_identity.rs:18:9: 18:14
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it doesn't look like your optimization actually gets applied

}

bb1: {
StorageLive(_5); // scope 0 at $DIR/match_identity.rs:19:13: 19:14
_5 = ((_1 as Err).0: u32); // scope 0 at $DIR/match_identity.rs:19:13: 19:14
StorageLive(_6); // scope 2 at $DIR/match_identity.rs:19:22: 19:23
_6 = _5; // scope 2 at $DIR/match_identity.rs:19:22: 19:23
((_0 as Ok).0: u32) = move _6; // scope 2 at $DIR/match_identity.rs:19:19: 19:24
discriminant(_0) = 0; // scope 2 at $DIR/match_identity.rs:19:19: 19:24
StorageDead(_6); // scope 2 at $DIR/match_identity.rs:19:23: 19:24
StorageDead(_5); // scope 0 at $DIR/match_identity.rs:19:23: 19:24
goto -> bb3; // scope 0 at $DIR/match_identity.rs:17:5: 20:6
}

bb2: {
StorageLive(_3); // scope 0 at $DIR/match_identity.rs:18:12: 18:13
_3 = ((_1 as Ok).0: i32); // scope 0 at $DIR/match_identity.rs:18:12: 18:13
StorageLive(_4); // scope 1 at $DIR/match_identity.rs:18:22: 18:23
_4 = _3; // scope 1 at $DIR/match_identity.rs:18:22: 18:23
((_0 as Err).0: i32) = move _4; // scope 1 at $DIR/match_identity.rs:18:18: 18:24
discriminant(_0) = 1; // scope 1 at $DIR/match_identity.rs:18:18: 18:24
StorageDead(_4); // scope 1 at $DIR/match_identity.rs:18:23: 18:24
StorageDead(_3); // scope 0 at $DIR/match_identity.rs:18:23: 18:24
goto -> bb3; // scope 0 at $DIR/match_identity.rs:17:5: 20:6
}

bb3: {
return; // scope 0 at $DIR/match_identity.rs:21:2: 21:2
}
}

Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
- // MIR for `id_result` before MatchIdentitySimplification
+ // MIR for `id_result` after MatchIdentitySimplification

fn id_result(_1: std::result::Result<i32, u32>) -> std::result::Result<i32, u32> {
debug a => _1; // in scope 0 at $DIR/match_identity.rs:8:18: 8:19
let mut _0: std::result::Result<i32, u32>; // return place in scope 0 at $DIR/match_identity.rs:8:38: 8:50
let mut _2: isize; // in scope 0 at $DIR/match_identity.rs:10:9: 10:14
scope 1 {
debug x => ((_0 as Ok).0: i32); // in scope 1 at $DIR/match_identity.rs:10:12: 10:13
}
scope 2 {
debug y => ((_0 as Err).0: u32); // in scope 2 at $DIR/match_identity.rs:11:13: 11:14
}

bb0: {
_2 = discriminant(_1); // scope 0 at $DIR/match_identity.rs:10:9: 10:14
switchInt(move _2) -> [0_isize: bb2, otherwise: bb1]; // scope 0 at $DIR/match_identity.rs:10:9: 10:14
}

bb1: {
((_0 as Err).0: u32) = ((_1 as Err).0: u32); // scope 0 at $DIR/match_identity.rs:11:13: 11:14
discriminant(_0) = 1; // scope 2 at $DIR/match_identity.rs:11:19: 11:25
return; // scope 0 at $DIR/match_identity.rs:9:5: 12:6
}

bb2: {
((_0 as Ok).0: i32) = ((_1 as Ok).0: i32); // scope 0 at $DIR/match_identity.rs:10:12: 10:13
discriminant(_0) = 0; // scope 1 at $DIR/match_identity.rs:10:18: 10:23
return; // scope 0 at $DIR/match_identity.rs:9:5: 12:6
}
}

21 changes: 21 additions & 0 deletions src/test/mir-opt/match_identity.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
#![crate_type = "lib"]

type T = i32;
type E = u32;

// EMIT_MIR_FOR_EACH_BIT_WIDTH
// EMIT_MIR match_identity.id_result.MatchIdentitySimplification.diff
pub fn id_result(a: Result<T, E>) -> Result<T, E> {
match a {
Ok(x) => Ok(x),
Err(y) => Err(y),
}
}

// EMIT_MIR match_identity.flip_flop.MatchBranchSimplification.diff
pub fn flip_flop(a: Result<T, E>) -> Result<E, T> {
match a {
Ok(x) => Err(x),
Err(y) => Ok(y),
}
}