Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

More powerful const panic #90488

Closed
wants to merge 8 commits into from
Closed
18 changes: 14 additions & 4 deletions compiler/rustc_builtin_macros/src/format.rs
Original file line number Diff line number Diff line change
Expand Up @@ -889,11 +889,21 @@ impl<'a, 'b> Context<'a, 'b> {
return ecx.expr_call_global(macsp, path, vec![arg]);
}
};
let new_fn_name = match trait_ {
"Display" => "new_display",
"Debug" => "new_debug",
"LowerExp" => "new_lower_exp",
"UpperExp" => "new_upper_exp",
"Octal" => "new_octal",
"Pointer" => "new_pointer",
"Binary" => "new_binary",
"LowerHex" => "new_lower_hex",
"UpperHex" => "new_upper_hex",
_ => unreachable!(),
};

let path = ecx.std_path(&[sym::fmt, Symbol::intern(trait_), sym::fmt]);
let format_fn = ecx.path_global(sp, path);
let path = ecx.std_path(&[sym::fmt, sym::ArgumentV1, sym::new]);
ecx.expr_call_global(macsp, path, vec![arg, ecx.expr_path(format_fn)])
let path = ecx.std_path(&[sym::fmt, sym::ArgumentV1, Symbol::intern(new_fn_name)]);
ecx.expr_call_global(sp, path, vec![arg])
}
}

Expand Down
34 changes: 10 additions & 24 deletions compiler/rustc_const_eval/src/const_eval/machine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,36 +48,22 @@ impl<'mir, 'tcx> InterpCx<'mir, 'tcx, CompileTimeInterpreter<'mir, 'tcx>> {
.unwrap(),
));
}
} else if Some(def_id) == self.tcx.lang_items().panic_display()
|| Some(def_id) == self.tcx.lang_items().begin_panic_fn()
{
// &str or &&str
} else if Some(def_id) == self.tcx.lang_items().begin_panic_fn() {
assert!(args.len() == 1);

let mut msg_place = self.deref_operand(&args[0])?;
while msg_place.layout.ty.is_ref() {
msg_place = self.deref_operand(&msg_place.into())?;
}

let msg = Symbol::intern(self.read_str(&msg_place)?);
let msg = self.eval_const_panic_any(args[0])?;
let msg = Symbol::intern(&msg);
let span = self.find_closest_untracked_caller_location();
let (file, line, col) = self.location_triple_for_span(span);
return Err(ConstEvalErrKind::Panic { msg, file, line, col }.into());
} else if Some(def_id) == self.tcx.lang_items().panic_fmt() {
// For panic_fmt, call const_panic_fmt instead.
if let Some(const_panic_fmt) = self.tcx.lang_items().const_panic_fmt() {
return Ok(Some(
ty::Instance::resolve(
*self.tcx,
ty::ParamEnv::reveal_all(),
const_panic_fmt,
self.tcx.intern_substs(&[]),
)
.unwrap()
.unwrap(),
));
}
assert!(args.len() == 1);
let msg = self.eval_const_panic_fmt(args[0])?;
let msg = Symbol::intern(&msg);
let span = self.find_closest_untracked_caller_location();
let (file, line, col) = self.location_triple_for_span(span);
return Err(ConstEvalErrKind::Panic { msg, file, line, col }.into());
}

Ok(None)
}
}
Expand Down
1 change: 1 addition & 0 deletions compiler/rustc_const_eval/src/const_eval/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ mod error;
mod eval_queries;
mod fn_queries;
mod machine;
mod panic;

pub use error::*;
pub use eval_queries::*;
Expand Down
262 changes: 262 additions & 0 deletions compiler/rustc_const_eval/src/const_eval/panic.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,262 @@
use rustc_hir::def_id::DefId;
use rustc_middle::ty::{self, layout::LayoutOf, subst::Subst};
use rustc_span::sym;
use std::cell::Cell;
use std::fmt::{self, Debug, Formatter};

use crate::interpret::{FnVal, InterpCx, InterpErrorInfo, InterpResult, OpTy};

use super::CompileTimeInterpreter;

struct Arg<'mir, 'tcx, 'err> {
cx: &'err InterpCx<'mir, 'tcx, CompileTimeInterpreter<'mir, 'tcx>>,
arg: OpTy<'tcx>,
fmt_trait: DefId,
err: &'err Cell<Option<InterpErrorInfo<'tcx>>>,
}

impl Debug for Arg<'_, '_, '_> {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
match self.cx.fmt_arg(self.arg, self.fmt_trait, f) {
Ok(_) => Ok(()),
Err(e) => {
self.err.set(Some(e));
Err(fmt::Error)
}
}
}
}

impl<'mir, 'tcx> InterpCx<'mir, 'tcx, CompileTimeInterpreter<'mir, 'tcx>> {
fn fmt_arg(
&self,
arg: OpTy<'tcx>,
fmt_trait: DefId,
f: &mut Formatter<'_>,
) -> InterpResult<'tcx> {
let fmt_trait_sym = self.tcx.item_name(fmt_trait);
let fmt_trait_name = fmt_trait_sym.as_str();

macro_rules! dispatch_fmt {
($e: expr, $($t: ident)|*) => {
let _ = match &*fmt_trait_name {
$(stringify!($t) => fmt::$t::fmt($e, f),)*
_ => Debug::fmt($e, f),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add a comment explaining what this macros does? I cannot make sense of it.

Also please not explicitly that we are dispatching to the standard library formatting machinery here, so we are guaranteed that this stays in sync with the logic in the standard library -- is that correct?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This just generates match arms like "Display" => fmt::Display::fmt(e, f),.

Yes, it's correct that this would mean the formatting result to be in sync with std. Though I don't think we want to guarantee that to the user.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems unfortunate to do string comparisons on the trait name. I wonder if it would be better to make Display etc. lang items instead?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They certainly shouldn't be lang items but could be made into diagnostic items. With that said, format_args! expansion in rustc_builtin_macros is using strings to name these as well, so I don't think it's a big deal.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it's correct that this would mean the formatting result to be in sync with std. Though I don't think we want to guarantee that to the user.

I think we have to make this guarantee. All this machinery should be entirely invisible, it just works around limitations in the main CTFE machinery.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, instead of having a Debug fallback case, shouldn't this bug! in case of an unknown trait name, to ensure we do dispatch to the right code?

};
}
}

match arg.layout.ty.kind() {
ty::Bool => {
let v = self.read_scalar(&arg)?.to_bool()?;
dispatch_fmt!(&v, Display);
}
ty::Char => {
let v = self.read_scalar(&arg)?.to_char()?;
dispatch_fmt!(&v, Display);
}
ty::Int(int_ty) => {
let v = self.read_scalar(&arg)?.check_init()?;
let v = match int_ty {
ty::IntTy::I8 => v.to_i8()?.into(),
ty::IntTy::I16 => v.to_i16()?.into(),
ty::IntTy::I32 => v.to_i32()?.into(),
ty::IntTy::I64 => v.to_i64()?.into(),
ty::IntTy::I128 => v.to_i128()?,
ty::IntTy::Isize => v.to_machine_isize(self)?.into(),
};
dispatch_fmt!(
&v,
Display | Binary | Octal | LowerHex | UpperHex | LowerExp | UpperExp
);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This makes some assumptions about how integers are formatted, namely that using a larger integer type will not make a difference. This assumption should at least be explicitly documented in a comment here.

Also, do these LowerHex etc things her mean that hex output is supported? I don't see the tests covering that possibility.

}
ty::Uint(int_ty) => {
let v = self.read_scalar(&arg)?.check_init()?;
let v = match int_ty {
ty::UintTy::U8 => v.to_u8()?.into(),
ty::UintTy::U16 => v.to_u16()?.into(),
ty::UintTy::U32 => v.to_u32()?.into(),
ty::UintTy::U64 => v.to_u64()?.into(),
ty::UintTy::U128 => v.to_u128()?,
ty::UintTy::Usize => v.to_machine_usize(self)?.into(),
};
dispatch_fmt!(
&v,
Display | Binary | Octal | LowerHex | UpperHex | LowerExp | UpperExp
);
}
ty::Float(ty::FloatTy::F32) => {
let v = f32::from_bits(self.read_scalar(&arg)?.to_u32()?);
dispatch_fmt!(&v, Display);
}
ty::Float(ty::FloatTy::F64) => {
let v = f64::from_bits(self.read_scalar(&arg)?.to_u64()?);
dispatch_fmt!(&v, Display);
}
ty::Str => {
let Ok(place) = arg.try_as_mplace() else {
bug!("str is not in MemPlace");
};
let v = self.read_str(&place)?;
dispatch_fmt!(v, Display);
}
ty::Array(..) | ty::Slice(..) => {
let Ok(place) = arg.try_as_mplace() else {
bug!("array/slice is not in MemPlace");
};
let err = Cell::new(None);
let mut debug_list = f.debug_list();
for field in self.mplace_array_fields(&place)? {
debug_list.entry(&Arg { cx: self, arg: field?.into(), fmt_trait, err: &err });
}
let _ = debug_list.finish();
if let Some(e) = err.into_inner() {
return Err(e);
}
}
ty::RawPtr(..) | ty::FnPtr(..) => {
// This isn't precisely how Pointer is implemented, but it's best we can do.
let ptr = self.read_pointer(&arg)?;
let _ = write!(f, "{:?}", ptr);
}
ty::Tuple(substs) => {
let err = Cell::new(None);
let mut debug_tuple = f.debug_tuple("");
for i in 0..substs.len() {
debug_tuple.field(&Arg {
cx: self,
arg: self.operand_field(&arg, i)?,
fmt_trait,
err: &err,
});
}
let _ = debug_tuple.finish();
if let Some(e) = err.into_inner() {
return Err(e);
}
}

// FIXME(nbdd0121): extend to allow fmt trait as super trait
ty::Dynamic(list, _) if list.principal_def_id() == Some(fmt_trait) => {
let Ok(place) = arg.try_as_mplace() else {
bug!("dyn is not in MemPlace");
};
let place = self.unpack_dyn_trait(&place)?.1;
return self.fmt_arg(place.into(), fmt_trait, f);
}

ty::Ref(..) if fmt_trait_name == "Pointer" => {
let ptr = self.read_pointer(&arg)?;
let _ = write!(f, "{:?}", ptr);
}
ty::Ref(..) => {
// FIXME(nbdd0121): User can implement trait on &UserType, so this isn't always correct.
let place = self.deref_operand(&arg)?;
return self.fmt_arg(place.into(), fmt_trait, f);
}

ty::Adt(adt, _) if self.tcx.is_diagnostic_item(sym::Arguments, adt.did) => {
return self.fmt_arguments(arg, f);
}
ty::Adt(adt, _) if self.tcx.is_diagnostic_item(sym::String, adt.did) => {
// NOTE(nbdd0121): const `String` can only be empty.
dispatch_fmt!("", Display);
}

// FIXME(nbdd0121): ty::Adt(..) => (),
_ => {
let _ = write!(f, "<failed to format {}>", arg.layout.ty);
}
}
Ok(())
}

fn fmt_arguments(&self, arguments: OpTy<'tcx>, f: &mut Formatter<'_>) -> InterpResult<'tcx> {
// Check we are dealing with the simple form
let fmt_variant_idx = self.read_discriminant(&self.operand_field(&arguments, 1)?)?.1;
if fmt_variant_idx.as_usize() != 0 {
// FIXME(nbdd0121): implement complex format
let _ = write!(f, "<cannot evaluate complex format>");
return Ok(());
}

// `pieces: &[&str]`
let pieces_place = self.deref_operand(&self.operand_field(&arguments, 0)?)?;
let mut pieces = Vec::new();
for piece in self.mplace_array_fields(&pieces_place)? {
let piece: OpTy<'tcx> = piece?.into();
pieces.push(self.read_str(&self.deref_operand(&piece)?)?);
}

// `args: &[ArgumentV1]`
let args_place = self.deref_operand(&self.operand_field(&arguments, 2)?)?;
let mut args = Vec::new();
let err = Cell::new(None);
for arg in self.mplace_array_fields(&args_place)? {
let arg: OpTy<'tcx> = arg?.into();

let fmt_fn = self.memory.get_fn(self.read_pointer(&self.operand_field(&arg, 1)?)?)?;
let fmt_fn = match fmt_fn {
FnVal::Instance(instance) => instance,
FnVal::Other(o) => match o {},
};

// The formatter must an instance of fmt method of a fmt trait.
let Some(fmt_impl) = self.tcx.impl_of_method(fmt_fn.def_id()) else {
throw_unsup_format!("fmt function is not from trait impl")
};
let Some(fmt_trait) = self.tcx.impl_trait_ref(fmt_impl) else {
throw_unsup_format!("fmt function is not from trait impl")
};

// Retrieve the trait ref with concrete self ty.
let fmt_trait = fmt_trait.subst(*self.tcx, &fmt_fn.substs);

// Change the opaque type into the actual type.
let mut value_place = self.deref_operand(&self.operand_field(&arg, 0)?)?;
value_place.layout = self.layout_of(fmt_trait.self_ty())?;

args.push(Arg {
cx: self,
arg: value_place.into(),
fmt_trait: fmt_trait.def_id,
err: &err,
});
}

// SAFETY: This transmutes `&[&str]` to `&[&'static str]` so it can be used in
// `core::fmt::Arguments`. The slice will not be used after `write_fmt`.
let static_pieces = unsafe { core::mem::transmute(&pieces[..]) };
let arg_v1s = args.iter().map(|x| fmt::ArgumentV1::new(x, Debug::fmt)).collect::<Vec<_>>();
let fmt_args = fmt::Arguments::new_v1(static_pieces, &arg_v1s);
let _ = f.write_fmt(fmt_args);
if let Some(v) = err.into_inner() {
return Err(v);
}
Ok(())
}

pub(super) fn eval_const_panic_fmt(
&mut self,
arguments: OpTy<'tcx>,
) -> InterpResult<'tcx, String> {
let mut msg = String::new();
let mut formatter = Formatter::new(&mut msg);
self.fmt_arguments(arguments, &mut formatter)?;
Ok(msg)
}

pub(super) fn eval_const_panic_any(&mut self, arg: OpTy<'tcx>) -> InterpResult<'tcx, String> {
match arg.layout.ty.kind() {
ty::Ref(_, ty, _) if ty.is_str() => {
let place = self.deref_operand(&arg)?;
Ok(self.read_str(&place)?.to_string())
}
ty::Adt(adt, _) if self.tcx.is_diagnostic_item(sym::String, adt.did) => {
// NOTE(nbdd0121): const `String` can only be empty.
Ok(String::new())
}
_ => Ok("Box<dyn Any>".to_string()),
}
}
}
4 changes: 2 additions & 2 deletions compiler/rustc_const_eval/src/interpret/place.rs
Original file line number Diff line number Diff line change
Expand Up @@ -420,7 +420,7 @@ where

// Iterates over all fields of an array. Much more efficient than doing the
// same by repeatedly calling `mplace_array`.
pub(super) fn mplace_array_fields(
pub(crate) fn mplace_array_fields(
&self,
base: &'a MPlaceTy<'tcx, Tag>,
) -> InterpResult<'tcx, impl Iterator<Item = InterpResult<'tcx, MPlaceTy<'tcx, Tag>>> + 'a>
Expand Down Expand Up @@ -1082,7 +1082,7 @@ where

/// Turn a place with a `dyn Trait` type into a place with the actual dynamic type.
/// Also return some more information so drop doesn't have to run the same code twice.
pub(super) fn unpack_dyn_trait(
pub(crate) fn unpack_dyn_trait(
&self,
mplace: &MPlaceTy<'tcx, M::PointerTag>,
) -> InterpResult<'tcx, (ty::Instance<'tcx>, MPlaceTy<'tcx, M::PointerTag>)> {
Expand Down
1 change: 1 addition & 0 deletions compiler/rustc_const_eval/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ Rust MIR: a lowered representation of Rust.
#![feature(trusted_len)]
#![feature(trusted_step)]
#![feature(try_blocks)]
#![feature(fmt_internals)]
#![recursion_limit = "256"]

#[macro_use]
Expand Down
4 changes: 2 additions & 2 deletions compiler/rustc_const_eval/src/transform/check_consts/check.rs
Original file line number Diff line number Diff line change
Expand Up @@ -893,7 +893,7 @@ impl Visitor<'tcx> for Checker<'mir, 'tcx> {
if Some(callee) == tcx.lang_items().begin_panic_fn() {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The comment on the previous line seems outdated now? (Same below.)

match args[0].ty(&self.ccx.body.local_decls, tcx).kind() {
ty::Ref(_, ty, _) if ty.is_str() => return,
_ => self.check_op(ops::PanicNonStr),
_ => (),
}
}

Expand All @@ -904,7 +904,7 @@ impl Visitor<'tcx> for Checker<'mir, 'tcx> {
{
return;
}
_ => self.check_op(ops::PanicNonStr),
_ => (),
}
}

Expand Down
Loading