Skip to content

Commit

Permalink
Auto merge of #129047 - DianQK:early_otherwise_branch_scalar, r=cjgillot
Browse files Browse the repository at this point in the history
Apply `EarlyOtherwiseBranch` to scalar value

In the future, I'm thinking of hoisting discriminant via GVN so that we only need to write very little code here.

r? `@cjgillot`
  • Loading branch information
bors committed Sep 23, 2024
2 parents 702987f + e3a9eaf commit a772336
Show file tree
Hide file tree
Showing 5 changed files with 422 additions and 85 deletions.
260 changes: 175 additions & 85 deletions compiler/rustc_mir_transform/src/early_otherwise_branch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -133,18 +133,29 @@ impl<'tcx> crate::MirPass<'tcx> for EarlyOtherwiseBranch {

let mut patch = MirPatch::new(body);

// create temp to store second discriminant in, `_s` in example above
let second_discriminant_temp =
patch.new_temp(opt_data.child_ty, opt_data.child_source.span);
let (second_discriminant_temp, second_operand) = if opt_data.need_hoist_discriminant {
// create temp to store second discriminant in, `_s` in example above
let second_discriminant_temp =
patch.new_temp(opt_data.child_ty, opt_data.child_source.span);

patch.add_statement(parent_end, StatementKind::StorageLive(second_discriminant_temp));
patch.add_statement(
parent_end,
StatementKind::StorageLive(second_discriminant_temp),
);

// create assignment of discriminant
patch.add_assign(
parent_end,
Place::from(second_discriminant_temp),
Rvalue::Discriminant(opt_data.child_place),
);
// create assignment of discriminant
patch.add_assign(
parent_end,
Place::from(second_discriminant_temp),
Rvalue::Discriminant(opt_data.child_place),
);
(
Some(second_discriminant_temp),
Operand::Move(Place::from(second_discriminant_temp)),
)
} else {
(None, Operand::Copy(opt_data.child_place))
};

// create temp to store inequality comparison between the two discriminants, `_t` in
// example above
Expand All @@ -153,11 +164,9 @@ impl<'tcx> crate::MirPass<'tcx> for EarlyOtherwiseBranch {
let comp_temp = patch.new_temp(comp_res_type, opt_data.child_source.span);
patch.add_statement(parent_end, StatementKind::StorageLive(comp_temp));

// create inequality comparison between the two discriminants
let comp_rvalue = Rvalue::BinaryOp(
nequal,
Box::new((parent_op.clone(), Operand::Move(Place::from(second_discriminant_temp)))),
);
// create inequality comparison
let comp_rvalue =
Rvalue::BinaryOp(nequal, Box::new((parent_op.clone(), second_operand)));
patch.add_statement(
parent_end,
StatementKind::Assign(Box::new((Place::from(comp_temp), comp_rvalue))),
Expand Down Expand Up @@ -193,8 +202,13 @@ impl<'tcx> crate::MirPass<'tcx> for EarlyOtherwiseBranch {
TerminatorKind::if_(Operand::Move(Place::from(comp_temp)), true_case, false_case),
);

// generate StorageDead for the second_discriminant_temp not in use anymore
patch.add_statement(parent_end, StatementKind::StorageDead(second_discriminant_temp));
if let Some(second_discriminant_temp) = second_discriminant_temp {
// generate StorageDead for the second_discriminant_temp not in use anymore
patch.add_statement(
parent_end,
StatementKind::StorageDead(second_discriminant_temp),
);
}

// Generate a StorageDead for comp_temp in each of the targets, since we moved it into
// the switch
Expand Down Expand Up @@ -222,6 +236,7 @@ struct OptimizationData<'tcx> {
child_place: Place<'tcx>,
child_ty: Ty<'tcx>,
child_source: SourceInfo,
need_hoist_discriminant: bool,
}

fn evaluate_candidate<'tcx>(
Expand All @@ -235,70 +250,128 @@ fn evaluate_candidate<'tcx>(
return None;
};
let parent_ty = parent_discr.ty(body.local_decls(), tcx);
if !bbs[targets.otherwise()].is_empty_unreachable() {
// Someone could write code like this:
// ```rust
// let Q = val;
// if discriminant(P) == otherwise {
// let ptr = &mut Q as *mut _ as *mut u8;
// // It may be difficult for us to effectively determine whether values are valid.
// // Invalid values can come from all sorts of corners.
// unsafe { *ptr = 10; }
// }
//
// match P {
// A => match Q {
// A => {
// // code
// }
// _ => {
// // don't use Q
// }
// }
// _ => {
// // don't use Q
// }
// };
// ```
//
// Hoisting the `discriminant(Q)` out of the `A` arm causes us to compute the discriminant
// of an invalid value, which is UB.
// In order to fix this, **we would either need to show that the discriminant computation of
// `place` is computed in all branches**.
// FIXME(#95162) For the moment, we adopt a conservative approach and
// consider only the `otherwise` branch has no statements and an unreachable terminator.
return None;
}
let (_, child) = targets.iter().next()?;
let child_terminator = &bbs[child].terminator();
let TerminatorKind::SwitchInt { targets: child_targets, discr: child_discr } =
&child_terminator.kind

let Terminator {
kind: TerminatorKind::SwitchInt { targets: child_targets, discr: child_discr },
source_info,
} = bbs[child].terminator()
else {
return None;
};
let child_ty = child_discr.ty(body.local_decls(), tcx);
if child_ty != parent_ty {
return None;
}
let Some(StatementKind::Assign(boxed)) = &bbs[child].statements.first().map(|x| &x.kind) else {

// We only handle:
// ```
// bb4: {
// _8 = discriminant((_3.1: Enum1));
// switchInt(move _8) -> [2: bb7, otherwise: bb1];
// }
// ```
// and
// ```
// bb2: {
// switchInt((_3.1: u64)) -> [1: bb5, otherwise: bb1];
// }
// ```
if bbs[child].statements.len() > 1 {
return None;
}

// When thie BB has exactly one statement, this statement should be discriminant.
let need_hoist_discriminant = bbs[child].statements.len() == 1;
let child_place = if need_hoist_discriminant {
if !bbs[targets.otherwise()].is_empty_unreachable() {
// Someone could write code like this:
// ```rust
// let Q = val;
// if discriminant(P) == otherwise {
// let ptr = &mut Q as *mut _ as *mut u8;
// // It may be difficult for us to effectively determine whether values are valid.
// // Invalid values can come from all sorts of corners.
// unsafe { *ptr = 10; }
// }
//
// match P {
// A => match Q {
// A => {
// // code
// }
// _ => {
// // don't use Q
// }
// }
// _ => {
// // don't use Q
// }
// };
// ```
//
// Hoisting the `discriminant(Q)` out of the `A` arm causes us to compute the discriminant of an
// invalid value, which is UB.
// In order to fix this, **we would either need to show that the discriminant computation of
// `place` is computed in all branches**.
// FIXME(#95162) For the moment, we adopt a conservative approach and
// consider only the `otherwise` branch has no statements and an unreachable terminator.
return None;
}
// Handle:
// ```
// bb4: {
// _8 = discriminant((_3.1: Enum1));
// switchInt(move _8) -> [2: bb7, otherwise: bb1];
// }
// ```
let [
Statement {
kind: StatementKind::Assign(box (_, Rvalue::Discriminant(child_place))),
..
},
] = bbs[child].statements.as_slice()
else {
return None;
};
*child_place
} else {
// Handle:
// ```
// bb2: {
// switchInt((_3.1: u64)) -> [1: bb5, otherwise: bb1];
// }
// ```
let Operand::Copy(child_place) = child_discr else {
return None;
};
*child_place
};
let (_, Rvalue::Discriminant(child_place)) = &**boxed else {
return None;
let destination = if need_hoist_discriminant || bbs[targets.otherwise()].is_empty_unreachable()
{
child_targets.otherwise()
} else {
targets.otherwise()
};
let destination = child_targets.otherwise();

// Verify that the optimization is legal for each branch
for (value, child) in targets.iter() {
if !verify_candidate_branch(&bbs[child], value, *child_place, destination) {
if !verify_candidate_branch(
&bbs[child],
value,
child_place,
destination,
need_hoist_discriminant,
) {
return None;
}
}
Some(OptimizationData {
destination,
child_place: *child_place,
child_place,
child_ty,
child_source: child_terminator.source_info,
child_source: *source_info,
need_hoist_discriminant,
})
}

Expand All @@ -307,31 +380,48 @@ fn verify_candidate_branch<'tcx>(
value: u128,
place: Place<'tcx>,
destination: BasicBlock,
need_hoist_discriminant: bool,
) -> bool {
// In order for the optimization to be correct, the branch must...
// ...have exactly one statement
if let [statement] = branch.statements.as_slice()
// ...assign the discriminant of `place` in that statement
&& let StatementKind::Assign(boxed) = &statement.kind
&& let (discr_place, Rvalue::Discriminant(from_place)) = &**boxed
&& *from_place == place
// ...make that assignment to a local
&& discr_place.projection.is_empty()
// ...terminate on a `SwitchInt` that invalidates that local
&& let TerminatorKind::SwitchInt { discr: switch_op, targets, .. } =
&branch.terminator().kind
&& *switch_op == Operand::Move(*discr_place)
// ...fall through to `destination` if the switch misses
&& destination == targets.otherwise()
// ...have a branch for value `value`
&& let mut iter = targets.iter()
&& let Some((target_value, _)) = iter.next()
&& target_value == value
// ...and have no more branches
&& iter.next().is_none()
{
true
// In order for the optimization to be correct, the terminator must be a `SwitchInt`.
let TerminatorKind::SwitchInt { discr: switch_op, targets } = &branch.terminator().kind else {
return false;
};
if need_hoist_discriminant {
// If we need hoist discriminant, the branch must have exactly one statement.
let [statement] = branch.statements.as_slice() else {
return false;
};
// The statement must assign the discriminant of `place`.
let StatementKind::Assign(box (discr_place, Rvalue::Discriminant(from_place))) =
statement.kind
else {
return false;
};
if from_place != place {
return false;
}
// The assignment must invalidate a local that terminate on a `SwitchInt`.
if !discr_place.projection.is_empty() || *switch_op != Operand::Move(discr_place) {
return false;
}
} else {
false
// If we don't need hoist discriminant, the branch must not have any statements.
if !branch.statements.is_empty() {
return false;
}
// The place on `SwitchInt` must be the same.
if *switch_op != Operand::Copy(place) {
return false;
}
}
// It must fall through to `destination` if the switch misses.
if destination != targets.otherwise() {
return false;
}
// It must have exactly one branch for value `value` and have no more branches.
let mut iter = targets.iter();
let (Some((target_value, _)), None) = (iter.next(), iter.next()) else {
return false;
};
target_value == value
}
Loading

0 comments on commit a772336

Please sign in to comment.