Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Store the types of ty::Expr arguments in the ty::Expr #125968

Merged
merged 2 commits into from
Jun 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions compiler/rustc_middle/src/ty/consts.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ pub type ConstKind<'tcx> = ir::ConstKind<TyCtxt<'tcx>>;
pub type UnevaluatedConst<'tcx> = ir::UnevaluatedConst<TyCtxt<'tcx>>;

#[cfg(target_pointer_width = "64")]
rustc_data_structures::static_assert_size!(ConstKind<'_>, 32);
rustc_data_structures::static_assert_size!(ConstKind<'_>, 24);

/// Use this rather than `ConstData`, whenever possible.
#[derive(Copy, Clone, PartialEq, Eq, Hash, HashStable)]
Expand Down Expand Up @@ -58,7 +58,7 @@ pub struct ConstData<'tcx> {
}

#[cfg(target_pointer_width = "64")]
rustc_data_structures::static_assert_size!(ConstData<'_>, 40);
rustc_data_structures::static_assert_size!(ConstData<'_>, 32);

impl<'tcx> Const<'tcx> {
#[inline]
Expand Down
125 changes: 118 additions & 7 deletions compiler/rustc_middle/src/ty/consts/kind.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use super::Const;
use crate::mir;
use crate::ty::abstract_const::CastKind;
use crate::ty::{self, visit::TypeVisitableExt as _, List, Ty, TyCtxt};
use crate::ty::{self, visit::TypeVisitableExt as _, Ty, TyCtxt};
use rustc_macros::{extension, HashStable, TyDecodable, TyEncodable, TypeFoldable, TypeVisitable};

#[extension(pub(crate) trait UnevaluatedConstEvalExt<'tcx>)]
Expand Down Expand Up @@ -40,14 +40,125 @@ impl<'tcx> ty::UnevaluatedConst<'tcx> {
}
}

#[derive(Copy, Clone, Eq, PartialEq, Hash, Debug)]
#[derive(HashStable, TyEncodable, TyDecodable, TypeVisitable, TypeFoldable)]
pub enum ExprKind {
Binop(mir::BinOp),
UnOp(mir::UnOp),
FunctionCall,
Cast(CastKind),
}
#[derive(Copy, Clone, Eq, PartialEq, Hash)]
#[derive(HashStable, TyEncodable, TyDecodable, TypeVisitable, TypeFoldable)]
pub enum Expr<'tcx> {
Binop(mir::BinOp, Const<'tcx>, Const<'tcx>),
UnOp(mir::UnOp, Const<'tcx>),
FunctionCall(Const<'tcx>, &'tcx List<Const<'tcx>>),
Cast(CastKind, Const<'tcx>, Ty<'tcx>),
pub struct Expr<'tcx> {
pub kind: ExprKind,
args: ty::GenericArgsRef<'tcx>,
}
impl<'tcx> Expr<'tcx> {
pub fn new_binop(
tcx: TyCtxt<'tcx>,
binop: mir::BinOp,
lhs_ty: Ty<'tcx>,
rhs_ty: Ty<'tcx>,
lhs_ct: Const<'tcx>,
rhs_ct: Const<'tcx>,
) -> Self {
let args = tcx.mk_args_from_iter::<_, ty::GenericArg<'tcx>>(
[lhs_ty.into(), rhs_ty.into(), lhs_ct.into(), rhs_ct.into()].into_iter(),
);

Self { kind: ExprKind::Binop(binop), args }
}

pub fn binop_args(self) -> (Ty<'tcx>, Ty<'tcx>, Const<'tcx>, Const<'tcx>) {
assert!(matches!(self.kind, ExprKind::Binop(_)));

match self.args().as_slice() {
[lhs_ty, rhs_ty, lhs_ct, rhs_ct] => (
lhs_ty.expect_ty(),
rhs_ty.expect_ty(),
lhs_ct.expect_const(),
rhs_ct.expect_const(),
),
_ => bug!("Invalid args for `Binop` expr {self:?}"),
}
}

pub fn new_unop(tcx: TyCtxt<'tcx>, unop: mir::UnOp, ty: Ty<'tcx>, ct: Const<'tcx>) -> Self {
let args =
tcx.mk_args_from_iter::<_, ty::GenericArg<'tcx>>([ty.into(), ct.into()].into_iter());

Self { kind: ExprKind::UnOp(unop), args }
}

pub fn unop_args(self) -> (Ty<'tcx>, Const<'tcx>) {
assert!(matches!(self.kind, ExprKind::UnOp(_)));

match self.args().as_slice() {
[ty, ct] => (ty.expect_ty(), ct.expect_const()),
_ => bug!("Invalid args for `UnOp` expr {self:?}"),
}
}

pub fn new_call(
tcx: TyCtxt<'tcx>,
func_ty: Ty<'tcx>,
func_expr: Const<'tcx>,
arguments: impl Iterator<Item = Const<'tcx>>,
) -> Self {
let args = tcx.mk_args_from_iter::<_, ty::GenericArg<'tcx>>(
[func_ty.into(), func_expr.into()].into_iter().chain(arguments.map(|ct| ct.into())),
);

Self { kind: ExprKind::FunctionCall, args }
}

pub fn call_args(self) -> (Ty<'tcx>, Const<'tcx>, impl Iterator<Item = Const<'tcx>>) {
assert!(matches!(self.kind, ExprKind::FunctionCall));

match self.args().as_slice() {
[func_ty, func, rest @ ..] => (
func_ty.expect_ty(),
func.expect_const(),
rest.iter().map(|arg| arg.expect_const()),
),
_ => bug!("Invalid args for `Call` expr {self:?}"),
}
}

pub fn new_cast(
tcx: TyCtxt<'tcx>,
cast: CastKind,
value_ty: Ty<'tcx>,
value: Const<'tcx>,
to_ty: Ty<'tcx>,
) -> Self {
let args = tcx.mk_args_from_iter::<_, ty::GenericArg<'tcx>>(
[value_ty.into(), value.into(), to_ty.into()].into_iter(),
);

Self { kind: ExprKind::Cast(cast), args }
}

pub fn cast_args(self) -> (Ty<'tcx>, Const<'tcx>, Ty<'tcx>) {
assert!(matches!(self.kind, ExprKind::Cast(_)));

match self.args().as_slice() {
[value_ty, value, to_ty] => {
(value_ty.expect_ty(), value.expect_const(), to_ty.expect_ty())
}
_ => bug!("Invalid args for `Cast` expr {self:?}"),
}
}

pub fn new(kind: ExprKind, args: ty::GenericArgsRef<'tcx>) -> Self {
Self { kind, args }
}

pub fn args(&self) -> ty::GenericArgsRef<'tcx> {
self.args
}
}

#[cfg(target_pointer_width = "64")]
rustc_data_structures::static_assert_size!(Expr<'_>, 24);
rustc_data_structures::static_assert_size!(Expr<'_>, 16);
21 changes: 1 addition & 20 deletions compiler/rustc_middle/src/ty/flags.rs
Original file line number Diff line number Diff line change
Expand Up @@ -374,26 +374,7 @@ impl FlagComputation {
self.add_flags(TypeFlags::STILL_FURTHER_SPECIALIZABLE);
}
ty::ConstKind::Value(_) => {}
ty::ConstKind::Expr(e) => {
use ty::Expr;
match e {
Expr::Binop(_, l, r) => {
self.add_const(l);
self.add_const(r);
}
Expr::UnOp(_, v) => self.add_const(v),
Expr::FunctionCall(f, args) => {
self.add_const(f);
for arg in args {
self.add_const(arg);
}
}
Expr::Cast(_, c, t) => {
self.add_ty(t);
self.add_const(c);
}
}
}
ty::ConstKind::Expr(e) => self.add_args(e.args()),
ty::ConstKind::Error(_) => self.add_flags(TypeFlags::HAS_ERROR),
}
}
Expand Down
2 changes: 1 addition & 1 deletion compiler/rustc_middle/src/ty/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ pub use self::closure::{
CAPTURE_STRUCT_LOCAL,
};
pub use self::consts::{
Const, ConstData, ConstInt, ConstKind, Expr, ScalarInt, UnevaluatedConst, ValTree,
Const, ConstData, ConstInt, ConstKind, Expr, ExprKind, ScalarInt, UnevaluatedConst, ValTree,
};
pub use self::context::{
tls, CtxtInterners, CurrentGcx, DeducedParamAttrs, Feed, FreeRegionInfo, GlobalCtxt, Lift,
Expand Down
102 changes: 44 additions & 58 deletions compiler/rustc_middle/src/ty/print/pretty.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1533,8 +1533,10 @@ pub trait PrettyPrinter<'tcx>: Printer<'tcx> + fmt::Write {
print_ty: bool,
) -> Result<(), PrintError> {
define_scoped_cx!(self);
match expr {
Expr::Binop(op, c1, c2) => {
match expr.kind {
ty::ExprKind::Binop(op) => {
let (_, _, c1, c2) = expr.binop_args();

let precedence = |binop: rustc_middle::mir::BinOp| {
use rustc_ast::util::parser::AssocOp;
AssocOp::from_ast_binop(binop.to_hir_binop().into()).precedence()
Expand All @@ -1543,22 +1545,26 @@ pub trait PrettyPrinter<'tcx>: Printer<'tcx> + fmt::Write {
let formatted_op = op.to_hir_binop().as_str();
let (lhs_parenthesized, rhs_parenthesized) = match (c1.kind(), c2.kind()) {
(
ty::ConstKind::Expr(Expr::Binop(lhs_op, _, _)),
ty::ConstKind::Expr(Expr::Binop(rhs_op, _, _)),
ty::ConstKind::Expr(ty::Expr { kind: ty::ExprKind::Binop(lhs_op), .. }),
ty::ConstKind::Expr(ty::Expr { kind: ty::ExprKind::Binop(rhs_op), .. }),
) => (precedence(lhs_op) < op_precedence, precedence(rhs_op) < op_precedence),
(ty::ConstKind::Expr(Expr::Binop(lhs_op, ..)), ty::ConstKind::Expr(_)) => {
(precedence(lhs_op) < op_precedence, true)
}
(ty::ConstKind::Expr(_), ty::ConstKind::Expr(Expr::Binop(rhs_op, ..))) => {
(true, precedence(rhs_op) < op_precedence)
}
(
ty::ConstKind::Expr(ty::Expr { kind: ty::ExprKind::Binop(lhs_op), .. }),
ty::ConstKind::Expr(_),
) => (precedence(lhs_op) < op_precedence, true),
(
ty::ConstKind::Expr(_),
ty::ConstKind::Expr(ty::Expr { kind: ty::ExprKind::Binop(rhs_op), .. }),
) => (true, precedence(rhs_op) < op_precedence),
(ty::ConstKind::Expr(_), ty::ConstKind::Expr(_)) => (true, true),
(ty::ConstKind::Expr(Expr::Binop(lhs_op, ..)), _) => {
(precedence(lhs_op) < op_precedence, false)
}
(_, ty::ConstKind::Expr(Expr::Binop(rhs_op, ..))) => {
(false, precedence(rhs_op) < op_precedence)
}
(
ty::ConstKind::Expr(ty::Expr { kind: ty::ExprKind::Binop(lhs_op), .. }),
_,
) => (precedence(lhs_op) < op_precedence, false),
(
_,
ty::ConstKind::Expr(ty::Expr { kind: ty::ExprKind::Binop(rhs_op), .. }),
) => (false, precedence(rhs_op) < op_precedence),
(ty::ConstKind::Expr(_), _) => (true, false),
(_, ty::ConstKind::Expr(_)) => (false, true),
_ => (false, false),
Expand All @@ -1574,7 +1580,9 @@ pub trait PrettyPrinter<'tcx>: Printer<'tcx> + fmt::Write {
rhs_parenthesized,
)?;
}
Expr::UnOp(op, ct) => {
ty::ExprKind::UnOp(op) => {
let (_, ct) = expr.unop_args();

use rustc_middle::mir::UnOp;
let formatted_op = match op {
UnOp::Not => "!",
Expand All @@ -1583,7 +1591,9 @@ pub trait PrettyPrinter<'tcx>: Printer<'tcx> + fmt::Write {
};
let parenthesized = match ct.kind() {
_ if op == UnOp::PtrMetadata => true,
ty::ConstKind::Expr(Expr::UnOp(c_op, ..)) => c_op != op,
ty::ConstKind::Expr(ty::Expr { kind: ty::ExprKind::UnOp(c_op), .. }) => {
c_op != op
}
ty::ConstKind::Expr(_) => true,
_ => false,
};
Expand All @@ -1593,61 +1603,37 @@ pub trait PrettyPrinter<'tcx>: Printer<'tcx> + fmt::Write {
parenthesized,
)?
}
Expr::FunctionCall(fn_def, fn_args) => {
use ty::TyKind;
match fn_def.ty().kind() {
TyKind::FnDef(def_id, gen_args) => {
p!(print_value_path(*def_id, gen_args), "(");
if print_ty {
let tcx = self.tcx();
let sig = tcx.fn_sig(def_id).instantiate(tcx, gen_args).skip_binder();

let mut args_with_ty = fn_args.iter().map(|ct| (ct, ct.ty()));
let output_ty = sig.output();

if let Some((ct, ty)) = args_with_ty.next() {
self.typed_value(
|this| this.pretty_print_const(ct, print_ty),
|this| this.pretty_print_type(ty),
": ",
)?;
for (ct, ty) in args_with_ty {
p!(", ");
self.typed_value(
|this| this.pretty_print_const(ct, print_ty),
|this| this.pretty_print_type(ty),
": ",
)?;
}
}
p!(write(") -> {output_ty}"));
} else {
p!(comma_sep(fn_args.iter()), ")");
}
}
_ => bug!("unexpected type of fn def"),
}
ty::ExprKind::FunctionCall => {
let (_, fn_def, fn_args) = expr.call_args();

write!(self, "(")?;
self.pretty_print_const(fn_def, print_ty)?;
p!(")(", comma_sep(fn_args), ")");
}
Expr::Cast(kind, ct, ty) => {
ty::ExprKind::Cast(kind) => {
let (_, value, to_ty) = expr.cast_args();

use ty::abstract_const::CastKind;
if kind == CastKind::As || (kind == CastKind::Use && self.should_print_verbose()) {
let parenthesized = match ct.kind() {
ty::ConstKind::Expr(Expr::Cast(_, _, _)) => false,
let parenthesized = match value.kind() {
ty::ConstKind::Expr(ty::Expr {
kind: ty::ExprKind::Cast { .. }, ..
}) => false,
ty::ConstKind::Expr(_) => true,
_ => false,
};
self.maybe_parenthesized(
|this| {
this.typed_value(
|this| this.pretty_print_const(ct, print_ty),
|this| this.pretty_print_type(ty),
|this| this.pretty_print_const(value, print_ty),
|this| this.pretty_print_type(to_ty),
" as ",
)
},
parenthesized,
)?;
} else {
self.pretty_print_const(ct, print_ty)?
self.pretty_print_const(value, print_ty)?
}
}
}
Expand Down
Loading
Loading