Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Implement trait dispatch in the comptime interpreter #5376

Merged
merged 7 commits into from
Jul 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion compiler/noirc_frontend/src/elaborator/expressions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -547,7 +547,7 @@ impl<'context> Elaborator<'context> {
trait_id: trait_id.trait_id,
trait_generics: Vec::new(),
};
self.trait_constraints.push((constraint, expr_id));
self.push_trait_constraint(constraint, expr_id);
self.type_check_operator_method(expr_id, trait_id, &lhs_type, span);
}
typ
Expand Down Expand Up @@ -663,7 +663,14 @@ impl<'context> Elaborator<'context> {
}

fn elaborate_comptime_block(&mut self, block: BlockExpression, span: Span) -> (ExprId, Type) {
// We have to push a new FunctionContext so that we can resolve any constraints
// in this comptime block early before the function as a whole finishes elaborating.
// Otherwise the interpreter below may find expressions for which the underlying trait
// call is not yet solved for.
self.function_context.push(Default::default());
let (block, _typ) = self.elaborate_block_expression(block);
self.check_and_pop_function_context();

let mut interpreter =
Interpreter::new(self.interner, &mut self.comptime_scopes, self.crate_id);
let value = interpreter.evaluate_block(block);
Expand Down
101 changes: 58 additions & 43 deletions compiler/noirc_frontend/src/elaborator/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -138,16 +138,14 @@ pub struct Elaborator<'context> {

current_function: Option<FuncId>,

/// All type variables created in the current function.
/// This map is used to default any integer type variables at the end of
/// a function (before checking trait constraints) if a type wasn't already chosen.
type_variables: Vec<Type>,

/// Trait constraints are collected during type checking until they are
/// verified at the end of a function. This is because constraints arise
/// on each variable, but it is only until function calls when the types
/// needed for the trait constraint may become known.
trait_constraints: Vec<(TraitConstraint, ExprId)>,
/// This is a stack of function contexts. Most of the time, for each function we
/// expect this to be of length one, containing each type variable and trait constraint
/// used in the function. This is also pushed to when a `comptime {}` block is used within
/// the function. Since it can force us to resolve that block's trait constraints earlier
/// so that they are resolved when the interpreter is run before the enclosing function
/// is finished elaborating. When this happens, we need to resolve any type variables
/// that were made within this block as well so that we can solve these traits.
function_context: Vec<FunctionContext>,

/// The current module this elaborator is in.
/// Initially empty, it is set whenever a new top-level item is resolved.
Expand All @@ -166,6 +164,20 @@ pub struct Elaborator<'context> {
unresolved_globals: BTreeMap<GlobalId, UnresolvedGlobal>,
}

#[derive(Default)]
struct FunctionContext {
/// All type variables created in the current function.
/// This map is used to default any integer type variables at the end of
/// a function (before checking trait constraints) if a type wasn't already chosen.
type_variables: Vec<Type>,

/// Trait constraints are collected during type checking until they are
/// verified at the end of a function. This is because constraints arise
/// on each variable, but it is only until function calls when the types
/// needed for the trait constraint may become known.
trait_constraints: Vec<(TraitConstraint, ExprId)>,
}

impl<'context> Elaborator<'context> {
pub fn new(context: &'context mut Context, crate_id: CrateId) -> Self {
Self {
Expand All @@ -185,8 +197,7 @@ impl<'context> Elaborator<'context> {
resolving_ids: BTreeSet::new(),
trait_bounds: Vec::new(),
current_function: None,
type_variables: Vec::new(),
trait_constraints: Vec::new(),
function_context: vec![FunctionContext::default()],
current_trait_impl: None,
comptime_scopes: vec![HashMap::default()],
unresolved_globals: BTreeMap::new(),
Expand Down Expand Up @@ -326,6 +337,7 @@ impl<'context> Elaborator<'context> {
let func_meta = func_meta.clone();

self.trait_bounds = func_meta.trait_constraints.clone();
self.function_context.push(FunctionContext::default());

// Introduce all numeric generics into scope
for generic in &func_meta.all_generics {
Expand Down Expand Up @@ -367,34 +379,11 @@ impl<'context> Elaborator<'context> {
self.type_check_function_body(body_type, &func_meta, hir_func.as_expr());
}

// Default any type variables that still need defaulting.
// Default any type variables that still need defaulting and
// verify any remaining trait constraints arising from the function body.
// This is done before trait impl search since leaving them bindable can lead to errors
// when multiple impls are available. Instead we default first to choose the Field or u64 impl.
for typ in &self.type_variables {
if let Type::TypeVariable(variable, kind) = typ.follow_bindings() {
let msg = "TypeChecker should only track defaultable type vars";
variable.bind(kind.default_type().expect(msg));
}
}

// Verify any remaining trait constraints arising from the function body
for (mut constraint, expr_id) in std::mem::take(&mut self.trait_constraints) {
let span = self.interner.expr_span(&expr_id);

if matches!(&constraint.typ, Type::MutableReference(_)) {
let (_, dereferenced_typ) =
self.insert_auto_dereferences(expr_id, constraint.typ.clone());
constraint.typ = dereferenced_typ;
}

self.verify_trait_constraint(
&constraint.typ,
constraint.trait_id,
&constraint.trait_generics,
expr_id,
span,
);
}
self.check_and_pop_function_context();

// Now remove all the `where` clause constraints we added
for constraint in &func_meta.trait_constraints {
Expand All @@ -417,12 +406,42 @@ impl<'context> Elaborator<'context> {
meta.function_body = FunctionBody::Resolved;

self.trait_bounds.clear();
self.type_variables.clear();
self.interner.update_fn(id, hir_func);
self.current_function = old_function;
self.current_item = old_item;
}

/// Defaults all type variables used in this function context then solves
/// all still-unsolved trait constraints in this context.
fn check_and_pop_function_context(&mut self) {
jfecher marked this conversation as resolved.
Show resolved Hide resolved
let context = self.function_context.pop().expect("Imbalanced function_context pushes");

for typ in context.type_variables {
if let Type::TypeVariable(variable, kind) = typ.follow_bindings() {
let msg = "TypeChecker should only track defaultable type vars";
variable.bind(kind.default_type().expect(msg));
}
}

for (mut constraint, expr_id) in context.trait_constraints {
let span = self.interner.expr_span(&expr_id);

if matches!(&constraint.typ, Type::MutableReference(_)) {
let (_, dereferenced_typ) =
self.insert_auto_dereferences(expr_id, constraint.typ.clone());
constraint.typ = dereferenced_typ;
}

self.verify_trait_constraint(
&constraint.typ,
constraint.trait_id,
&constraint.trait_generics,
expr_id,
span,
);
}
}

/// This turns function parameters of the form:
/// `fn foo(x: impl Bar)`
///
Expand Down Expand Up @@ -1339,10 +1358,6 @@ impl<'context> Elaborator<'context> {
self.elaborate_comptime_global(global_id);
}

// Avoid defaulting the types of globals here since they may be used in any function.
// Otherwise we may prematurely default to a Field inside the next function if this
// global was unused there, even if it is consistently used as a u8 everywhere else.
self.type_variables.clear();
self.local_module = old_module;
self.file = old_file;
self.current_item = old_item;
Expand Down
5 changes: 3 additions & 2 deletions compiler/noirc_frontend/src/elaborator/patterns.rs
Original file line number Diff line number Diff line change
Expand Up @@ -396,6 +396,7 @@ impl<'context> Elaborator<'context> {
let expr = self.resolve_variable(variable);

let id = self.interner.push_expr(HirExpression::Ident(expr.clone(), generics.clone()));

self.interner.push_expr_location(id, span, self.file);
let typ = self.type_check_variable(expr, id, generics);
self.interner.push_expr_type(id, typ.clone());
Expand Down Expand Up @@ -516,7 +517,7 @@ impl<'context> Elaborator<'context> {

for mut constraint in function.trait_constraints.clone() {
constraint.apply_bindings(&bindings);
self.trait_constraints.push((constraint, expr_id));
self.push_trait_constraint(constraint, expr_id);
}
}
}
Expand All @@ -533,7 +534,7 @@ impl<'context> Elaborator<'context> {
// Currently only one impl can be selected per expr_id, so this
// constraint needs to be pushed after any other constraints so
// that monomorphization can resolve this trait method to the correct impl.
self.trait_constraints.push((constraint, expr_id));
self.push_trait_constraint(constraint, expr_id);
}
}

Expand Down
7 changes: 7 additions & 0 deletions compiler/noirc_frontend/src/elaborator/statements.rs
Original file line number Diff line number Diff line change
Expand Up @@ -433,8 +433,15 @@ impl<'context> Elaborator<'context> {
}

fn elaborate_comptime_statement(&mut self, statement: Statement) -> (HirStatement, Type) {
// We have to push a new FunctionContext so that we can resolve any constraints
// in this comptime block early before the function as a whole finishes elaborating.
// Otherwise the interpreter below may find expressions for which the underlying trait
// call is not yet solved for.
self.function_context.push(Default::default());
let span = statement.span;
let (hir_statement, _typ) = self.elaborate_statement(statement);
self.check_and_pop_function_context();

let mut interpreter =
Interpreter::new(self.interner, &mut self.comptime_scopes, self.crate_id);
let value = interpreter.evaluate_statement(hir_statement);
Expand Down
46 changes: 23 additions & 23 deletions compiler/noirc_frontend/src/elaborator/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ use crate::{
errors::ResolverError,
resolver::{verify_mutable_reference, SELF_TYPE_NAME, WILDCARD_TYPE},
},
type_check::{Source, TypeCheckError},
type_check::{NoMatchingImplFoundError, Source, TypeCheckError},
},
hir_def::{
expr::{
Expand Down Expand Up @@ -615,15 +615,15 @@ impl<'context> Elaborator<'context> {
/// in self.type_variables to default it later.
pub(super) fn polymorphic_integer_or_field(&mut self) -> Type {
let typ = Type::polymorphic_integer_or_field(self.interner);
self.type_variables.push(typ.clone());
self.push_type_variable(typ.clone());
typ
}

/// Return a fresh integer type variable and log it
/// in self.type_variables to default it later.
pub(super) fn polymorphic_integer(&mut self) -> Type {
let typ = Type::polymorphic_integer(self.interner);
self.type_variables.push(typ.clone());
self.push_type_variable(typ.clone());
typ
}

Expand Down Expand Up @@ -1410,26 +1410,10 @@ impl<'context> Elaborator<'context> {
Err(erroring_constraints) => {
if erroring_constraints.is_empty() {
self.push_err(TypeCheckError::TypeAnnotationsNeeded { span });
} else {
// Don't show any errors where try_get_trait returns None.
// This can happen if a trait is used that was never declared.
let constraints = erroring_constraints
.into_iter()
.map(|constraint| {
let r#trait = self.interner.try_get_trait(constraint.trait_id)?;
let mut name = r#trait.name.to_string();
if !constraint.trait_generics.is_empty() {
let generics =
vecmap(&constraint.trait_generics, ToString::to_string);
name += &format!("<{}>", generics.join(", "));
}
Some((constraint.typ, name))
})
.collect::<Option<Vec<_>>>();

if let Some(constraints) = constraints {
self.push_err(TypeCheckError::NoMatchingImplFound { constraints, span });
}
} else if let Some(error) =
NoMatchingImplFoundError::new(self.interner, erroring_constraints, span)
{
self.push_err(TypeCheckError::NoMatchingImplFound(error));
}
}
}
Expand Down Expand Up @@ -1557,4 +1541,20 @@ impl<'context> Elaborator<'context> {
}
}
}

/// Push a type variable into the current FunctionContext to be defaulted if needed
/// at the end of the earlier of either the current function or the current comptime scope.
fn push_type_variable(&mut self, typ: Type) {
let context = self.function_context.last_mut();
let context = context.expect("The function_context stack should always be non-empty");
context.type_variables.push(typ);
}

/// Push a trait constraint into the current FunctionContext to be solved if needed
/// at the end of the earlier of either the current function or the current comptime scope.
pub fn push_trait_constraint(&mut self, constraint: TraitConstraint, expr_id: ExprId) {
let context = self.function_context.last_mut();
let context = context.expect("The function_context stack should always be non-empty");
context.trait_constraints.push((constraint, expr_id));
}
}
16 changes: 15 additions & 1 deletion compiler/noirc_frontend/src/hir/comptime/errors.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
use std::rc::Rc;

use crate::{
hir::def_collector::dc_crate::CompilationError, parser::ParserError, token::Tokens, Type,
hir::{def_collector::dc_crate::CompilationError, type_check::NoMatchingImplFoundError},
parser::ParserError,
token::Tokens,
Type,
};
use acvm::{acir::AcirField, FieldElement};
use fm::FileId;
Expand Down Expand Up @@ -44,6 +47,8 @@
FailedToParseMacro { error: ParserError, tokens: Rc<Tokens>, rule: &'static str, file: FileId },
UnsupportedTopLevelItemUnquote { item: String, location: Location },
NonComptimeFnCallInSameCrate { function: String, location: Location },
NoImpl { location: Location },
NoMatchingImplFound { error: NoMatchingImplFoundError, file: FileId },

Unimplemented { item: String, location: Location },

Expand Down Expand Up @@ -106,11 +111,15 @@
| InterpreterError::UnsupportedTopLevelItemUnquote { location, .. }
| InterpreterError::NonComptimeFnCallInSameCrate { location, .. }
| InterpreterError::Unimplemented { location, .. }
| InterpreterError::NoImpl { location, .. }
| InterpreterError::BreakNotInLoop { location, .. }
| InterpreterError::ContinueNotInLoop { location, .. } => *location,
InterpreterError::FailedToParseMacro { error, file, .. } => {
Location::new(error.span(), *file)
}
InterpreterError::NoMatchingImplFound { error, file } => {
Location::new(error.span, *file)
}
InterpreterError::Break | InterpreterError::Continue => {
panic!("Tried to get the location of Break/Continue error!")
}
Expand Down Expand Up @@ -277,7 +286,7 @@
let message = format!("Failed to parse macro's token stream into {rule}");
let tokens = vecmap(&tokens.0, ToString::to_string).join(" ");

// 10 is an aribtrary number of tokens here chosen to fit roughly onto one line

Check warning on line 289 in compiler/noirc_frontend/src/hir/comptime/errors.rs

View workflow job for this annotation

GitHub Actions / Code

Unknown word (aribtrary)
let token_stream = if tokens.len() > 10 {
format!("The resulting token stream was: {tokens}")
} else {
Expand Down Expand Up @@ -324,6 +333,11 @@
let msg = "There is no loop to continue!".into();
CustomDiagnostic::simple_error(msg, String::new(), location.span)
}
InterpreterError::NoImpl { location } => {
let msg = "No impl found due to prior type error".into();
CustomDiagnostic::simple_error(msg, String::new(), location.span)
}
InterpreterError::NoMatchingImplFound { error, .. } => error.into(),
InterpreterError::Break => unreachable!("Uncaught InterpreterError::Break"),
InterpreterError::Continue => unreachable!("Uncaught InterpreterError::Continue"),
}
Expand Down
Loading
Loading