diff --git a/compiler/noirc_frontend/src/ast/traits.rs b/compiler/noirc_frontend/src/ast/traits.rs index 61ef6f6276d..723df775b1e 100644 --- a/compiler/noirc_frontend/src/ast/traits.rs +++ b/compiler/noirc_frontend/src/ast/traits.rs @@ -18,6 +18,7 @@ use super::{Documented, GenericTypeArgs, ItemVisibility}; pub struct NoirTrait { pub name: Ident, pub generics: UnresolvedGenerics, + pub bounds: Vec, pub where_clause: Vec, pub span: Span, pub items: Vec>, @@ -134,7 +135,12 @@ impl Display for NoirTrait { let generics = vecmap(&self.generics, |generic| generic.to_string()); let generics = if generics.is_empty() { "".into() } else { generics.join(", ") }; - writeln!(f, "trait {}{} {{", self.name, generics)?; + write!(f, "trait {}{}", self.name, generics)?; + if !self.bounds.is_empty() { + let bounds = vecmap(&self.bounds, |bound| bound.to_string()).join(" + "); + write!(f, ": {}", bounds)?; + } + writeln!(f, " {{")?; for item in self.items.iter() { let item = item.to_string(); diff --git a/compiler/noirc_frontend/src/elaborator/expressions.rs b/compiler/noirc_frontend/src/elaborator/expressions.rs index 06bcafc55c9..31a518ca97f 100644 --- a/compiler/noirc_frontend/src/elaborator/expressions.rs +++ b/compiler/noirc_frontend/src/elaborator/expressions.rs @@ -26,7 +26,7 @@ use crate::{ HirPrefixExpression, }, stmt::HirStatement, - traits::TraitConstraint, + traits::{ResolvedTraitBound, TraitConstraint}, }, node_interner::{DefinitionKind, ExprId, FuncId, InternedStatementKind, TraitMethodId}, token::Tokens, @@ -743,9 +743,11 @@ impl<'context> Elaborator<'context> { // that implements the trait. let constraint = TraitConstraint { typ: operand_type.clone(), - trait_id: trait_id.trait_id, - trait_generics: TraitGenerics::default(), - span, + trait_bound: ResolvedTraitBound { + trait_id: trait_id.trait_id, + trait_generics: TraitGenerics::default(), + span, + }, }; self.push_trait_constraint(constraint, expr_id); self.type_check_operator_method(expr_id, trait_id, operand_type, span); diff --git a/compiler/noirc_frontend/src/elaborator/mod.rs b/compiler/noirc_frontend/src/elaborator/mod.rs index a9173621fc7..fb901f3dd76 100644 --- a/compiler/noirc_frontend/src/elaborator/mod.rs +++ b/compiler/noirc_frontend/src/elaborator/mod.rs @@ -3,7 +3,7 @@ use std::{ rc::Rc, }; -use crate::{ast::ItemVisibility, StructField}; +use crate::{ast::ItemVisibility, hir_def::traits::ResolvedTraitBound, StructField, TypeBindings}; use crate::{ ast::{ BlockExpression, FunctionKind, GenericTypeArgs, Ident, NoirFunction, NoirStruct, Param, @@ -54,6 +54,7 @@ mod unquote; use fm::FileId; use iter_extended::vecmap; use noirc_errors::{Location, Span}; +use types::bind_ordered_generics; use self::traits::check_trait_impl_method_matches_declaration; @@ -433,7 +434,8 @@ impl<'context> Elaborator<'context> { // Now remove all the `where` clause constraints we added for constraint in &func_meta.trait_constraints { - self.interner.remove_assumed_trait_implementations_for_trait(constraint.trait_id); + self.interner + .remove_assumed_trait_implementations_for_trait(constraint.trait_bound.trait_id); } let func_scope_tree = self.scopes.end_function(); @@ -479,9 +481,9 @@ impl<'context> Elaborator<'context> { self.verify_trait_constraint( &constraint.typ, - constraint.trait_id, - &constraint.trait_generics.ordered, - &constraint.trait_generics.named, + constraint.trait_bound.trait_id, + &constraint.trait_bound.trait_generics.ordered, + &constraint.trait_bound.trait_generics.named, expr_id, span, ); @@ -510,7 +512,8 @@ impl<'context> Elaborator<'context> { let generic_type = Type::NamedGeneric(new_generic, Rc::new(name)); let trait_bound = TraitBound { trait_path, trait_id: None, trait_generics }; - if let Some(new_constraint) = self.resolve_trait_bound(&trait_bound, generic_type.clone()) { + if let Some(trait_bound) = self.resolve_trait_bound(&trait_bound) { + let new_constraint = TraitConstraint { typ: generic_type.clone(), trait_bound }; trait_constraints.push(new_constraint); } @@ -668,14 +671,11 @@ impl<'context> Elaborator<'context> { constraint: &UnresolvedTraitConstraint, ) -> Option { let typ = self.resolve_type(constraint.typ.clone()); - self.resolve_trait_bound(&constraint.trait_bound, typ) + let trait_bound = self.resolve_trait_bound(&constraint.trait_bound)?; + Some(TraitConstraint { typ, trait_bound }) } - pub fn resolve_trait_bound( - &mut self, - bound: &TraitBound, - typ: Type, - ) -> Option { + pub fn resolve_trait_bound(&mut self, bound: &TraitBound) -> Option { let the_trait = self.lookup_trait_or_error(bound.trait_path.clone())?; let trait_id = the_trait.id; let span = bound.trait_path.span; @@ -683,7 +683,7 @@ impl<'context> Elaborator<'context> { let (ordered, named) = self.resolve_type_args(bound.trait_generics.clone(), trait_id, span); let trait_generics = TraitGenerics { ordered, named }; - Some(TraitConstraint { typ, trait_id, trait_generics, span }) + Some(ResolvedTraitBound { trait_id, trait_generics, span }) } /// Extract metadata from a NoirFunction @@ -942,21 +942,52 @@ impl<'context> Elaborator<'context> { fn add_trait_constraints_to_scope(&mut self, func_meta: &FuncMeta) { for constraint in &func_meta.trait_constraints { - let object = constraint.typ.clone(); - let trait_id = constraint.trait_id; - let generics = constraint.trait_generics.clone(); - - if !self.interner.add_assumed_trait_implementation(object, trait_id, generics) { - if let Some(the_trait) = self.interner.try_get_trait(trait_id) { - let trait_name = the_trait.name.to_string(); - let typ = constraint.typ.clone(); - let span = func_meta.location.span; - self.push_err(TypeCheckError::UnneededTraitConstraint { - trait_name, - typ, - span, - }); + self.add_trait_bound_to_scope( + func_meta, + &constraint.typ, + &constraint.trait_bound, + constraint.trait_bound.trait_id, + ); + } + } + + fn add_trait_bound_to_scope( + &mut self, + func_meta: &FuncMeta, + object: &Type, + trait_bound: &ResolvedTraitBound, + starting_trait_id: TraitId, + ) { + let trait_id = trait_bound.trait_id; + let generics = trait_bound.trait_generics.clone(); + + if !self.interner.add_assumed_trait_implementation(object.clone(), trait_id, generics) { + if let Some(the_trait) = self.interner.try_get_trait(trait_id) { + let trait_name = the_trait.name.to_string(); + let typ = object.clone(); + let span = func_meta.location.span; + self.push_err(TypeCheckError::UnneededTraitConstraint { trait_name, typ, span }); + } + } + + // Also add assumed implementations for the parent traits, if any + if let Some(trait_bounds) = + self.interner.try_get_trait(trait_id).map(|the_trait| the_trait.trait_bounds.clone()) + { + for parent_trait_bound in trait_bounds { + // Avoid looping forever in case there are cycles + if parent_trait_bound.trait_id == starting_trait_id { + continue; } + + let parent_trait_bound = + self.instantiate_parent_trait_bound(trait_bound, &parent_trait_bound); + self.add_trait_bound_to_scope( + func_meta, + object, + &parent_trait_bound, + starting_trait_id, + ); } } } @@ -972,6 +1003,8 @@ impl<'context> Elaborator<'context> { self.file = trait_impl.file_id; self.local_module = trait_impl.module_id; + self.check_parent_traits_are_implemented(&trait_impl); + self.generics = trait_impl.resolved_generics; self.current_trait_impl = trait_impl.impl_id; @@ -988,6 +1021,73 @@ impl<'context> Elaborator<'context> { self.generics.clear(); } + fn check_parent_traits_are_implemented(&mut self, trait_impl: &UnresolvedTraitImpl) { + let Some(trait_id) = trait_impl.trait_id else { + return; + }; + + let Some(object_type) = &trait_impl.resolved_object_type else { + return; + }; + + let Some(the_trait) = self.interner.try_get_trait(trait_id) else { + return; + }; + + if the_trait.trait_bounds.is_empty() { + return; + } + + let impl_trait = the_trait.name.to_string(); + let the_trait_file = the_trait.location.file; + + let mut bindings = TypeBindings::new(); + bind_ordered_generics( + &the_trait.generics, + &trait_impl.resolved_trait_generics, + &mut bindings, + ); + + // Note: we only check if the immediate parents are implemented, we don't check recursively. + // Why? If a parent isn't implemented, we get an error. If a parent is implemented, we'll + // do the same check for the parent, so this trait's parents parents will be checked, so the + // recursion is guaranteed. + for parent_trait_bound in the_trait.trait_bounds.clone() { + let Some(parent_trait) = self.interner.try_get_trait(parent_trait_bound.trait_id) + else { + continue; + }; + + let parent_trait_bound = ResolvedTraitBound { + trait_generics: parent_trait_bound + .trait_generics + .map(|typ| typ.substitute(&bindings)), + ..parent_trait_bound + }; + + if self + .interner + .try_lookup_trait_implementation( + object_type, + parent_trait_bound.trait_id, + &parent_trait_bound.trait_generics.ordered, + &parent_trait_bound.trait_generics.named, + ) + .is_err() + { + let missing_trait = + format!("{}{}", parent_trait.name, parent_trait_bound.trait_generics); + self.push_err(ResolverError::TraitNotImplemented { + impl_trait: impl_trait.clone(), + missing_trait, + type_missing_trait: trait_impl.object_type.to_string(), + span: trait_impl.object_type.span, + missing_trait_location: Location::new(parent_trait_bound.span, the_trait_file), + }); + } + } + } + fn collect_impls( &mut self, module: LocalModuleId, diff --git a/compiler/noirc_frontend/src/elaborator/patterns.rs b/compiler/noirc_frontend/src/elaborator/patterns.rs index 0f6cb78fae7..4682afe2d97 100644 --- a/compiler/noirc_frontend/src/elaborator/patterns.rs +++ b/compiler/noirc_frontend/src/elaborator/patterns.rs @@ -655,7 +655,7 @@ impl<'context> Elaborator<'context> { if let ImplKind::TraitMethod(mut method) = ident.impl_kind { method.constraint.apply_bindings(&bindings); if method.assumed { - let trait_generics = method.constraint.trait_generics.clone(); + let trait_generics = method.constraint.trait_bound.trait_generics.clone(); let object_type = method.constraint.typ; let trait_impl = TraitImplKind::Assumed { object_type, trait_generics }; self.interner.select_impl_for_expression(expr_id, trait_impl); @@ -748,7 +748,7 @@ impl<'context> Elaborator<'context> { HirMethodReference::TraitMethodId(method_id, generics) => { let mut constraint = self.interner.get_trait(method_id.trait_id).as_constraint(span); - constraint.trait_generics = generics; + constraint.trait_bound.trait_generics = generics; ImplKind::TraitMethod(TraitMethod { method_id, constraint, assumed: false }) } }; diff --git a/compiler/noirc_frontend/src/elaborator/trait_impls.rs b/compiler/noirc_frontend/src/elaborator/trait_impls.rs index 858cfa5cdd6..20f048bed05 100644 --- a/compiler/noirc_frontend/src/elaborator/trait_impls.rs +++ b/compiler/noirc_frontend/src/elaborator/trait_impls.rs @@ -167,12 +167,14 @@ impl<'context> Elaborator<'context> { let mut substituted_method_ids = HashSet::default(); for method_constraint in method.trait_constraints.iter() { let substituted_constraint_type = method_constraint.typ.substitute(&bindings); - let substituted_trait_generics = - method_constraint.trait_generics.map(|generic| generic.substitute(&bindings)); + let substituted_trait_generics = method_constraint + .trait_bound + .trait_generics + .map(|generic| generic.substitute(&bindings)); substituted_method_ids.insert(( substituted_constraint_type, - method_constraint.trait_id, + method_constraint.trait_bound.trait_id, substituted_trait_generics, )); } @@ -180,7 +182,8 @@ impl<'context> Elaborator<'context> { for override_trait_constraint in override_meta.trait_constraints.clone() { let override_constraint_is_from_impl = trait_impl_where_clause.iter().any(|impl_constraint| { - impl_constraint.trait_id == override_trait_constraint.trait_id + impl_constraint.trait_bound.trait_id + == override_trait_constraint.trait_bound.trait_id }); if override_constraint_is_from_impl { continue; @@ -188,15 +191,16 @@ impl<'context> Elaborator<'context> { if !substituted_method_ids.contains(&( override_trait_constraint.typ.clone(), - override_trait_constraint.trait_id, - override_trait_constraint.trait_generics.clone(), + override_trait_constraint.trait_bound.trait_id, + override_trait_constraint.trait_bound.trait_generics.clone(), )) { - let the_trait = self.interner.get_trait(override_trait_constraint.trait_id); + let the_trait = + self.interner.get_trait(override_trait_constraint.trait_bound.trait_id); self.push_err(DefCollectorErrorKind::ImplIsStricterThanTrait { constraint_typ: override_trait_constraint.typ, constraint_name: the_trait.name.0.contents.clone(), - constraint_generics: override_trait_constraint.trait_generics, - constraint_span: override_trait_constraint.span, + constraint_generics: override_trait_constraint.trait_bound.trait_generics, + constraint_span: override_trait_constraint.trait_bound.span, trait_method_name: method.name.0.contents.clone(), trait_method_span: method.location.span, }); diff --git a/compiler/noirc_frontend/src/elaborator/traits.rs b/compiler/noirc_frontend/src/elaborator/traits.rs index b4042bd3e31..e877682972c 100644 --- a/compiler/noirc_frontend/src/elaborator/traits.rs +++ b/compiler/noirc_frontend/src/elaborator/traits.rs @@ -10,8 +10,11 @@ use crate::{ UnresolvedTraitConstraint, UnresolvedType, }, hir::{def_collector::dc_crate::UnresolvedTrait, type_check::TypeCheckError}, - hir_def::{function::Parameters, traits::TraitFunction}, - node_interner::{FuncId, NodeInterner, ReferenceId, TraitId}, + hir_def::{ + function::Parameters, + traits::{ResolvedTraitBound, TraitFunction}, + }, + node_interner::{DependencyId, FuncId, NodeInterner, ReferenceId, TraitId}, ResolvedGeneric, Type, TypeBindings, }; @@ -34,10 +37,17 @@ impl<'context> Elaborator<'context> { this.generics.push(associated_type.clone()); } + let resolved_trait_bounds = this.resolve_trait_bounds(unresolved_trait); + for bound in &resolved_trait_bounds { + this.interner + .add_trait_dependency(DependencyId::Trait(bound.trait_id), *trait_id); + } + let methods = this.resolve_trait_methods(*trait_id, unresolved_trait); this.interner.update_trait(*trait_id, |trait_def| { trait_def.set_methods(methods); + trait_def.set_trait_bounds(resolved_trait_bounds); }); }); @@ -53,6 +63,14 @@ impl<'context> Elaborator<'context> { self.current_trait = None; } + fn resolve_trait_bounds( + &mut self, + unresolved_trait: &UnresolvedTrait, + ) -> Vec { + let bounds = &unresolved_trait.trait_def.bounds; + bounds.iter().filter_map(|bound| self.resolve_trait_bound(bound)).collect() + } + fn resolve_trait_methods( &mut self, trait_id: TraitId, diff --git a/compiler/noirc_frontend/src/elaborator/types.rs b/compiler/noirc_frontend/src/elaborator/types.rs index 82d14743428..8ffbd15bdab 100644 --- a/compiler/noirc_frontend/src/elaborator/types.rs +++ b/compiler/noirc_frontend/src/elaborator/types.rs @@ -28,7 +28,7 @@ use crate::{ }, function::{FuncMeta, Parameters}, stmt::HirStatement, - traits::{NamedType, TraitConstraint}, + traits::{NamedType, ResolvedTraitBound, Trait, TraitConstraint}, }, node_interner::{ DefinitionKind, DependencyId, ExprId, GlobalId, ImplSearchErrorKind, NodeInterner, TraitId, @@ -596,7 +596,7 @@ impl<'context> Elaborator<'context> { continue; } - let the_trait = self.interner.get_trait(constraint.trait_id); + let the_trait = self.interner.get_trait(constraint.trait_bound.trait_id); if let Some(method) = the_trait.find_method(path.last_name()) { return Some(TraitPathResolution { method: TraitMethod { method_id: method, constraint, assumed: true }, @@ -1376,15 +1376,16 @@ impl<'context> Elaborator<'context> { for constraint in &func_meta.trait_constraints { if *object_type == constraint.typ { - if let Some(the_trait) = self.interner.try_get_trait(constraint.trait_id) { - for (method_index, method) in the_trait.methods.iter().enumerate() { - if method.name.0.contents == method_name { - let trait_method = - TraitMethodId { trait_id: constraint.trait_id, method_index }; - - let generics = constraint.trait_generics.clone(); - return Some(HirMethodReference::TraitMethodId(trait_method, generics)); - } + if let Some(the_trait) = + self.interner.try_get_trait(constraint.trait_bound.trait_id) + { + if let Some(method) = self.lookup_method_in_trait( + the_trait, + method_name, + &constraint.trait_bound, + the_trait.id, + ) { + return Some(method); } } } @@ -1399,6 +1400,44 @@ impl<'context> Elaborator<'context> { None } + fn lookup_method_in_trait( + &self, + the_trait: &Trait, + method_name: &str, + trait_bound: &ResolvedTraitBound, + starting_trait_id: TraitId, + ) -> Option { + if let Some(trait_method) = the_trait.find_method(method_name) { + return Some(HirMethodReference::TraitMethodId( + trait_method, + trait_bound.trait_generics.clone(), + )); + } + + // Search in the parent traits, if any + for parent_trait_bound in &the_trait.trait_bounds { + if let Some(the_trait) = self.interner.try_get_trait(parent_trait_bound.trait_id) { + // Avoid looping forever in case there are cycles + if the_trait.id == starting_trait_id { + continue; + } + + let parent_trait_bound = + self.instantiate_parent_trait_bound(trait_bound, parent_trait_bound); + if let Some(method) = self.lookup_method_in_trait( + the_trait, + method_name, + &parent_trait_bound, + starting_trait_id, + ) { + return Some(method); + } + } + } + + None + } + pub(super) fn type_check_call( &mut self, call: &HirCallExpression, @@ -1801,55 +1840,86 @@ impl<'context> Elaborator<'context> { } pub fn bind_generics_from_trait_constraint( - &mut self, + &self, constraint: &TraitConstraint, assumed: bool, bindings: &mut TypeBindings, ) { - let the_trait = self.interner.get_trait(constraint.trait_id); - assert_eq!(the_trait.generics.len(), constraint.trait_generics.ordered.len()); - - for (param, arg) in the_trait.generics.iter().zip(&constraint.trait_generics.ordered) { - // Avoid binding t = t - if !arg.occurs(param.type_var.id()) { - bindings.insert( - param.type_var.id(), - (param.type_var.clone(), param.kind(), arg.clone()), - ); - } - } - - let mut associated_types = the_trait.associated_types.clone(); - assert_eq!(associated_types.len(), constraint.trait_generics.named.len()); - - for arg in &constraint.trait_generics.named { - let i = associated_types - .iter() - .position(|typ| *typ.name == arg.name.0.contents) - .unwrap_or_else(|| { - unreachable!("Expected to find associated type named {}", arg.name) - }); - - let param = associated_types.swap_remove(i); - - // Avoid binding t = t - if !arg.typ.occurs(param.type_var.id()) { - bindings.insert( - param.type_var.id(), - (param.type_var.clone(), param.kind(), arg.typ.clone()), - ); - } - } + self.bind_generics_from_trait_bound(&constraint.trait_bound, bindings); // If the trait impl is already assumed to exist we should add any type bindings for `Self`. // Otherwise `self` will be replaced with a fresh type variable, which will require the user // to specify a redundant type annotation. if assumed { + let the_trait = self.interner.get_trait(constraint.trait_bound.trait_id); let self_type = the_trait.self_type_typevar.clone(); let kind = the_trait.self_type_typevar.kind(); bindings.insert(self_type.id(), (self_type, kind, constraint.typ.clone())); } } + + pub fn bind_generics_from_trait_bound( + &self, + trait_bound: &ResolvedTraitBound, + bindings: &mut TypeBindings, + ) { + let the_trait = self.interner.get_trait(trait_bound.trait_id); + + bind_ordered_generics(&the_trait.generics, &trait_bound.trait_generics.ordered, bindings); + + let associated_types = the_trait.associated_types.clone(); + bind_named_generics(associated_types, &trait_bound.trait_generics.named, bindings); + } + + pub fn instantiate_parent_trait_bound( + &self, + trait_bound: &ResolvedTraitBound, + parent_trait_bound: &ResolvedTraitBound, + ) -> ResolvedTraitBound { + let mut bindings = TypeBindings::new(); + self.bind_generics_from_trait_bound(trait_bound, &mut bindings); + ResolvedTraitBound { + trait_generics: parent_trait_bound.trait_generics.map(|typ| typ.substitute(&bindings)), + ..*parent_trait_bound + } + } +} + +pub(crate) fn bind_ordered_generics( + params: &[ResolvedGeneric], + args: &[Type], + bindings: &mut TypeBindings, +) { + assert_eq!(params.len(), args.len()); + + for (param, arg) in params.iter().zip(args) { + bind_generic(param, arg, bindings); + } +} + +pub(crate) fn bind_named_generics( + mut params: Vec, + args: &[NamedType], + bindings: &mut TypeBindings, +) { + assert_eq!(params.len(), args.len()); + + for arg in args { + let i = params + .iter() + .position(|typ| *typ.name == arg.name.0.contents) + .unwrap_or_else(|| unreachable!("Expected to find associated type named {}", arg.name)); + + let param = params.swap_remove(i); + bind_generic(¶m, &arg.typ, bindings); + } +} + +fn bind_generic(param: &ResolvedGeneric, arg: &Type, bindings: &mut TypeBindings) { + // Avoid binding t = t + if !arg.occurs(param.type_var.id()) { + bindings.insert(param.type_var.id(), (param.type_var.clone(), param.kind(), arg.clone())); + } } pub fn try_eval_array_length_id( diff --git a/compiler/noirc_frontend/src/hir/comptime/display.rs b/compiler/noirc_frontend/src/hir/comptime/display.rs index fa45c41f8ec..60661211a09 100644 --- a/compiler/noirc_frontend/src/hir/comptime/display.rs +++ b/compiler/noirc_frontend/src/hir/comptime/display.rs @@ -509,8 +509,11 @@ impl<'token, 'interner> Display for TokenPrinter<'token, 'interner> { } fn display_trait_constraint(interner: &NodeInterner, trait_constraint: &TraitConstraint) -> String { - let trait_ = interner.get_trait(trait_constraint.trait_id); - format!("{}: {}{}", trait_constraint.typ, trait_.name, trait_constraint.trait_generics) + let trait_ = interner.get_trait(trait_constraint.trait_bound.trait_id); + format!( + "{}: {}{}", + trait_constraint.typ, trait_.name, trait_constraint.trait_bound.trait_generics + ) } // Returns a new Expression where all Interned and Resolved expressions have been turned into non-interned ExpressionKind. diff --git a/compiler/noirc_frontend/src/hir/comptime/interpreter/builtin.rs b/compiler/noirc_frontend/src/hir/comptime/interpreter/builtin.rs index 170cd67e146..bcda4f713b7 100644 --- a/compiler/noirc_frontend/src/hir/comptime/interpreter/builtin.rs +++ b/compiler/noirc_frontend/src/hir/comptime/interpreter/builtin.rs @@ -748,7 +748,7 @@ fn quoted_as_trait_constraint( )?; let bound = interpreter .elaborate_in_function(interpreter.current_function, |elaborator| { - elaborator.resolve_trait_bound(&trait_bound, Type::Unit) + elaborator.resolve_trait_bound(&trait_bound) }) .ok_or(InterpreterError::FailedToResolveTraitBound { trait_bound, location })?; @@ -2733,7 +2733,7 @@ fn trait_def_as_trait_constraint( let trait_id = get_trait_def(argument)?; let constraint = interner.get_trait(trait_id).as_constraint(location.span); - Ok(Value::TraitConstraint(trait_id, constraint.trait_generics)) + Ok(Value::TraitConstraint(trait_id, constraint.trait_bound.trait_generics)) } /// Creates a value that holds an `Option`. diff --git a/compiler/noirc_frontend/src/hir/resolution/errors.rs b/compiler/noirc_frontend/src/hir/resolution/errors.rs index a6eb1864d13..4f9907d6a16 100644 --- a/compiler/noirc_frontend/src/hir/resolution/errors.rs +++ b/compiler/noirc_frontend/src/hir/resolution/errors.rs @@ -1,6 +1,6 @@ use acvm::FieldElement; pub use noirc_errors::Span; -use noirc_errors::{CustomDiagnostic as Diagnostic, FileDiagnostic}; +use noirc_errors::{CustomDiagnostic as Diagnostic, FileDiagnostic, Location}; use thiserror::Error; use crate::{ @@ -150,6 +150,14 @@ pub enum ResolverError { AttributeFunctionIsNotAPath { function: String, span: Span }, #[error("Attribute function `{name}` is not in scope")] AttributeFunctionNotInScope { name: String, span: Span }, + #[error("The trait `{missing_trait}` is not implemented for `{type_missing_trait}")] + TraitNotImplemented { + impl_trait: String, + missing_trait: String, + type_missing_trait: String, + span: Span, + missing_trait_location: Location, + }, } impl ResolverError { @@ -579,6 +587,14 @@ impl<'a> From<&'a ResolverError> for Diagnostic { *span, ) }, + ResolverError::TraitNotImplemented { impl_trait, missing_trait: the_trait, type_missing_trait: typ, span, missing_trait_location} => { + let mut diagnostic = Diagnostic::simple_error( + format!("The trait bound `{typ}: {the_trait}` is not satisfied"), + format!("The trait `{the_trait}` is not implemented for `{typ}") + , *span); + diagnostic.add_secondary_with_file(format!("required by this bound in `{impl_trait}"), missing_trait_location.span, missing_trait_location.file); + diagnostic + }, } } } diff --git a/compiler/noirc_frontend/src/hir/type_check/errors.rs b/compiler/noirc_frontend/src/hir/type_check/errors.rs index d8dae1f6549..99de6bca434 100644 --- a/compiler/noirc_frontend/src/hir/type_check/errors.rs +++ b/compiler/noirc_frontend/src/hir/type_check/errors.rs @@ -506,8 +506,8 @@ impl NoMatchingImplFoundError { let constraints = failing_constraints .into_iter() .map(|constraint| { - let r#trait = interner.try_get_trait(constraint.trait_id)?; - let name = format!("{}{}", r#trait.name, constraint.trait_generics); + let r#trait = interner.try_get_trait(constraint.trait_bound.trait_id)?; + let name = format!("{}{}", r#trait.name, constraint.trait_bound.trait_generics); Some((constraint.typ, name)) }) .collect::>>()?; diff --git a/compiler/noirc_frontend/src/hir/type_check/generics.rs b/compiler/noirc_frontend/src/hir/type_check/generics.rs index 86fc2d25d4e..370223f1f11 100644 --- a/compiler/noirc_frontend/src/hir/type_check/generics.rs +++ b/compiler/noirc_frontend/src/hir/type_check/generics.rs @@ -133,6 +133,10 @@ impl TraitGenerics { vecmap(&self.named, |named| NamedType { name: named.name.clone(), typ: f(&named.typ) }); TraitGenerics { ordered, named } } + + pub fn is_empty(&self) -> bool { + self.ordered.is_empty() && self.named.is_empty() + } } impl std::fmt::Display for TraitGenerics { diff --git a/compiler/noirc_frontend/src/hir_def/expr.rs b/compiler/noirc_frontend/src/hir_def/expr.rs index 71e0256a7e8..5d3fe632a74 100644 --- a/compiler/noirc_frontend/src/hir_def/expr.rs +++ b/compiler/noirc_frontend/src/hir_def/expr.rs @@ -11,7 +11,7 @@ use crate::token::Tokens; use crate::Shared; use super::stmt::HirPattern; -use super::traits::TraitConstraint; +use super::traits::{ResolvedTraitBound, TraitConstraint}; use super::types::{StructType, Type}; /// A HirExpression is the result of an Expression in the AST undergoing @@ -250,9 +250,11 @@ impl HirMethodCallExpression { let id = interner.trait_method_id(method_id); let constraint = TraitConstraint { typ: object_type, - trait_id: method_id.trait_id, - trait_generics, - span: location.span, + trait_bound: ResolvedTraitBound { + trait_id: method_id.trait_id, + trait_generics, + span: location.span, + }, }; (id, ImplKind::TraitMethod(TraitMethod { method_id, constraint, assumed: false })) } diff --git a/compiler/noirc_frontend/src/hir_def/traits.rs b/compiler/noirc_frontend/src/hir_def/traits.rs index 3859db26e39..534805c2dad 100644 --- a/compiler/noirc_frontend/src/hir_def/traits.rs +++ b/compiler/noirc_frontend/src/hir_def/traits.rs @@ -72,6 +72,9 @@ pub struct Trait { /// match the definition in the trait, we bind this TypeVariable to whatever /// the correct Self type is for that particular impl block. pub self_type_typevar: TypeVariable, + + /// The resolved trait bounds (for example in `trait Foo: Bar + Baz`, this would be `Bar + Baz`) + pub trait_bounds: Vec, } #[derive(Debug)] @@ -101,15 +104,25 @@ pub struct TraitImpl { #[derive(Debug, Clone, PartialEq, Eq)] pub struct TraitConstraint { pub typ: Type, - pub trait_id: TraitId, - pub trait_generics: TraitGenerics, - pub span: Span, + pub trait_bound: ResolvedTraitBound, } impl TraitConstraint { pub fn apply_bindings(&mut self, type_bindings: &TypeBindings) { self.typ = self.typ.substitute(type_bindings); + self.trait_bound.apply_bindings(type_bindings); + } +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct ResolvedTraitBound { + pub trait_id: TraitId, + pub trait_generics: TraitGenerics, + pub span: Span, +} +impl ResolvedTraitBound { + pub fn apply_bindings(&mut self, type_bindings: &TypeBindings) { for typ in &mut self.trait_generics.ordered { *typ = typ.substitute(type_bindings); } @@ -137,6 +150,10 @@ impl Trait { self.methods = methods; } + pub fn set_trait_bounds(&mut self, trait_bounds: Vec) { + self.trait_bounds = trait_bounds; + } + pub fn find_method(&self, name: &str) -> Option { for (idx, method) in self.methods.iter().enumerate() { if &method.name == name { @@ -169,9 +186,11 @@ impl Trait { TraitConstraint { typ: Type::TypeVariable(self.self_type_typevar.clone()), - trait_generics: TraitGenerics { ordered, named }, - trait_id: self.id, - span, + trait_bound: ResolvedTraitBound { + trait_generics: TraitGenerics { ordered, named }, + trait_id: self.id, + span, + }, } } } diff --git a/compiler/noirc_frontend/src/node_interner.rs b/compiler/noirc_frontend/src/node_interner.rs index b80c37c2ce4..ca7e0c6aa59 100644 --- a/compiler/noirc_frontend/src/node_interner.rs +++ b/compiler/noirc_frontend/src/node_interner.rs @@ -24,6 +24,7 @@ use crate::hir::def_map::DefMaps; use crate::hir::def_map::{LocalModuleId, ModuleDefId, ModuleId}; use crate::hir::type_check::generics::TraitGenerics; use crate::hir_def::traits::NamedType; +use crate::hir_def::traits::ResolvedTraitBound; use crate::usage_tracker::UnusedItem; use crate::usage_tracker::UsageTracker; use crate::QuotedType; @@ -291,6 +292,7 @@ pub enum DependencyId { Global(GlobalId), Function(FuncId), Alias(TypeAliasId), + Trait(TraitId), Variable(Location), } @@ -732,6 +734,7 @@ impl NodeInterner { methods: Vec::new(), method_ids: unresolved_trait.method_ids.clone(), associated_types, + trait_bounds: Vec::new(), }; self.traits.insert(type_id, new_trait); @@ -1531,9 +1534,11 @@ impl NodeInterner { let named = trait_associated_types.to_vec(); TraitConstraint { typ: object_type.clone(), - trait_id, - trait_generics: TraitGenerics { ordered, named }, - span: Span::default(), + trait_bound: ResolvedTraitBound { + trait_id, + trait_generics: TraitGenerics { ordered, named }, + span: Span::default(), + }, } }; @@ -1613,9 +1618,11 @@ impl NodeInterner { let constraint = TraitConstraint { typ: existing_object_type, - trait_id, - trait_generics, - span: Span::default(), + trait_bound: ResolvedTraitBound { + trait_id, + trait_generics, + span: Span::default(), + }, }; matching_impls.push((impl_kind.clone(), fresh_bindings, constraint)); } @@ -1635,8 +1642,8 @@ impl NodeInterner { Err(ImplSearchErrorKind::Nested(errors)) } else { let impls = vecmap(matching_impls, |(_, _, constraint)| { - let name = &self.get_trait(constraint.trait_id).name; - format!("{}: {name}{}", constraint.typ, constraint.trait_generics) + let name = &self.get_trait(constraint.trait_bound.trait_id).name; + format!("{}: {name}{}", constraint.typ, constraint.trait_bound.trait_generics) }); Err(ImplSearchErrorKind::MultipleMatching(impls)) } @@ -1658,20 +1665,22 @@ impl NodeInterner { let constraint_type = constraint.typ.force_substitute(instantiation_bindings).substitute(type_bindings); - let trait_generics = vecmap(&constraint.trait_generics.ordered, |generic| { - generic.force_substitute(instantiation_bindings).substitute(type_bindings) - }); + let trait_generics = + vecmap(&constraint.trait_bound.trait_generics.ordered, |generic| { + generic.force_substitute(instantiation_bindings).substitute(type_bindings) + }); - let trait_associated_types = vecmap(&constraint.trait_generics.named, |generic| { - let typ = generic.typ.force_substitute(instantiation_bindings); - NamedType { name: generic.name.clone(), typ: typ.substitute(type_bindings) } - }); + let trait_associated_types = + vecmap(&constraint.trait_bound.trait_generics.named, |generic| { + let typ = generic.typ.force_substitute(instantiation_bindings); + NamedType { name: generic.name.clone(), typ: typ.substitute(type_bindings) } + }); // We can ignore any associated types on the constraint since those should not affect // which impl we choose. self.lookup_trait_implementation_helper( &constraint_type, - constraint.trait_id, + constraint.trait_bound.trait_id, &trait_generics, &trait_associated_types, // Use a fresh set of type bindings here since the constraint_type originates from @@ -2016,6 +2025,10 @@ impl NodeInterner { self.add_dependency(dependent, DependencyId::Alias(dependency)); } + pub fn add_trait_dependency(&mut self, dependent: DependencyId, dependency: TraitId) { + self.add_dependency(dependent, DependencyId::Trait(dependency)); + } + pub fn add_dependency(&mut self, dependent: DependencyId, dependency: DependencyId) { let dependent_index = self.get_or_insert_dependency(dependent); let dependency_index = self.get_or_insert_dependency(dependency); @@ -2071,6 +2084,11 @@ impl NodeInterner { push_error(alias.name.to_string(), &scc, i, alias.location); break; } + DependencyId::Trait(trait_id) => { + let the_trait = self.get_trait(trait_id); + push_error(the_trait.name.to_string(), &scc, i, the_trait.location); + break; + } // Mutually recursive functions are allowed DependencyId::Function(_) => (), // Local variables should never be in a dependency cycle, scoping rules @@ -2099,6 +2117,7 @@ impl NodeInterner { DependencyId::Global(id) => { Cow::Borrowed(self.get_global(id).ident.0.contents.as_ref()) } + DependencyId::Trait(id) => Cow::Owned(self.get_trait(id).name.to_string()), DependencyId::Variable(loc) => { unreachable!("Variable used at location {loc:?} caught in a dependency cycle") } diff --git a/compiler/noirc_frontend/src/parser/parser/traits.rs b/compiler/noirc_frontend/src/parser/parser/traits.rs index 3bae152e75f..fead6a34c82 100644 --- a/compiler/noirc_frontend/src/parser/parser/traits.rs +++ b/compiler/noirc_frontend/src/parser/parser/traits.rs @@ -11,7 +11,7 @@ use super::parse_many::without_separator; use super::Parser; impl<'a> Parser<'a> { - /// Trait = 'trait' identifier Generics WhereClause TraitBody + /// Trait = 'trait' identifier Generics ( ':' TraitBounds )? WhereClause TraitBody pub(crate) fn parse_trait( &mut self, attributes: Vec<(Attribute, Span)>, @@ -26,12 +26,14 @@ impl<'a> Parser<'a> { }; let generics = self.parse_generics(); + let bounds = if self.eat_colon() { self.parse_trait_bounds() } else { Vec::new() }; let where_clause = self.parse_where_clause(); let items = self.parse_trait_body(); NoirTrait { name, generics, + bounds, where_clause, span: self.span_since(start_span), items, @@ -180,6 +182,7 @@ fn empty_trait( NoirTrait { name: Ident::default(), generics: Vec::new(), + bounds: Vec::new(), where_clause: Vec::new(), span, items: Vec::new(), @@ -292,4 +295,16 @@ mod tests { }; assert!(body.is_some()); } + + #[test] + fn parse_trait_inheirtance() { + let src = "trait Foo: Bar + Baz {}"; + let noir_trait = parse_trait_no_errors(src); + assert_eq!(noir_trait.bounds.len(), 2); + + assert_eq!(noir_trait.bounds[0].to_string(), "Bar"); + assert_eq!(noir_trait.bounds[1].to_string(), "Baz"); + + assert_eq!(noir_trait.to_string(), "trait Foo: Bar + Baz {\n}"); + } } diff --git a/compiler/noirc_frontend/src/tests.rs b/compiler/noirc_frontend/src/tests.rs index 8b54095973c..b2800717d90 100644 --- a/compiler/noirc_frontend/src/tests.rs +++ b/compiler/noirc_frontend/src/tests.rs @@ -4,6 +4,7 @@ mod bound_checks; mod imports; mod name_shadowing; mod references; +mod traits; mod turbofish; mod unused_items; mod visibility; diff --git a/compiler/noirc_frontend/src/tests/traits.rs b/compiler/noirc_frontend/src/tests/traits.rs new file mode 100644 index 00000000000..ee84cc0e890 --- /dev/null +++ b/compiler/noirc_frontend/src/tests/traits.rs @@ -0,0 +1,150 @@ +use crate::hir::def_collector::dc_crate::CompilationError; +use crate::hir::resolution::errors::ResolverError; +use crate::tests::get_program_errors; + +use super::assert_no_errors; + +#[test] +fn trait_inheritance() { + let src = r#" + pub trait Foo { + fn foo(self) -> Field; + } + + pub trait Bar { + fn bar(self) -> Field; + } + + pub trait Baz: Foo + Bar { + fn baz(self) -> Field; + } + + pub fn foo(baz: T) -> (Field, Field, Field) where T: Baz { + (baz.foo(), baz.bar(), baz.baz()) + } + + fn main() {} + "#; + assert_no_errors(src); +} + +#[test] +fn trait_inheritance_with_generics() { + let src = r#" + trait Foo { + fn foo(self) -> T; + } + + trait Bar: Foo { + fn bar(self); + } + + pub fn foo(x: T) -> i32 where T: Bar { + x.foo() + } + + fn main() {} + "#; + assert_no_errors(src); +} + +#[test] +fn trait_inheritance_with_generics_2() { + let src = r#" + pub trait Foo { + fn foo(self) -> T; + } + + pub trait Bar: Foo { + fn bar(self) -> (T, U); + } + + pub fn foo(x: T) -> i32 where T: Bar { + x.foo() + } + + fn main() {} + "#; + assert_no_errors(src); +} + +#[test] +fn trait_inheritance_with_generics_3() { + let src = r#" + trait Foo {} + + trait Bar: Foo {} + + impl Foo for () {} + + impl Bar for () {} + + fn main() {} + "#; + assert_no_errors(src); +} + +#[test] +fn trait_inheritance_with_generics_4() { + let src = r#" + trait Foo { type A; } + + trait Bar: Foo {} + + impl Foo for () { type A = i32; } + + impl Bar for () {} + + fn main() {} + "#; + assert_no_errors(src); +} + +#[test] +fn trait_inheritance_dependency_cycle() { + let src = r#" + trait Foo: Bar {} + trait Bar: Foo {} + fn main() {} + "#; + let errors = get_program_errors(src); + assert_eq!(errors.len(), 1); + + assert!(matches!( + errors[0].0, + CompilationError::ResolverError(ResolverError::DependencyCycle { .. }) + )); +} + +#[test] +fn trait_inheritance_missing_parent_implementation() { + let src = r#" + pub trait Foo {} + + pub trait Bar: Foo {} + + pub struct Struct {} + + impl Bar for Struct {} + + fn main() { + let _ = Struct {}; // silence Struct never constructed warning + } + "#; + let errors = get_program_errors(src); + assert_eq!(errors.len(), 1); + + let CompilationError::ResolverError(ResolverError::TraitNotImplemented { + impl_trait, + missing_trait: the_trait, + type_missing_trait: typ, + .. + }) = &errors[0].0 + else { + panic!("Expected a TraitNotImplemented error, got {:?}", &errors[0].0); + }; + + assert_eq!(the_trait, "Foo"); + assert_eq!(typ, "Struct"); + assert_eq!(impl_trait, "Bar"); +} diff --git a/cspell.json b/cspell.json index dbc5fb5a43e..6fd25a77182 100644 --- a/cspell.json +++ b/cspell.json @@ -194,8 +194,11 @@ "stdlib", "structs", "subexpression", + "subtrait", "subshell", "subtyping", + "supertrait", + "supertraits", "swcurve", "Taiko", "tarjan", diff --git a/docs/docs/noir/concepts/traits.md b/docs/docs/noir/concepts/traits.md index b3235a1a29b..9da00a77587 100644 --- a/docs/docs/noir/concepts/traits.md +++ b/docs/docs/noir/concepts/traits.md @@ -464,6 +464,32 @@ Since we have an impl for our own type, the behavior of this code will not chang to provide its own `impl Default for Foo`. The downside of this pattern is that it requires extra wrapping and unwrapping of values when converting to and from the `Wrapper` and `Foo` types. +### Trait Inheritance + +Sometimes, you might need one trait to use another trait’s functionality (like "inheritance" in some other languages). In this case, you can specify this relationship by listing any child traits after the parent trait's name and a colon. Now, whenever the parent trait is implemented it will require the child traits to be implemented as well. A parent trait is also called a "super trait." + +```rust +trait Person { + fn name(self) -> String; +} + +// Person is a supertrait of Student. +// Implementing Student requires you to also impl Person. +trait Student: Person { + fn university(self) -> String; +} + +trait Programmer { + fn fav_language(self) -> String; +} + +// CompSciStudent (computer science student) is a subtrait of both Programmer +// and Student. Implementing CompSciStudent requires you to impl both supertraits. +trait CompSciStudent: Programmer + Student { + fn git_username(self) -> String; +} +``` + ### Visibility By default, like functions, traits are private to the module they exist in. You can use `pub` diff --git a/test_programs/execution_success/trait_inheritance/Nargo.toml b/test_programs/execution_success/trait_inheritance/Nargo.toml new file mode 100644 index 00000000000..b8390fc800d --- /dev/null +++ b/test_programs/execution_success/trait_inheritance/Nargo.toml @@ -0,0 +1,5 @@ +[package] +name = "trait_inheritance" +type = "bin" +authors = [""] +[dependencies] \ No newline at end of file diff --git a/test_programs/execution_success/trait_inheritance/Prover.toml b/test_programs/execution_success/trait_inheritance/Prover.toml new file mode 100644 index 00000000000..e69de29bb2d diff --git a/test_programs/execution_success/trait_inheritance/src/main.nr b/test_programs/execution_success/trait_inheritance/src/main.nr new file mode 100644 index 00000000000..1d17d386189 --- /dev/null +++ b/test_programs/execution_success/trait_inheritance/src/main.nr @@ -0,0 +1,33 @@ +trait Foo { + fn foo(self) -> Field; +} + +trait Bar: Foo { + fn bar(self) -> Field { + self.foo() + 1 + } + + fn baz(self) -> Field; +} + +struct Struct { + x: Field, +} + +impl Foo for Struct { + fn foo(self) -> Field { + self.x + } +} + +impl Bar for Struct { + fn baz(self) -> Field { + self.foo() + 2 + } +} + +fn main() { + let s = Struct { x: 1 }; + assert_eq(s.bar(), 2); + assert_eq(s.baz(), 3); +} diff --git a/tooling/lsp/src/trait_impl_method_stub_generator.rs b/tooling/lsp/src/trait_impl_method_stub_generator.rs index 14b40858bb1..b433ee2ec88 100644 --- a/tooling/lsp/src/trait_impl_method_stub_generator.rs +++ b/tooling/lsp/src/trait_impl_method_stub_generator.rs @@ -98,9 +98,9 @@ impl<'a> TraitImplMethodStubGenerator<'a> { } self.append_type(&constraint.typ); self.string.push_str(": "); - let trait_ = self.interner.get_trait(constraint.trait_id); + let trait_ = self.interner.get_trait(constraint.trait_bound.trait_id); self.string.push_str(&trait_.name.0.contents); - self.append_trait_generics(&constraint.trait_generics); + self.append_trait_generics(&constraint.trait_bound.trait_generics); } } diff --git a/tooling/nargo_fmt/tests/expected/trait.nr b/tooling/nargo_fmt/tests/expected/trait.nr new file mode 100644 index 00000000000..0467585fac3 --- /dev/null +++ b/tooling/nargo_fmt/tests/expected/trait.nr @@ -0,0 +1,2 @@ +trait Foo: Bar + Baz {} + diff --git a/tooling/nargo_fmt/tests/input/trait.nr b/tooling/nargo_fmt/tests/input/trait.nr new file mode 100644 index 00000000000..0467585fac3 --- /dev/null +++ b/tooling/nargo_fmt/tests/input/trait.nr @@ -0,0 +1,2 @@ +trait Foo: Bar + Baz {} +