Skip to content

Commit

Permalink
feat: Allow a trait to be implemented multiple times for the same str…
Browse files Browse the repository at this point in the history
…uct (#3292)
  • Loading branch information
jfecher authored and guipublic committed Oct 30, 2023
1 parent 2d3d6f6 commit 593f3c1
Show file tree
Hide file tree
Showing 6 changed files with 108 additions and 71 deletions.
92 changes: 39 additions & 53 deletions compiler/noirc_frontend/src/hir/def_collector/dc_crate.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use super::dc_mod::collect_defs;
use super::errors::{DefCollectorErrorKind, DuplicateType};
use crate::graph::CrateId;
use crate::hir::def_map::{CrateDefMap, LocalModuleId, ModuleDefId, ModuleId};
use crate::hir::def_map::{CrateDefMap, LocalModuleId, ModuleData, ModuleDefId, ModuleId};
use crate::hir::resolution::errors::ResolverError;
use crate::hir::resolution::import::PathResolutionError;
use crate::hir::resolution::path_resolver::PathResolver;
Expand Down Expand Up @@ -126,7 +126,7 @@ pub struct DefCollector {
pub enum CompilationError {
ParseError(ParserError),
DefinitionError(DefCollectorErrorKind),
ResolveError(ResolverError),
ResolverError(ResolverError),
TypeError(TypeCheckError),
}

Expand All @@ -135,7 +135,7 @@ impl From<CompilationError> for CustomDiagnostic {
match value {
CompilationError::ParseError(error) => error.into(),
CompilationError::DefinitionError(error) => error.into(),
CompilationError::ResolveError(error) => error.into(),
CompilationError::ResolverError(error) => error.into(),
CompilationError::TypeError(error) => error.into(),
}
}
Expand All @@ -155,7 +155,7 @@ impl From<DefCollectorErrorKind> for CompilationError {

impl From<ResolverError> for CompilationError {
fn from(value: ResolverError) -> Self {
CompilationError::ResolveError(value)
CompilationError::ResolverError(value)
}
}
impl From<TypeCheckError> for CompilationError {
Expand Down Expand Up @@ -296,12 +296,6 @@ impl DefCollector {
// globals will need to reference the struct type they're initialized to to ensure they are valid.
resolved_globals.extend(resolve_globals(context, other_globals, crate_id));

// Before we resolve any function symbols we must go through our impls and
// re-collect the methods within into their proper module. This cannot be
// done before resolution since we need to be able to resolve the type of the
// impl since that determines the module we should collect into.
errors.extend(collect_impls(context, crate_id, &def_collector.collected_impls));

// Bind trait impls to their trait. Collect trait functions, that have a
// default implementation, which hasn't been overridden.
errors.extend(collect_trait_impls(
Expand All @@ -310,6 +304,15 @@ impl DefCollector {
&mut def_collector.collected_traits_impls,
));

// Before we resolve any function symbols we must go through our impls and
// re-collect the methods within into their proper module. This cannot be
// done before resolution since we need to be able to resolve the type of the
// impl since that determines the module we should collect into.
//
// These are resolved after trait impls so that struct methods are chosen
// over trait methods if there are name conflicts.
errors.extend(collect_impls(context, crate_id, &def_collector.collected_impls));

// Lower each function in the crate. This is now possible since imports have been resolved
let file_func_ids = resolve_free_functions(
&mut context.def_interner,
Expand Down Expand Up @@ -377,7 +380,6 @@ fn collect_impls(

if let Some(struct_type) = get_struct_type(&typ) {
let struct_type = struct_type.borrow();
let type_module = struct_type.id.local_module_id();

// `impl`s are only allowed on types defined within the current crate
if struct_type.id.krate() != crate_id {
Expand All @@ -391,7 +393,7 @@ fn collect_impls(
// Grab the module defined by the struct type. Note that impls are a case
// where the module the methods are added to is not the same as the module
// they are resolved in.
let module = &mut def_maps.get_mut(&crate_id).unwrap().modules[type_module.0];
let module = get_module_mut(def_maps, struct_type.id.module_id());

for (_, method_id, method) in &unresolved.functions {
// If this method was already declared, remove it from the module so it cannot
Expand All @@ -413,6 +415,13 @@ fn collect_impls(
errors
}

fn get_module_mut(
def_maps: &mut BTreeMap<CrateId, CrateDefMap>,
module: ModuleId,
) -> &mut ModuleData {
&mut def_maps.get_mut(&module.krate).unwrap().modules[module.local_id.0]
}

fn collect_trait_impl_methods(
interner: &mut NodeInterner,
def_maps: &BTreeMap<CrateId, CrateDefMap>,
Expand Down Expand Up @@ -494,25 +503,6 @@ fn collect_trait_impl_methods(
errors
}

fn add_method_to_struct_namespace(
current_def_map: &mut CrateDefMap,
struct_type: &Shared<StructType>,
func_id: FuncId,
name_ident: &Ident,
trait_id: TraitId,
) -> Result<(), DefCollectorErrorKind> {
let struct_type = struct_type.borrow();
let type_module = struct_type.id.local_module_id();
let module = &mut current_def_map.modules[type_module.0];
module.declare_trait_function(name_ident.clone(), func_id, trait_id).map_err(
|(first_def, second_def)| DefCollectorErrorKind::Duplicate {
typ: DuplicateType::TraitImplementation,
first_def,
second_def,
},
)
}

fn collect_trait_impl(
context: &mut Context,
crate_id: CrateId,
Expand All @@ -535,28 +525,24 @@ fn collect_trait_impl(
if let Some(trait_id) = trait_impl.trait_id {
errors
.extend(collect_trait_impl_methods(interner, def_maps, crate_id, trait_id, trait_impl));
for (_, func_id, ast) in &trait_impl.methods.functions {
let file = def_maps[&crate_id].file_id(trait_impl.module_id);

let path_resolver = StandardPathResolver::new(module);
let mut resolver = Resolver::new(interner, &path_resolver, def_maps, file);
resolver.add_generics(&ast.def.generics);
let typ = resolver.resolve_type(unresolved_type.clone());

if let Some(struct_type) = get_struct_type(&typ) {
errors.extend(take_errors(trait_impl.file_id, resolver));
let current_def_map = def_maps.get_mut(&struct_type.borrow().id.krate()).unwrap();
match add_method_to_struct_namespace(
current_def_map,
struct_type,
*func_id,
ast.name_ident(),
trait_id,
) {
Ok(()) => {}
Err(err) => {
errors.push((err.into(), trait_impl.file_id));
}
let path_resolver = StandardPathResolver::new(module);
let file = def_maps[&crate_id].file_id(trait_impl.module_id);
let mut resolver = Resolver::new(interner, &path_resolver, def_maps, file);
let typ = resolver.resolve_type(unresolved_type);
errors.extend(take_errors(trait_impl.file_id, resolver));

if let Some(struct_type) = get_struct_type(&typ) {
let struct_type = struct_type.borrow();
let module = get_module_mut(def_maps, struct_type.id.module_id());

for (_, method_id, method) in &trait_impl.methods.functions {
// If this method was already declared, remove it from the module so it cannot
// be accessed with the `TypeName::method` syntax. We'll check later whether the
// object types in each method overlap or not. If they do, we issue an error.
// If not, that is specialization which is allowed.
if module.declare_function(method.name_ident().clone(), *method_id).is_err() {
module.remove_function(method.name_ident());
}
}
}
Expand Down Expand Up @@ -841,7 +827,7 @@ fn take_errors_filter_self_not_resolved(
}

fn take_errors(file_id: FileId, resolver: Resolver<'_>) -> Vec<(CompilationError, FileId)> {
resolver.take_errors().iter().cloned().map(|e| (e.into(), file_id)).collect()
vecmap(resolver.take_errors(), |e| (e.into(), file_id))
}

/// Create the mappings from TypeId -> TraitType
Expand Down
15 changes: 13 additions & 2 deletions compiler/noirc_frontend/src/hir/def_collector/dc_mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use noirc_errors::Location;
use crate::{
graph::CrateId,
hir::def_collector::dc_crate::{UnresolvedStruct, UnresolvedTrait},
node_interner::{TraitId, TypeAliasId},
node_interner::{FunctionModifiers, TraitId, TypeAliasId},
parser::{SortedModule, SortedSubModule},
FunctionDefinition, Ident, LetStatement, NoirFunction, NoirStruct, NoirTrait, NoirTraitImpl,
NoirTypeAlias, TraitImplItem, TraitItem, TypeImpl,
Expand Down Expand Up @@ -378,11 +378,22 @@ impl<'a> ModCollector<'a> {
body,
} => {
let func_id = context.def_interner.push_empty_fn();
let modifiers = FunctionModifiers {
name: name.to_string(),
visibility: crate::FunctionVisibility::Public,
// TODO(Maddiaa): Investigate trait implementations with attributes see: https://github.com/noir-lang/noir/issues/2629
attributes: crate::token::Attributes::empty(),
is_unconstrained: false,
contract_function_type: None,
is_internal: None,
};

context.def_interner.push_function_definition(func_id, modifiers, id.0);

match self.def_collector.def_map.modules[id.0.local_id.0]
.declare_function(name.clone(), func_id)
{
Ok(()) => {
// TODO(Maddiaa): Investigate trait implementations with attributes see: https://github.com/noir-lang/noir/issues/2629
if let Some(body) = body {
let impl_method =
NoirFunction::normal(FunctionDefinition::normal(
Expand Down
10 changes: 5 additions & 5 deletions compiler/noirc_frontend/src/node_interner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -618,9 +618,10 @@ impl NodeInterner {
#[cfg(test)]
pub fn push_test_function_definition(&mut self, name: String) -> FuncId {
let id = self.push_fn(HirFunction::empty());
let modifiers = FunctionModifiers::new();
let mut modifiers = FunctionModifiers::new();
modifiers.name = name;
let module = ModuleId::dummy_id();
self.push_function_definition(name, id, modifiers, module);
self.push_function_definition(id, modifiers, module);
id
}

Expand All @@ -631,7 +632,6 @@ impl NodeInterner {
module: ModuleId,
) -> DefinitionId {
use ContractFunctionType::*;
let name = function.name.0.contents.clone();

// We're filling in contract_function_type and is_internal now, but these will be verified
// later during name resolution.
Expand All @@ -643,16 +643,16 @@ impl NodeInterner {
contract_function_type: Some(if function.is_open { Open } else { Secret }),
is_internal: Some(function.is_internal),
};
self.push_function_definition(name, id, modifiers, module)
self.push_function_definition(id, modifiers, module)
}

pub fn push_function_definition(
&mut self,
name: String,
func: FuncId,
modifiers: FunctionModifiers,
module: ModuleId,
) -> DefinitionId {
let name = modifiers.name.clone();
self.function_modifiers.insert(func, modifiers);
self.function_modules.insert(func, module);
self.push_definition(name, false, DefinitionKind::Function(func))
Expand Down
28 changes: 17 additions & 11 deletions compiler/noirc_frontend/src/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -461,7 +461,7 @@ mod test {

for (err, _file_id) in errors {
match &err {
CompilationError::ResolveError(ResolverError::PathResolutionError(
CompilationError::ResolverError(ResolverError::PathResolutionError(
PathResolutionError::Unresolved(ident),
)) => {
assert_eq!(ident, "NotAType");
Expand Down Expand Up @@ -533,19 +533,24 @@ mod test {
}
}
fn main() {
}
fn main() {}
";
let errors = get_program_errors(src);
assert!(!has_parser_error(&errors));
assert!(errors.len() == 1, "Expected 1 error, got: {:?}", errors);
assert!(errors.len() == 2, "Expected 2 errors, got: {:?}", errors);
for (err, _file_id) in errors {
match &err {
CompilationError::DefinitionError(
DefCollectorErrorKind::TraitImplNotAllowedFor { trait_path, span: _ },
) => {
assert_eq!(trait_path.as_string(), "Default");
}
CompilationError::ResolverError(ResolverError::Expected {
expected, got, ..
}) => {
assert_eq!(expected, "type");
assert_eq!(got, "function");
}
_ => {
panic!("No other errors are expected! Found = {:?}", err);
}
Expand Down Expand Up @@ -810,7 +815,7 @@ mod test {
assert!(errors.len() == 1, "Expected 1 error, got: {:?}", errors);
// It should be regarding the unused variable
match &errors[0].0 {
CompilationError::ResolveError(ResolverError::UnusedVariable { ident }) => {
CompilationError::ResolverError(ResolverError::UnusedVariable { ident }) => {
assert_eq!(&ident.0.contents, "y");
}
_ => unreachable!("we should only have an unused var error"),
Expand All @@ -829,7 +834,7 @@ mod test {
assert!(errors.len() == 1, "Expected 1 error, got: {:?}", errors);
// It should be regarding the unresolved var `z` (Maybe change to undeclared and special case)
match &errors[0].0 {
CompilationError::ResolveError(ResolverError::VariableNotDeclared {
CompilationError::ResolverError(ResolverError::VariableNotDeclared {
name,
span: _,
}) => assert_eq!(name, "z"),
Expand All @@ -848,7 +853,7 @@ mod test {
assert!(errors.len() == 1, "Expected 1 error, got: {:?}", errors);
for (compilation_error, _file_id) in errors {
match compilation_error {
CompilationError::ResolveError(err) => {
CompilationError::ResolverError(err) => {
match err {
ResolverError::PathResolutionError(PathResolutionError::Unresolved(
name,
Expand Down Expand Up @@ -892,7 +897,7 @@ mod test {
// `foo::bar` does not exist
for (compilation_error, _file_id) in errors {
match compilation_error {
CompilationError::ResolveError(err) => {
CompilationError::ResolverError(err) => {
match err {
ResolverError::UnusedVariable { ident } => {
assert_eq!(&ident.0.contents, "z");
Expand Down Expand Up @@ -1069,12 +1074,13 @@ mod test {

for (err, _file_id) in errors {
match &err {
CompilationError::ResolveError(ResolverError::VariableNotDeclared {
name, ..
CompilationError::ResolverError(ResolverError::VariableNotDeclared {
name,
..
}) => {
assert_eq!(name, "i");
}
CompilationError::ResolveError(ResolverError::NumericConstantInFormatString {
CompilationError::ResolverError(ResolverError::NumericConstantInFormatString {
name,
..
}) => {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
[package]
name = "trait_generics"
type = "bin"
authors = [""]
compiler_version = "0.10.5"

[dependencies]
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@

struct Empty<T> {}

trait Foo {
fn foo(self) -> u32;
}

impl Foo for Empty<u32> {
fn foo(_self: Self) -> u32 { 32 }
}

impl Foo for Empty<u64> {
fn foo(_self: Self) -> u32 { 64 }
}

fn main() {
let x: Empty<u32> = Empty {};
let y: Empty<u64> = Empty {};
let z = Empty {};

assert(x.foo() == 32);
assert(y.foo() == 64);

// Types matching multiple impls will currently choose
// the first matching one instead of erroring
assert(z.foo() == 32);
}

0 comments on commit 593f3c1

Please sign in to comment.