Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(experimental elaborator): Fix impl Trait when --use-elaborator is selected #5138

Merged
merged 1 commit into from
May 30, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 41 additions & 48 deletions compiler/noirc_frontend/src/elaborator/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,7 @@
self.trait_id = None;
}

fn elaborate_function(&mut self, mut function: NoirFunction, id: FuncId) {
fn elaborate_function(&mut self, function: NoirFunction, id: FuncId) {
self.current_function = Some(id);

// Without this, impl methods can accidentally be placed in contracts. See #3254
Expand Down Expand Up @@ -303,7 +303,6 @@
}

self.generics = func_meta.all_generics.clone();
self.desugar_impl_trait_args(&mut function, id);
self.declare_numeric_generics(&func_meta.parameters, func_meta.return_type());
self.add_trait_constraints_to_scope(&func_meta);

Expand All @@ -330,7 +329,7 @@
// when multiple impls are available. Instead we default first to choose the Field or u64 impl.
for typ in &self.type_variables {
if let Type::TypeVariable(variable, kind) = typ.follow_bindings() {
let msg = "TypeChecker should only track defaultable type vars";

Check warning on line 332 in compiler/noirc_frontend/src/elaborator/mod.rs

View workflow job for this annotation

GitHub Actions / Code

Unknown word (defaultable)
variable.bind(kind.default_type().expect(msg));
}
}
Expand Down Expand Up @@ -373,49 +372,31 @@
}

/// This turns function parameters of the form:
/// fn foo(x: impl Bar)
/// `fn foo(x: impl Bar)`
///
/// into
/// fn foo<T0_impl_Bar>(x: T0_impl_Bar) where T0_impl_Bar: Bar
fn desugar_impl_trait_args(&mut self, func: &mut NoirFunction, func_id: FuncId) {
let mut impl_trait_generics = HashSet::default();
let mut counter: usize = 0;
for parameter in func.def.parameters.iter_mut() {
if let UnresolvedTypeData::TraitAsType(path, args) = &parameter.typ.typ {
let mut new_generic_ident: Ident =
format!("T{}_impl_{}", func_id, path.as_string()).into();
let mut new_generic_path = Path::from_ident(new_generic_ident.clone());
while impl_trait_generics.contains(&new_generic_ident)
|| self.lookup_generic_or_global_type(&new_generic_path).is_some()
{
new_generic_ident =
format!("T{}_impl_{}_{}", func_id, path.as_string(), counter).into();
new_generic_path = Path::from_ident(new_generic_ident.clone());
counter += 1;
}
impl_trait_generics.insert(new_generic_ident.clone());

let is_synthesized = true;
let new_generic_type_data =
UnresolvedTypeData::Named(new_generic_path, vec![], is_synthesized);
let new_generic_type =
UnresolvedType { typ: new_generic_type_data.clone(), span: None };
let new_trait_bound = TraitBound {
trait_path: path.clone(),
trait_id: None,
trait_generics: args.to_vec(),
};
let new_trait_constraint = UnresolvedTraitConstraint {
typ: new_generic_type,
trait_bound: new_trait_bound,
};

parameter.typ.typ = new_generic_type_data;
func.def.generics.push(new_generic_ident);
func.def.where_clause.push(new_trait_constraint);
}
/// `fn foo<T0_impl_Bar>(x: T0_impl_Bar) where T0_impl_Bar: Bar`
/// although the fresh type variable is not named internally.
fn desugar_impl_trait_arg(
&mut self,
trait_path: Path,
trait_generics: Vec<UnresolvedType>,
generics: &mut Vec<TypeVariable>,
trait_constraints: &mut Vec<TraitConstraint>,
) -> Type {
let new_generic_id = self.interner.next_type_variable_id();
let new_generic = TypeVariable::unbound(new_generic_id);
generics.push(new_generic.clone());

let name = format!("impl {trait_path}");
let generic_type = Type::NamedGeneric(new_generic, Rc::new(name));
let trait_bound = TraitBound { trait_path, trait_id: None, trait_generics };

if let Some(new_constraint) = self.resolve_trait_bound(&trait_bound, generic_type.clone()) {
trait_constraints.push(new_constraint);
}
self.add_generics(&impl_trait_generics.into_iter().collect());

generic_type
}

/// Add the given generics to scope.
Expand Down Expand Up @@ -491,11 +472,14 @@
constraint: &UnresolvedTraitConstraint,
) -> Option<TraitConstraint> {
let typ = self.resolve_type(constraint.typ.clone());
let trait_generics =
vecmap(&constraint.trait_bound.trait_generics, |typ| self.resolve_type(typ.clone()));
self.resolve_trait_bound(&constraint.trait_bound, typ)
}

fn resolve_trait_bound(&mut self, bound: &TraitBound, typ: Type) -> Option<TraitConstraint> {
let trait_generics = vecmap(&bound.trait_generics, |typ| self.resolve_type(typ.clone()));

let span = constraint.trait_bound.trait_path.span();
let the_trait = self.lookup_trait_or_error(constraint.trait_bound.trait_path.clone())?;
let span = bound.trait_path.span();
let the_trait = self.lookup_trait_or_error(bound.trait_path.clone())?;
let trait_id = the_trait.id;

let expected_generics = the_trait.generics.len();
Expand Down Expand Up @@ -561,6 +545,8 @@

self.add_generics(&func.def.generics);

let mut trait_constraints = self.resolve_trait_constraints(&func.def.where_clause);

let mut generics = vecmap(&self.generics, |(_, typevar, _)| typevar.clone());
let mut parameters = Vec::new();
let mut parameter_types = Vec::new();
Expand All @@ -575,7 +561,14 @@
}

let type_span = typ.span.unwrap_or_else(|| pattern.span());
let typ = self.resolve_type_inner(typ, &mut generics);

let typ = match typ.typ {
UnresolvedTypeData::TraitAsType(path, args) => {
self.desugar_impl_trait_arg(path, args, &mut generics, &mut trait_constraints)
}
_ => self.resolve_type_inner(typ, &mut generics),
};

self.check_if_type_is_valid_for_program_input(
&typ,
is_entry_point,
Expand Down Expand Up @@ -660,7 +653,7 @@
return_type: func.def.return_type.clone(),
return_visibility: func.def.return_visibility,
has_body: !func.def.body.is_empty(),
trait_constraints: self.resolve_trait_constraints(&func.def.where_clause),
trait_constraints,
is_entry_point,
is_trait_function,
has_inline_attribute,
Expand Down
Loading