Skip to content

Commit

Permalink
feat(traits): Add impl Trait as function return type noir-lang#2397
Browse files Browse the repository at this point in the history
  • Loading branch information
ymadzhunkov committed Oct 16, 2023
1 parent 1c1afbc commit 88f174b
Show file tree
Hide file tree
Showing 14 changed files with 217 additions and 43 deletions.
11 changes: 11 additions & 0 deletions compiler/noirc_frontend/src/ast/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,9 @@ pub enum UnresolvedTypeData {
/// A Named UnresolvedType can be a struct type or a type variable
Named(Path, Vec<UnresolvedType>),

/// A Trait as return type or parameter of function, including it's generics
TraitAsType(Path, Vec<UnresolvedType>),

/// &mut T
MutableReference(Box<UnresolvedType>),

Expand Down Expand Up @@ -112,6 +115,14 @@ impl std::fmt::Display for UnresolvedTypeData {
write!(f, "{}<{}>", s, args.join(", "))
}
}
TraitAsType(s, args) => {
let args = vecmap(args, |arg| ToString::to_string(&arg.typ));
if args.is_empty() {
write!(f, "impl {s}")
} else {
write!(f, "impl {}<{}>", s, args.join(", "))
}
}
Tuple(elements) => {
let elements = vecmap(elements, ToString::to_string);
write!(f, "({})", elements.join(", "))
Expand Down
27 changes: 27 additions & 0 deletions compiler/noirc_frontend/src/hir/resolution/resolver.rs
Original file line number Diff line number Diff line change
Expand Up @@ -376,6 +376,8 @@ impl<'a> Resolver<'a> {
Unspecified => Type::Error,
Error => Type::Error,
Named(path, args) => self.resolve_named_type(path, args, new_variables),
TraitAsType(path, args) => self.resolve_trait_as_type(path, args, new_variables),

Tuple(fields) => {
Type::Tuple(vecmap(fields, |field| self.resolve_type_inner(field, new_variables)))
}
Expand Down Expand Up @@ -479,6 +481,19 @@ impl<'a> Resolver<'a> {
}
}

fn resolve_trait_as_type(
&mut self,
path: Path,
_args: Vec<UnresolvedType>,
_new_variables: &mut Generics,
) -> Type {
if let Some(t) = self.lookup_trait_or_error(path) {
Type::TraitAsType(t)
} else {
Type::Error
}
}

fn verify_generics_count(
&mut self,
expected_count: usize,
Expand Down Expand Up @@ -874,6 +889,7 @@ impl<'a> Resolver<'a> {
| Type::Constant(_)
| Type::NamedGeneric(_, _)
| Type::NotConstant
| Type::TraitAsType(_)
| Type::Forall(_, _) => (),

Type::Array(length, element_type) => {
Expand Down Expand Up @@ -1430,6 +1446,17 @@ impl<'a> Resolver<'a> {
}
}

/// Lookup a given trait by name/path.
fn lookup_trait_or_error(&mut self, path: Path) -> Option<Trait> {
match self.lookup(path) {
Ok(trait_id) => Some(self.get_trait(trait_id)),
Err(error) => {
self.push_err(error);
None
}
}
}

/// Looks up a given type by name.
/// This will also instantiate any struct types found.
fn lookup_type_or_error(&mut self, path: Path) -> Option<Type> {
Expand Down
7 changes: 5 additions & 2 deletions compiler/noirc_frontend/src/hir/type_check/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ impl<'interner> TypeChecker<'interner> {
}
}
}

/// Infers a type for a given expression, and return this type.
/// As a side-effect, this function will also remember this type in the NodeInterner
/// for the given expr_id key.
Expand All @@ -50,7 +51,7 @@ impl<'interner> TypeChecker<'interner> {
// E.g. `fn foo<T>(t: T, field: Field) -> T` has type `forall T. fn(T, Field) -> T`.
// We must instantiate identifiers at every call site to replace this T with a new type
// variable to handle generic functions.
let t = self.interner.id_type(ident.id);
let t = self.interner.id_type_substitute_trait_as_type(ident.id);
let (typ, bindings) = t.instantiate(self.interner);
self.interner.store_instantiation_bindings(*expr_id, bindings);
typ
Expand Down Expand Up @@ -131,7 +132,6 @@ impl<'interner> TypeChecker<'interner> {
HirExpression::Index(index_expr) => self.check_index_expression(expr_id, index_expr),
HirExpression::Call(call_expr) => {
self.check_if_deprecated(&call_expr.func);

let function = self.check_expression(&call_expr.func);
let args = vecmap(&call_expr.arguments, |arg| {
let typ = self.check_expression(arg);
Expand Down Expand Up @@ -839,6 +839,9 @@ impl<'interner> TypeChecker<'interner> {
}
}
}
Type::TraitAsType(_trait) => {
unreachable!("unexpected lookup on trait as return type")
}
Type::NamedGeneric(_, _) => {
let func_meta = self.interner.function_meta(
&self.current_function.expect("unexpected method outside a function"),
Expand Down
59 changes: 37 additions & 22 deletions compiler/noirc_frontend/src/hir/type_check/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ pub use errors::TypeCheckError;

use crate::{
hir_def::{expr::HirExpression, stmt::HirStatement},
node_interner::{ExprId, FuncId, NodeInterner, StmtId},
node_interner::{ExprId, FuncId, NodeInterner, StmtId, TraitImplKey},
Type,
};

Expand Down Expand Up @@ -63,30 +63,45 @@ pub fn type_check_func(interner: &mut NodeInterner, func_id: FuncId) -> Vec<Type
// Check declared return type and actual return type
if !can_ignore_ret {
let (expr_span, empty_function) = function_info(interner, function_body_id);

//let body_type = type_checker.check_expression(function_body_id);
let func_span = interner.expr_span(function_body_id); // XXX: We could be more specific and return the span of the last stmt, however stmts do not have spans yet
function_last_type.unify_with_coercions(
&declared_return_type,
*function_body_id,
interner,
&mut errors,
|| {
let mut error = TypeCheckError::TypeMismatchWithSource {
expected: declared_return_type.clone(),
actual: function_last_type.clone(),
span: func_span,
source: Source::Return(meta.return_type, expr_span),
};

if empty_function {
error = error.add_context(
if let Type::TraitAsType(t) = &declared_return_type {
let key = TraitImplKey { typ: function_last_type.follow_bindings(), trait_id: t.id };
match interner.get_trait_implementation(&key) {
Some(_implementation) => {}
None => {
let error = TypeCheckError::TypeMismatchWithSource {
expected: declared_return_type.clone(),
actual: function_last_type.clone(),
span: func_span,
source: Source::Return(meta.return_type, expr_span),
};
errors.push(error);
}
}
} else {
function_last_type.unify_with_coercions(
&declared_return_type,
*function_body_id,
interner,
&mut errors,
|| {
let mut error = TypeCheckError::TypeMismatchWithSource {
expected: declared_return_type.clone(),
actual: function_last_type.clone(),
span: func_span,
source: Source::Return(meta.return_type, expr_span),
};

if empty_function {
error = error.add_context(
"implicitly returns `()` as its body has no tail or `return` expression",
);
}

error
},
);
}
error
},
);
}
}

errors
Expand Down
2 changes: 1 addition & 1 deletion compiler/noirc_frontend/src/hir_def/traits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ pub struct TraitType {
/// Represents a trait in the type system. Each instance of this struct
/// will be shared across all Type::Trait variants that represent
/// the same trait.
#[derive(Clone, Debug)]
#[derive(Debug, Eq, Clone)]
pub struct Trait {
/// A unique id representing this trait type. Used to check if two
/// struct traits are equal.
Expand Down
29 changes: 22 additions & 7 deletions compiler/noirc_frontend/src/hir_def/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,10 @@ use noirc_printable_type::PrintableType;

use crate::{node_interner::StructId, Ident, Signedness};

use super::expr::{HirCallExpression, HirExpression, HirIdent};
use super::{
expr::{HirCallExpression, HirExpression, HirIdent},
traits::Trait,
};

#[derive(Debug, PartialEq, Eq, Clone, Hash)]
pub enum Type {
Expand Down Expand Up @@ -62,6 +65,8 @@ pub enum Type {
/// different argument types each time.
TypeVariable(TypeVariable, TypeVariableKind),

TraitAsType(Trait),

/// NamedGenerics are the 'T' or 'U' in a user-defined generic function
/// like `fn foo<T, U>(...) {}`. Unlike TypeVariables, they cannot be bound over.
NamedGeneric(TypeVariable, Rc<String>),
Expand Down Expand Up @@ -483,7 +488,8 @@ impl Type {
| Type::Constant(_)
| Type::NamedGeneric(_, _)
| Type::NotConstant
| Type::Forall(_, _) => false,
| Type::Forall(_, _)
| Type::TraitAsType(_) => false,

Type::Array(length, elem) => {
elem.contains_numeric_typevar(target_id) || named_generic_id_matches_target(length)
Expand Down Expand Up @@ -560,6 +566,9 @@ impl std::fmt::Display for Type {
write!(f, "{}<{}>", s.borrow(), args.join(", "))
}
}
Type::TraitAsType(tr) => {
write!(f, "impl {}", tr.name)
}
Type::Tuple(elements) => {
let elements = vecmap(elements, ToString::to_string);
write!(f, "({})", elements.join(", "))
Expand Down Expand Up @@ -1057,6 +1066,7 @@ impl Type {
let fields = vecmap(fields, |field| field.substitute(type_bindings));
Type::Tuple(fields)
}
Type::TraitAsType(_) => todo!(),
Type::Forall(typevars, typ) => {
// Trying to substitute a variable defined within a nested Forall
// is usually impossible and indicative of an error in the type checker somewhere.
Expand Down Expand Up @@ -1096,6 +1106,7 @@ impl Type {
let field_occurs = fields.occurs(target_id);
len_occurs || field_occurs
}
Type::TraitAsType(_) => todo!(),
Type::Struct(_, generic_args) => generic_args.iter().any(|arg| arg.occurs(target_id)),
Type::Tuple(fields) => fields.iter().any(|field| field.occurs(target_id)),
Type::NamedGeneric(binding, _) | Type::TypeVariable(binding, _) => {
Expand Down Expand Up @@ -1147,7 +1158,6 @@ impl Type {
Struct(def.clone(), args)
}
Tuple(args) => Tuple(vecmap(args, |arg| arg.follow_bindings())),

TypeVariable(var, _) | NamedGeneric(var, _) => {
if let TypeBinding::Bound(typ) = &*var.borrow() {
return typ.follow_bindings();
Expand All @@ -1166,10 +1176,14 @@ impl Type {

// Expect that this function should only be called on instantiated types
Forall(..) => unreachable!(),

FieldElement | Integer(_, _) | Bool | Constant(_) | Unit | Error | NotConstant => {
self.clone()
}
TraitAsType(_)
| FieldElement
| Integer(_, _)
| Bool
| Constant(_)
| Unit
| Error
| NotConstant => self.clone(),
}
}
}
Expand Down Expand Up @@ -1270,6 +1284,7 @@ impl From<&Type> for PrintableType {
let fields = vecmap(fields, |(name, typ)| (name, typ.into()));
PrintableType::Struct { fields, name: struct_type.name.to_string() }
}
Type::TraitAsType(name) => PrintableType::Trait { name: name.name.to_string() },
Type::Tuple(_) => todo!("printing tuple types is not yet implemented"),
Type::TypeVariable(_, _) => unreachable!(),
Type::NamedGeneric(..) => unreachable!(),
Expand Down
29 changes: 18 additions & 11 deletions compiler/noirc_frontend/src/monomorphization/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -222,9 +222,15 @@ impl<'interner> Monomorphizer<'interner> {
let modifiers = self.interner.function_modifiers(&f);
let name = self.interner.function_name(&f).to_owned();

let return_type = self.convert_type(meta.return_type());
let body_expr_id = *self.interner.function(&f).as_expr();
let body_return_type = self.interner.id_type(body_expr_id);
let return_type = self.convert_type(match meta.return_type() {
Type::TraitAsType(_) => &body_return_type,
_ => meta.return_type(),
});

let parameters = self.parameters(meta.parameters);
let body = self.expr(*self.interner.function(&f).as_expr());
let body = self.expr(body_expr_id);
let unconstrained = modifiers.is_unconstrained
|| matches!(modifiers.contract_function_type, Some(ContractFunctionType::Open));

Expand Down Expand Up @@ -381,8 +387,8 @@ impl<'interner> Monomorphizer<'interner> {
}
}

HirExpression::MethodCall(_) => {
unreachable!("Encountered HirExpression::MethodCall during monomorphization")
HirExpression::MethodCall(hir_method_call) => {
unreachable!("Encountered HirExpression::MethodCall during monomorphization {hir_method_call:?}")
}
HirExpression::Error => unreachable!("Encountered Error node during monomorphization"),
}
Expand Down Expand Up @@ -635,7 +641,6 @@ impl<'interner> Monomorphizer<'interner> {
let location = Some(ident.location);
let name = definition.name.clone();
let typ = self.interner.id_type(expr_id);

let definition = self.lookup_function(*func_id, expr_id, &typ);
let typ = self.convert_type(&typ);
let ident = ast::Ident { location, mutable, definition, name, typ: typ.clone() };
Expand Down Expand Up @@ -686,7 +691,6 @@ impl<'interner> Monomorphizer<'interner> {
ast::Type::FmtString(size, fields)
}
HirType::Unit => ast::Type::Unit,

HirType::Array(length, element) => {
let element = Box::new(self.convert_type(element.as_ref()));

Expand All @@ -696,7 +700,9 @@ impl<'interner> Monomorphizer<'interner> {
ast::Type::Slice(element)
}
}

HirType::TraitAsType(_) => {
unreachable!("All TraitAsType should be replaced before calling convert_type");
}
HirType::NamedGeneric(binding, _) => {
if let TypeBinding::Bound(binding) = &*binding.borrow() {
return self.convert_type(binding);
Expand Down Expand Up @@ -780,8 +786,7 @@ impl<'interner> Monomorphizer<'interner> {
}
}

fn is_function_closure(&self, raw_func_id: node_interner::ExprId) -> bool {
let t = self.convert_type(&self.interner.id_type(raw_func_id));
fn is_function_closure(&self, t: ast::Type) -> bool {
if self.is_function_closure_type(&t) {
true
} else if let ast::Type::Tuple(elements) = t {
Expand Down Expand Up @@ -850,6 +855,7 @@ impl<'interner> Monomorphizer<'interner> {
let func: Box<ast::Expression>;
let return_type = self.interner.id_type(id);
let return_type = self.convert_type(&return_type);

let location = call.location;

if let ast::Expression::Ident(ident) = original_func.as_ref() {
Expand All @@ -863,8 +869,9 @@ impl<'interner> Monomorphizer<'interner> {
}

let mut block_expressions = vec![];

let is_closure = self.is_function_closure(call.func);
let func_type = self.interner.id_type(call.func);
let func_type = self.convert_type(&func_type);
let is_closure = self.is_function_closure(func_type);
if is_closure {
let local_id = self.next_local_id();

Expand Down
Loading

0 comments on commit 88f174b

Please sign in to comment.