Skip to content

Commit

Permalink
fix: Implement generic functions in the interpreter (#5330)
Browse files Browse the repository at this point in the history
# Description

## Problem\*

Resolves #4924

## Summary\*

Implements the calling of generic functions in the interpreter

## Additional Context

I've removed the "type_check" method entirely since it was causing
errors. Certain types wouldn't match up (even with `follow_bindings`). I
couldn't fix this so I removed it since it is duplicated work from type
checking anyway.

## Documentation\*

Check one:
- [x] No documentation needed.
- [ ] Documentation included in this PR.
- [ ] **[For Experimental Features]** Documentation to be submitted in a
separate PR.

# PR Checklist\*

- [x] I have tested the changes locally.
- [x] I have formatted the changes with [Prettier](https://prettier.io/)
and/or `cargo fmt` on default settings.
  • Loading branch information
jfecher authored Jun 26, 2024
1 parent efdd818 commit d8b9870
Show file tree
Hide file tree
Showing 6 changed files with 66 additions and 60 deletions.
3 changes: 2 additions & 1 deletion compiler/noirc_frontend/src/elaborator/expressions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -759,7 +759,8 @@ impl<'context> Elaborator<'context> {
return None;
}

let result = interpreter.call_function(function, comptime_args, location);
let bindings = interpreter.interner.get_instantiation_bindings(func).clone();
let result = interpreter.call_function(function, comptime_args, bindings, location);
let (expr_id, typ) = self.inline_comptime_value(result, location.span);
Some((self.interner.expression(&expr_id), typ))
}
Expand Down
4 changes: 2 additions & 2 deletions compiler/noirc_frontend/src/elaborator/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ use crate::{
DefinitionId, DefinitionKind, DependencyId, ExprId, FuncId, GlobalId, TraitId, TypeAliasId,
},
parser::TopLevelStatement,
Shared, Type, TypeVariable,
Shared, Type, TypeBindings, TypeVariable,
};
use crate::{
ast::{TraitBound, UnresolvedGeneric, UnresolvedGenerics},
Expand Down Expand Up @@ -1273,7 +1273,7 @@ impl<'context> Elaborator<'context> {
let arguments = vec![(Value::TypeDefinition(struct_id), location)];

let value = interpreter
.call_function(function, arguments, location)
.call_function(function, arguments, TypeBindings::new(), location)
.map_err(|error| error.into_compilation_error_pair())?;

if value != Value::Unit {
Expand Down
78 changes: 31 additions & 47 deletions compiler/noirc_frontend/src/hir/comptime/interpreter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ use rustc_hash::FxHashMap as HashMap;

use crate::ast::{BinaryOpKind, FunctionKind, IntegerBitSize, Signedness};
use crate::graph::CrateId;
use crate::monomorphization::{perform_instantiation_bindings, undo_instantiation_bindings};
use crate::token::Tokens;
use crate::{
hir_def::{
Expand All @@ -28,7 +29,7 @@ use crate::{
};

use super::errors::{IResult, InterpreterError};
use super::value::Value;
use super::value::{unwrap_rc, Value};

mod builtin;
mod unquote;
Expand Down Expand Up @@ -59,6 +60,19 @@ impl<'a> Interpreter<'a> {
}

pub(crate) fn call_function(
&mut self,
function: FuncId,
arguments: Vec<(Value, Location)>,
instantiation_bindings: TypeBindings,
location: Location,
) -> IResult<Value> {
perform_instantiation_bindings(&instantiation_bindings);
let result = self.call_function_inner(function, arguments, location);
undo_instantiation_bindings(instantiation_bindings);
result
}

fn call_function_inner(
&mut self,
function: FuncId,
arguments: Vec<(Value, Location)>,
Expand Down Expand Up @@ -200,7 +214,8 @@ impl<'a> Interpreter<'a> {
) -> IResult<()> {
match pattern {
HirPattern::Identifier(identifier) => {
self.define(identifier.id, typ, argument, location)
self.define(identifier.id, argument);
Ok(())
}
HirPattern::Mutable(pattern, _) => {
self.define_pattern(pattern, typ, argument, location)
Expand All @@ -222,8 +237,6 @@ impl<'a> Interpreter<'a> {
},
HirPattern::Struct(struct_type, pattern_fields, _) => {
self.push_scope();
self.type_check(typ, &argument, location)?;
self.type_check(struct_type, &argument, location)?;

let res = match argument {
Value::Struct(fields, struct_type) if fields.len() == pattern_fields.len() => {
Expand Down Expand Up @@ -259,30 +272,8 @@ impl<'a> Interpreter<'a> {
}

/// Define a new variable in the current scope
fn define(
&mut self,
id: DefinitionId,
typ: &Type,
argument: Value,
location: Location,
) -> IResult<()> {
// Temporarily disabled since this fails on generic types
// self.type_check(typ, &argument, location)?;
fn define(&mut self, id: DefinitionId, argument: Value) {
self.current_scope_mut().insert(id, argument);
Ok(())
}

/// Mutate an existing variable, potentially from a prior scope.
/// Also type checks the value being assigned
fn checked_mutate(
&mut self,
id: DefinitionId,
typ: &Type,
argument: Value,
location: Location,
) -> IResult<()> {
self.type_check(typ, &argument, location)?;
self.mutate(id, argument, location)
}

/// Mutate an existing variable, potentially from a prior scope
Expand Down Expand Up @@ -321,15 +312,6 @@ impl<'a> Interpreter<'a> {
Err(InterpreterError::NonComptimeVarReferenced { name, location })
}

fn type_check(&self, typ: &Type, value: &Value, location: Location) -> IResult<()> {
let typ = typ.follow_bindings();
let value_type = value.get_type();

typ.try_unify(&value_type, &mut TypeBindings::new()).map_err(|_| {
InterpreterError::TypeMismatch { expected: typ, value: value.clone(), location }
})
}

/// Evaluate an expression and return the result
pub fn evaluate(&mut self, id: ExprId) -> IResult<Value> {
match self.interner.expression(&id) {
Expand Down Expand Up @@ -367,8 +349,9 @@ impl<'a> Interpreter<'a> {

match &definition.kind {
DefinitionKind::Function(function_id) => {
let typ = self.interner.id_type(id);
Ok(Value::Function(*function_id, typ))
let typ = self.interner.id_type(id).follow_bindings();
let bindings = Rc::new(self.interner.get_instantiation_bindings(id).clone());
Ok(Value::Function(*function_id, typ, bindings))
}
DefinitionKind::Local(_) => self.lookup(&ident),
DefinitionKind::Global(global_id) => {
Expand Down Expand Up @@ -539,7 +522,7 @@ impl<'a> Interpreter<'a> {
}

fn evaluate_array(&mut self, array: HirArrayLiteral, id: ExprId) -> IResult<Value> {
let typ = self.interner.id_type(id);
let typ = self.interner.id_type(id).follow_bindings();

match array {
HirArrayLiteral::Standard(elements) => {
Expand Down Expand Up @@ -936,7 +919,7 @@ impl<'a> Interpreter<'a> {
})
.collect::<Result<_, _>>()?;

let typ = self.interner.id_type(id);
let typ = self.interner.id_type(id).follow_bindings();
Ok(Value::Struct(fields, typ))
}

Expand Down Expand Up @@ -977,7 +960,10 @@ impl<'a> Interpreter<'a> {
let location = self.interner.expr_location(&id);

match function {
Value::Function(function_id, _) => self.call_function(function_id, arguments, location),
Value::Function(function_id, _, bindings) => {
let bindings = unwrap_rc(bindings);
self.call_function(function_id, arguments, bindings, location)
}
Value::Closure(closure, env, _) => self.call_closure(closure, env, arguments, location),
value => Err(InterpreterError::NonFunctionCalled { value, location }),
}
Expand Down Expand Up @@ -1006,7 +992,7 @@ impl<'a> Interpreter<'a> {
};

if let Some(method) = method {
self.call_function(method, arguments, location)
self.call_function(method, arguments, TypeBindings::new(), location)
} else {
Err(InterpreterError::NoMethodFound { name: method_name.clone(), typ, location })
}
Expand Down Expand Up @@ -1151,7 +1137,7 @@ impl<'a> Interpreter<'a> {
let environment =
try_vecmap(&lambda.captures, |capture| self.lookup_id(capture.ident.id, location))?;

let typ = self.interner.id_type(id);
let typ = self.interner.id_type(id).follow_bindings();
Ok(Value::Closure(lambda, environment, typ))
}

Expand Down Expand Up @@ -1212,9 +1198,7 @@ impl<'a> Interpreter<'a> {

fn store_lvalue(&mut self, lvalue: HirLValue, rhs: Value) -> IResult<()> {
match lvalue {
HirLValue::Ident(ident, typ) => {
self.checked_mutate(ident.id, &typ, rhs, ident.location)
}
HirLValue::Ident(ident, typ) => self.mutate(ident.id, rhs, ident.location),
HirLValue::Dereference { lvalue, element_type: _, location } => {
match self.evaluate_lvalue(&lvalue)? {
Value::Pointer(value) => {
Expand All @@ -1233,7 +1217,7 @@ impl<'a> Interpreter<'a> {
}
Value::Struct(mut fields, typ) => {
fields.insert(Rc::new(field_name.0.contents), rhs);
self.store_lvalue(*object, Value::Struct(fields, typ))
self.store_lvalue(*object, Value::Struct(fields, typ.follow_bindings()))
}
value => {
Err(InterpreterError::NonTupleOrStructInMemberAccess { value, location })
Expand Down
17 changes: 16 additions & 1 deletion compiler/noirc_frontend/src/hir/comptime/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ fn interpret_helper(src: &str, func_namespace: Vec<String>) -> Result<Value, Int
let mut interpreter = Interpreter::new(&mut interner, &mut scopes, CrateId::Root(0));

let no_location = Location::dummy();
interpreter.call_function(main_id, Vec::new(), no_location)
interpreter.call_function(main_id, Vec::new(), HashMap::new(), no_location)
}

fn interpret(src: &str, func_namespace: Vec<String>) -> Value {
Expand Down Expand Up @@ -197,3 +197,18 @@ fn non_deterministic_recursion() {
let result = interpret(program, vec!["main".into(), "fib".into()]);
assert_eq!(result, Value::U64(55));
}

#[test]
fn generic_functions() {
let program = "
fn main() -> pub u8 {
apply(1, |x| x + 1)
}
fn apply<T, Env, U>(x: T, f: fn[Env](T) -> U) -> U {
f(x)
}
";
let result = interpret(program, vec!["main".into(), "apply".into()]);
assert!(matches!(result, Value::U8(2)));
}
20 changes: 13 additions & 7 deletions compiler/noirc_frontend/src/hir/comptime/value.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ use crate::{
node_interner::{ExprId, FuncId},
parser::{self, NoirParser, TopLevelStatement},
token::{SpannedToken, Token, Tokens},
QuotedType, Shared, Type,
QuotedType, Shared, Type, TypeBindings,
};
use rustc_hash::FxHashMap as HashMap;

Expand All @@ -36,7 +36,7 @@ pub enum Value {
U32(u32),
U64(u64),
String(Rc<String>),
Function(FuncId, Type),
Function(FuncId, Type, Rc<TypeBindings>),
Closure(HirLambda, Vec<Value>, Type),
Tuple(Vec<Value>),
Struct(HashMap<Rc<String>, Value>, Type),
Expand Down Expand Up @@ -65,7 +65,7 @@ impl Value {
let length = Type::Constant(value.len() as u32);
Type::String(Box::new(length))
}
Value::Function(_, typ) => return Cow::Borrowed(typ),
Value::Function(_, typ, _) => return Cow::Borrowed(typ),
Value::Closure(_, _, typ) => return Cow::Borrowed(typ),
Value::Tuple(fields) => {
Type::Tuple(vecmap(fields, |field| field.get_type().into_owned()))
Expand Down Expand Up @@ -128,13 +128,14 @@ impl Value {
ExpressionKind::Literal(Literal::Integer((value as u128).into(), false))
}
Value::String(value) => ExpressionKind::Literal(Literal::Str(unwrap_rc(value))),
Value::Function(id, typ) => {
Value::Function(id, typ, bindings) => {
let id = interner.function_definition_id(id);
let impl_kind = ImplKind::NotATraitMethod;
let ident = HirIdent { location, id, impl_kind };
let expr_id = interner.push_expr(HirExpression::Ident(ident, None));
interner.push_expr_location(expr_id, location.span, location.file);
interner.push_expr_type(expr_id, typ);
interner.store_instantiation_bindings(expr_id, unwrap_rc(bindings));
ExpressionKind::Resolved(expr_id)
}
Value::Closure(_lambda, _env, _typ) => {
Expand Down Expand Up @@ -247,10 +248,15 @@ impl Value {
HirExpression::Literal(HirLiteral::Integer((value as u128).into(), false))
}
Value::String(value) => HirExpression::Literal(HirLiteral::Str(unwrap_rc(value))),
Value::Function(id, _typ) => {
Value::Function(id, typ, bindings) => {
let id = interner.function_definition_id(id);
let impl_kind = ImplKind::NotATraitMethod;
HirExpression::Ident(HirIdent { location, id, impl_kind }, None)
let ident = HirIdent { location, id, impl_kind };
let expr_id = interner.push_expr(HirExpression::Ident(ident, None));
interner.push_expr_location(expr_id, location.span, location.file);
interner.push_expr_type(expr_id, typ);
interner.store_instantiation_bindings(expr_id, unwrap_rc(bindings));
return Ok(expr_id);
}
Value::Closure(_lambda, _env, _typ) => {
// TODO: How should a closure's environment be inlined?
Expand Down Expand Up @@ -362,7 +368,7 @@ impl Display for Value {
Value::U32(value) => write!(f, "{value}"),
Value::U64(value) => write!(f, "{value}"),
Value::String(value) => write!(f, "{value}"),
Value::Function(_, _) => write!(f, "(function)"),
Value::Function(..) => write!(f, "(function)"),
Value::Closure(_, _, _) => write!(f, "(closure)"),
Value::Tuple(fields) => {
let fields = vecmap(fields, ToString::to_string);
Expand Down
4 changes: 2 additions & 2 deletions compiler/noirc_frontend/src/monomorphization/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1816,13 +1816,13 @@ fn unwrap_struct_type(typ: &HirType) -> Vec<(String, HirType)> {
}
}

fn perform_instantiation_bindings(bindings: &TypeBindings) {
pub fn perform_instantiation_bindings(bindings: &TypeBindings) {
for (var, binding) in bindings.values() {
var.force_bind(binding.clone());
}
}

fn undo_instantiation_bindings(bindings: TypeBindings) {
pub fn undo_instantiation_bindings(bindings: TypeBindings) {
for (id, (var, _)) in bindings {
var.unbind(id);
}
Expand Down

0 comments on commit d8b9870

Please sign in to comment.