Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Handle discriminant in DataflowConstProp #107411

Merged
merged 8 commits into from
Feb 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions compiler/rustc_middle/src/mir/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1640,6 +1640,14 @@ impl<'tcx> PlaceRef<'tcx> {
}
}

/// Returns `true` if this `Place` contains a `Deref` projection.
///
/// If `Place::is_indirect` returns false, the caller knows that the `Place` refers to the
/// same region of memory as its base.
pub fn is_indirect(&self) -> bool {
self.projection.iter().any(|elem| elem.is_indirect())
}

/// If MirPhase >= Derefered and if projection contains Deref,
/// It's guaranteed to be in the first place
pub fn has_deref(&self) -> bool {
Expand Down
4 changes: 3 additions & 1 deletion compiler/rustc_mir_dataflow/src/impls/borrowed_locals.rs
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,9 @@ where
// for now. See discussion on [#61069].
//
// [#61069]: https://github.com/rust-lang/rust/pull/61069
self.trans.gen(dropped_place.local);
if !dropped_place.is_indirect() {
self.trans.gen(dropped_place.local);
}
}

TerminatorKind::Abort
Expand Down
1 change: 1 addition & 0 deletions compiler/rustc_mir_dataflow/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#![feature(associated_type_defaults)]
#![feature(box_patterns)]
#![feature(exact_size_is_empty)]
#![feature(let_chains)]
#![feature(min_specialization)]
#![feature(once_cell)]
#![feature(stmt_expr_attributes)]
Expand Down
275 changes: 214 additions & 61 deletions compiler/rustc_mir_dataflow/src/value_analysis.rs

Large diffs are not rendered by default.

127 changes: 95 additions & 32 deletions compiler/rustc_mir_transform/src/dataflow_const_prop.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ use rustc_mir_dataflow::value_analysis::{Map, State, TrackElem, ValueAnalysis, V
use rustc_mir_dataflow::{lattice::FlatSet, Analysis, ResultsVisitor, SwitchIntEdgeEffects};
use rustc_span::DUMMY_SP;
use rustc_target::abi::Align;
use rustc_target::abi::VariantIdx;

use crate::MirPass;

Expand All @@ -30,14 +31,12 @@ impl<'tcx> MirPass<'tcx> for DataflowConstProp {

#[instrument(skip_all level = "debug")]
fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
debug!(def_id = ?body.source.def_id());
if tcx.sess.mir_opt_level() < 4 && body.basic_blocks.len() > BLOCK_LIMIT {
debug!("aborted dataflow const prop due too many basic blocks");
return;
}

// Decide which places to track during the analysis.
let map = Map::from_filter(tcx, body, Ty::is_scalar);

// We want to have a somewhat linear runtime w.r.t. the number of statements/terminators.
// Let's call this number `n`. Dataflow analysis has `O(h*n)` transfer function
// applications, where `h` is the height of the lattice. Because the height of our lattice
Expand All @@ -46,10 +45,10 @@ impl<'tcx> MirPass<'tcx> for DataflowConstProp {
// `O(num_nodes * tracked_places * n)` in terms of time complexity. Since the number of
// map nodes is strongly correlated to the number of tracked places, this becomes more or
// less `O(n)` if we place a constant limit on the number of tracked places.
if tcx.sess.mir_opt_level() < 4 && map.tracked_places() > PLACE_LIMIT {
debug!("aborted dataflow const prop due to too many tracked places");
return;
}
let place_limit = if tcx.sess.mir_opt_level() < 4 { Some(PLACE_LIMIT) } else { None };

// Decide which places to track during the analysis.
let map = Map::from_filter(tcx, body, Ty::is_scalar, place_limit);

// Perform the actual dataflow analysis.
let analysis = ConstAnalysis::new(tcx, body, map);
Expand All @@ -63,14 +62,31 @@ impl<'tcx> MirPass<'tcx> for DataflowConstProp {
}
}

struct ConstAnalysis<'tcx> {
struct ConstAnalysis<'a, 'tcx> {
map: Map,
tcx: TyCtxt<'tcx>,
local_decls: &'a LocalDecls<'tcx>,
ecx: InterpCx<'tcx, 'tcx, DummyMachine>,
param_env: ty::ParamEnv<'tcx>,
}

impl<'tcx> ValueAnalysis<'tcx> for ConstAnalysis<'tcx> {
impl<'tcx> ConstAnalysis<'_, 'tcx> {
fn eval_discriminant(
&self,
enum_ty: Ty<'tcx>,
variant_index: VariantIdx,
) -> Option<ScalarTy<'tcx>> {
if !enum_ty.is_enum() {
return None;
}
let discr = enum_ty.discriminant_for_variant(self.tcx, variant_index)?;
let discr_layout = self.tcx.layout_of(self.param_env.and(discr.ty)).ok()?;
let discr_value = Scalar::try_from_uint(discr.val, discr_layout.size)?;
Some(ScalarTy(discr_value, discr.ty))
}
}

impl<'tcx> ValueAnalysis<'tcx> for ConstAnalysis<'_, 'tcx> {
type Value = FlatSet<ScalarTy<'tcx>>;

const NAME: &'static str = "ConstAnalysis";
Expand All @@ -79,6 +95,25 @@ impl<'tcx> ValueAnalysis<'tcx> for ConstAnalysis<'tcx> {
&self.map
}

fn handle_statement(&self, statement: &Statement<'tcx>, state: &mut State<Self::Value>) {
match statement.kind {
StatementKind::SetDiscriminant { box ref place, variant_index } => {
state.flood_discr(place.as_ref(), &self.map);
if self.map.find_discr(place.as_ref()).is_some() {
let enum_ty = place.ty(self.local_decls, self.tcx).ty;
if let Some(discr) = self.eval_discriminant(enum_ty, variant_index) {
state.assign_discr(
place.as_ref(),
ValueOrPlace::Value(FlatSet::Elem(discr)),
&self.map,
);
}
}
}
_ => self.super_statement(statement, state),
}
}

fn handle_assign(
&self,
target: Place<'tcx>,
Expand All @@ -87,36 +122,47 @@ impl<'tcx> ValueAnalysis<'tcx> for ConstAnalysis<'tcx> {
) {
match rvalue {
Rvalue::Aggregate(kind, operands) => {
let target = self.map().find(target.as_ref());
if let Some(target) = target {
state.flood_idx_with(target, self.map(), FlatSet::Bottom);
let field_based = match **kind {
AggregateKind::Tuple | AggregateKind::Closure(..) => true,
AggregateKind::Adt(def_id, ..) => {
matches!(self.tcx.def_kind(def_id), DefKind::Struct)
state.flood_with(target.as_ref(), self.map(), FlatSet::Bottom);
if let Some(target_idx) = self.map().find(target.as_ref()) {
let (variant_target, variant_index) = match **kind {
AggregateKind::Tuple | AggregateKind::Closure(..) => {
(Some(target_idx), None)
}
_ => false,
AggregateKind::Adt(def_id, variant_index, ..) => {
match self.tcx.def_kind(def_id) {
DefKind::Struct => (Some(target_idx), None),
DefKind::Enum => (Some(target_idx), Some(variant_index)),
_ => (None, None),
}
}
_ => (None, None),
};
if field_based {
if let Some(target) = variant_target {
for (field_index, operand) in operands.iter().enumerate() {
if let Some(field) = self
.map()
.apply(target, TrackElem::Field(Field::from_usize(field_index)))
{
let result = self.handle_operand(operand, state);
state.assign_idx(field, result, self.map());
state.insert_idx(field, result, self.map());
}
}
}
if let Some(variant_index) = variant_index
&& let Some(discr_idx) = self.map().apply(target_idx, TrackElem::Discriminant)
{
let enum_ty = target.ty(self.local_decls, self.tcx).ty;
if let Some(discr_val) = self.eval_discriminant(enum_ty, variant_index) {
state.insert_value_idx(discr_idx, FlatSet::Elem(discr_val), &self.map);
}
}
}
}
Rvalue::CheckedBinaryOp(op, box (left, right)) => {
// Flood everything now, so we can use `insert_value_idx` directly later.
state.flood(target.as_ref(), self.map());

let target = self.map().find(target.as_ref());
if let Some(target) = target {
// We should not track any projections other than
// what is overwritten below, but just in case...
state.flood_idx(target, self.map());
}

let value_target = target
.and_then(|target| self.map().apply(target, TrackElem::Field(0_u32.into())));
Expand All @@ -127,7 +173,8 @@ impl<'tcx> ValueAnalysis<'tcx> for ConstAnalysis<'tcx> {
let (val, overflow) = self.binary_op(state, *op, left, right);

if let Some(value_target) = value_target {
state.assign_idx(value_target, ValueOrPlace::Value(val), self.map());
// We have flooded `target` earlier.
state.insert_value_idx(value_target, val, self.map());
}
if let Some(overflow_target) = overflow_target {
let overflow = match overflow {
Expand All @@ -142,11 +189,8 @@ impl<'tcx> ValueAnalysis<'tcx> for ConstAnalysis<'tcx> {
}
FlatSet::Bottom => FlatSet::Bottom,
};
state.assign_idx(
overflow_target,
ValueOrPlace::Value(overflow),
self.map(),
);
// We have flooded `target` earlier.
state.insert_value_idx(overflow_target, overflow, self.map());
}
}
}
Expand Down Expand Up @@ -195,6 +239,9 @@ impl<'tcx> ValueAnalysis<'tcx> for ConstAnalysis<'tcx> {
FlatSet::Bottom => ValueOrPlace::Value(FlatSet::Bottom),
FlatSet::Top => ValueOrPlace::Value(FlatSet::Top),
},
Rvalue::Discriminant(place) => {
ValueOrPlace::Value(state.get_discr(place.as_ref(), self.map()))
}
_ => self.super_rvalue(rvalue, state),
}
}
Expand Down Expand Up @@ -268,12 +315,13 @@ impl<'tcx> std::fmt::Debug for ScalarTy<'tcx> {
}
}

impl<'tcx> ConstAnalysis<'tcx> {
pub fn new(tcx: TyCtxt<'tcx>, body: &Body<'tcx>, map: Map) -> Self {
impl<'a, 'tcx> ConstAnalysis<'a, 'tcx> {
pub fn new(tcx: TyCtxt<'tcx>, body: &'a Body<'tcx>, map: Map) -> Self {
let param_env = tcx.param_env(body.source.def_id());
Self {
map,
tcx,
local_decls: &body.local_decls,
ecx: InterpCx::new(tcx, DUMMY_SP, param_env, DummyMachine),
param_env: param_env,
}
Expand Down Expand Up @@ -466,6 +514,21 @@ impl<'tcx, 'map, 'a> Visitor<'tcx> for OperandCollector<'tcx, 'map, 'a> {
_ => (),
}
}

fn visit_rvalue(&mut self, rvalue: &Rvalue<'tcx>, location: Location) {
match rvalue {
Rvalue::Discriminant(place) => {
match self.state.get_discr(place.as_ref(), self.visitor.map) {
FlatSet::Top => (),
FlatSet::Elem(value) => {
self.visitor.before_effect.insert((location, *place), value);
}
FlatSet::Bottom => (),
}
}
_ => self.super_rvalue(rvalue, location),
}
}
}

struct DummyMachine;
Expand Down
18 changes: 10 additions & 8 deletions compiler/rustc_mir_transform/src/sroa.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use crate::MirPass;
use rustc_index::bit_set::BitSet;
use rustc_index::bit_set::{BitSet, GrowableBitSet};
use rustc_index::vec::IndexVec;
use rustc_middle::mir::patch::MirPatch;
use rustc_middle::mir::visit::*;
Expand All @@ -26,10 +26,12 @@ impl<'tcx> MirPass<'tcx> for ScalarReplacementOfAggregates {
debug!(?replacements);
let all_dead_locals = replace_flattened_locals(tcx, body, replacements);
if !all_dead_locals.is_empty() {
for local in excluded.indices() {
excluded[local] |= all_dead_locals.contains(local);
}
excluded.raw.resize(body.local_decls.len(), false);
excluded.union(&all_dead_locals);
excluded = {
let mut growable = GrowableBitSet::from(excluded);
growable.ensure(body.local_decls.len());
growable.into()
};
} else {
break;
}
Expand All @@ -44,11 +46,11 @@ impl<'tcx> MirPass<'tcx> for ScalarReplacementOfAggregates {
/// - the locals is a union or an enum;
/// - the local's address is taken, and thus the relative addresses of the fields are observable to
/// client code.
fn escaping_locals(excluded: &IndexVec<Local, bool>, body: &Body<'_>) -> BitSet<Local> {
fn escaping_locals(excluded: &BitSet<Local>, body: &Body<'_>) -> BitSet<Local> {
let mut set = BitSet::new_empty(body.local_decls.len());
set.insert_range(RETURN_PLACE..=Local::from_usize(body.arg_count));
for (local, decl) in body.local_decls().iter_enumerated() {
if decl.ty.is_union() || decl.ty.is_enum() || excluded[local] {
if decl.ty.is_union() || decl.ty.is_enum() || excluded.contains(local) {
set.insert(local);
}
}
Expand Down Expand Up @@ -172,7 +174,7 @@ fn replace_flattened_locals<'tcx>(
body: &mut Body<'tcx>,
replacements: ReplacementMap<'tcx>,
) -> BitSet<Local> {
let mut all_dead_locals = BitSet::new_empty(body.local_decls.len());
let mut all_dead_locals = BitSet::new_empty(replacements.fragments.len());
for (local, replacements) in replacements.fragments.iter_enumerated() {
if replacements.is_some() {
all_dead_locals.insert(local);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
- // MIR for `mutate_discriminant` before DataflowConstProp
+ // MIR for `mutate_discriminant` after DataflowConstProp

fn mutate_discriminant() -> u8 {
let mut _0: u8; // return place in scope 0 at $DIR/enum.rs:+0:29: +0:31
let mut _1: std::option::Option<NonZeroUsize>; // in scope 0 at $SRC_DIR/core/src/intrinsics/mir.rs:LL:COL
let mut _2: isize; // in scope 0 at $SRC_DIR/core/src/intrinsics/mir.rs:LL:COL

bb0: {
discriminant(_1) = 1; // scope 0 at $DIR/enum.rs:+4:13: +4:34
(((_1 as variant#1).0: NonZeroUsize).0: usize) = const 0_usize; // scope 0 at $DIR/enum.rs:+6:13: +6:64
_2 = discriminant(_1); // scope 0 at $SRC_DIR/core/src/intrinsics/mir.rs:LL:COL
switchInt(_2) -> [0: bb1, otherwise: bb2]; // scope 0 at $DIR/enum.rs:+9:13: +12:14
}

bb1: {
_0 = const 1_u8; // scope 0 at $DIR/enum.rs:+15:13: +15:20
return; // scope 0 at $DIR/enum.rs:+16:13: +16:21
}

bb2: {
_0 = const 2_u8; // scope 0 at $DIR/enum.rs:+19:13: +19:20
unreachable; // scope 0 at $DIR/enum.rs:+20:13: +20:26
}
}

45 changes: 42 additions & 3 deletions tests/mir-opt/dataflow-const-prop/enum.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,52 @@
// unit-test: DataflowConstProp

// Not trackable, because variants could be aliased.
#![feature(custom_mir, core_intrinsics, rustc_attrs)]

use std::intrinsics::mir::*;

enum E {
V1(i32),
V2(i32)
}

// EMIT_MIR enum.main.DataflowConstProp.diff
fn main() {
// EMIT_MIR enum.simple.DataflowConstProp.diff
fn simple() {
let e = E::V1(0);
let x = match e { E::V1(x) => x, E::V2(x) => x };
}

#[rustc_layout_scalar_valid_range_start(1)]
#[rustc_nonnull_optimization_guaranteed]
struct NonZeroUsize(usize);

// EMIT_MIR enum.mutate_discriminant.DataflowConstProp.diff
#[custom_mir(dialect = "runtime", phase = "post-cleanup")]
fn mutate_discriminant() -> u8 {
mir!(
let x: Option<NonZeroUsize>;
{
SetDiscriminant(x, 1);
// This assignment overwrites the niche in which the discriminant is stored.
place!(Field(Field(Variant(x, 1), 0), 0)) = 0_usize;
// So we cannot know the value of this discriminant.
let a = Discriminant(x);
match a {
0 => bb1,
_ => bad,
}
}
bb1 = {
RET = 1;
Return()
}
bad = {
RET = 2;
Unreachable()
}
)
}

fn main() {
simple();
mutate_discriminant();
}
Loading