diff --git a/compiler/noirc_frontend/src/elaborator/mod.rs b/compiler/noirc_frontend/src/elaborator/mod.rs index 1d66b4b952a..a0ff3c12ae3 100644 --- a/compiler/noirc_frontend/src/elaborator/mod.rs +++ b/compiler/noirc_frontend/src/elaborator/mod.rs @@ -273,7 +273,7 @@ impl<'context> Elaborator<'context> { self.trait_id = None; } - fn elaborate_function(&mut self, mut function: NoirFunction, id: FuncId) { + fn elaborate_function(&mut self, function: NoirFunction, id: FuncId) { self.current_function = Some(id); // Without this, impl methods can accidentally be placed in contracts. See #3254 @@ -303,7 +303,6 @@ impl<'context> Elaborator<'context> { } self.generics = func_meta.all_generics.clone(); - self.desugar_impl_trait_args(&mut function, id); self.declare_numeric_generics(&func_meta.parameters, func_meta.return_type()); self.add_trait_constraints_to_scope(&func_meta); @@ -373,49 +372,31 @@ impl<'context> Elaborator<'context> { } /// This turns function parameters of the form: - /// fn foo(x: impl Bar) + /// `fn foo(x: impl Bar)` /// /// into - /// fn foo(x: T0_impl_Bar) where T0_impl_Bar: Bar - fn desugar_impl_trait_args(&mut self, func: &mut NoirFunction, func_id: FuncId) { - let mut impl_trait_generics = HashSet::default(); - let mut counter: usize = 0; - for parameter in func.def.parameters.iter_mut() { - if let UnresolvedTypeData::TraitAsType(path, args) = ¶meter.typ.typ { - let mut new_generic_ident: Ident = - format!("T{}_impl_{}", func_id, path.as_string()).into(); - let mut new_generic_path = Path::from_ident(new_generic_ident.clone()); - while impl_trait_generics.contains(&new_generic_ident) - || self.lookup_generic_or_global_type(&new_generic_path).is_some() - { - new_generic_ident = - format!("T{}_impl_{}_{}", func_id, path.as_string(), counter).into(); - new_generic_path = Path::from_ident(new_generic_ident.clone()); - counter += 1; - } - impl_trait_generics.insert(new_generic_ident.clone()); - - let is_synthesized = true; - let new_generic_type_data = - UnresolvedTypeData::Named(new_generic_path, vec![], is_synthesized); - let new_generic_type = - UnresolvedType { typ: new_generic_type_data.clone(), span: None }; - let new_trait_bound = TraitBound { - trait_path: path.clone(), - trait_id: None, - trait_generics: args.to_vec(), - }; - let new_trait_constraint = UnresolvedTraitConstraint { - typ: new_generic_type, - trait_bound: new_trait_bound, - }; - - parameter.typ.typ = new_generic_type_data; - func.def.generics.push(new_generic_ident); - func.def.where_clause.push(new_trait_constraint); - } + /// `fn foo(x: T0_impl_Bar) where T0_impl_Bar: Bar` + /// although the fresh type variable is not named internally. + fn desugar_impl_trait_arg( + &mut self, + trait_path: Path, + trait_generics: Vec, + generics: &mut Vec, + trait_constraints: &mut Vec, + ) -> Type { + let new_generic_id = self.interner.next_type_variable_id(); + let new_generic = TypeVariable::unbound(new_generic_id); + generics.push(new_generic.clone()); + + let name = format!("impl {trait_path}"); + let generic_type = Type::NamedGeneric(new_generic, Rc::new(name)); + let trait_bound = TraitBound { trait_path, trait_id: None, trait_generics }; + + if let Some(new_constraint) = self.resolve_trait_bound(&trait_bound, generic_type.clone()) { + trait_constraints.push(new_constraint); } - self.add_generics(&impl_trait_generics.into_iter().collect()); + + generic_type } /// Add the given generics to scope. @@ -491,11 +472,14 @@ impl<'context> Elaborator<'context> { constraint: &UnresolvedTraitConstraint, ) -> Option { let typ = self.resolve_type(constraint.typ.clone()); - let trait_generics = - vecmap(&constraint.trait_bound.trait_generics, |typ| self.resolve_type(typ.clone())); + self.resolve_trait_bound(&constraint.trait_bound, typ) + } + + fn resolve_trait_bound(&mut self, bound: &TraitBound, typ: Type) -> Option { + let trait_generics = vecmap(&bound.trait_generics, |typ| self.resolve_type(typ.clone())); - let span = constraint.trait_bound.trait_path.span(); - let the_trait = self.lookup_trait_or_error(constraint.trait_bound.trait_path.clone())?; + let span = bound.trait_path.span(); + let the_trait = self.lookup_trait_or_error(bound.trait_path.clone())?; let trait_id = the_trait.id; let expected_generics = the_trait.generics.len(); @@ -561,6 +545,8 @@ impl<'context> Elaborator<'context> { self.add_generics(&func.def.generics); + let mut trait_constraints = self.resolve_trait_constraints(&func.def.where_clause); + let mut generics = vecmap(&self.generics, |(_, typevar, _)| typevar.clone()); let mut parameters = Vec::new(); let mut parameter_types = Vec::new(); @@ -575,7 +561,14 @@ impl<'context> Elaborator<'context> { } let type_span = typ.span.unwrap_or_else(|| pattern.span()); - let typ = self.resolve_type_inner(typ, &mut generics); + + let typ = match typ.typ { + UnresolvedTypeData::TraitAsType(path, args) => { + self.desugar_impl_trait_arg(path, args, &mut generics, &mut trait_constraints) + } + _ => self.resolve_type_inner(typ, &mut generics), + }; + self.check_if_type_is_valid_for_program_input( &typ, is_entry_point, @@ -660,7 +653,7 @@ impl<'context> Elaborator<'context> { return_type: func.def.return_type.clone(), return_visibility: func.def.return_visibility, has_body: !func.def.body.is_empty(), - trait_constraints: self.resolve_trait_constraints(&func.def.where_clause), + trait_constraints, is_entry_point, is_trait_function, has_inline_attribute,