Skip to content

Commit

Permalink
feat: Allow comptime attributes on traits & functions (#5496)
Browse files Browse the repository at this point in the history
# Description

## Problem\*

Resolves #5477

## Summary\*

Adds:
- The ability to run `comptime` attribute functions on traits &
functions in the program
- The `TraitDefinition` type
- The `FunctionDefinition` type
- The `Module` type - the only one of the new types which you still
can't run attributes on. See:
#5495. Running these on modules
is a bit more difficult since modules don't have an entry in
`CollectedItems` to run them on. So I'm delaying this for a later PR.

## Additional Context



## Documentation\*

Check one:
- [ ] No documentation needed.
- [ ] Documentation included in this PR.
- [x] **[For Experimental Features]** Documentation to be submitted in a
separate PR.

# PR Checklist\*

- [x] I have tested the changes locally.
- [x] I have formatted the changes with [Prettier](https://prettier.io/)
and/or `cargo fmt` on default settings.
  • Loading branch information
jfecher authored Jul 12, 2024
1 parent 141ecdd commit b59a29e
Show file tree
Hide file tree
Showing 13 changed files with 176 additions and 25 deletions.
2 changes: 2 additions & 0 deletions compiler/noirc_frontend/src/ast/traits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use crate::ast::{
BlockExpression, Expression, FunctionReturnType, Ident, NoirFunction, Path, UnresolvedGenerics,
UnresolvedType,
};
use crate::macros_api::SecondaryAttribute;
use crate::node_interner::TraitId;

/// AST node for trait definitions:
Expand All @@ -18,6 +19,7 @@ pub struct NoirTrait {
pub where_clause: Vec<UnresolvedTraitConstraint>,
pub span: Span,
pub items: Vec<TraitItem>,
pub attributes: Vec<SecondaryAttribute>,
}

/// Any declaration inside the body of a trait that a user is required to
Expand Down
50 changes: 37 additions & 13 deletions compiler/noirc_frontend/src/elaborator/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -252,11 +252,11 @@ impl<'context> Elaborator<'context> {
}

// Must resolve structs before we resolve globals.
let generated_items = self.collect_struct_definitions(items.types);
let mut generated_items = self.collect_struct_definitions(items.types);

self.define_function_metas(&mut items.functions, &mut items.impls, &mut items.trait_impls);

self.collect_traits(items.traits);
self.collect_traits(items.traits, &mut generated_items);

// Before we resolve any function symbols we must go through our impls and
// re-collect the methods within into their proper module. This cannot be
Expand All @@ -278,6 +278,10 @@ impl<'context> Elaborator<'context> {
self.elaborate_global(global);
}

// We have to run any comptime attributes on functions before the function is elaborated
// since the generated items are checked beforehand as well.
self.run_attributes_on_functions(&items.functions, &mut generated_items);

// After everything is collected, we can elaborate our generated items.
// It may be better to inline these within `items` entirely since elaborating them
// all here means any globals will not see these. Inlining them completely within `items`
Expand Down Expand Up @@ -1242,7 +1246,8 @@ impl<'context> Elaborator<'context> {
.add_definition_location(ReferenceId::StructMember(type_id, field_index), None);
}

self.run_comptime_attributes_on_struct(attributes, type_id, span, &mut generated_items);
let item = Value::StructDefinition(type_id);
self.run_comptime_attributes_on_item(&attributes, item, span, &mut generated_items);
}

// Check whether the struct fields have nested slices
Expand All @@ -1268,34 +1273,34 @@ impl<'context> Elaborator<'context> {
generated_items
}

fn run_comptime_attributes_on_struct(
fn run_comptime_attributes_on_item(
&mut self,
attributes: Vec<SecondaryAttribute>,
struct_id: StructId,
attributes: &[SecondaryAttribute],
item: Value,
span: Span,
generated_items: &mut CollectedItems,
) {
for attribute in attributes {
if let SecondaryAttribute::Custom(name) = attribute {
if let Err(error) =
self.run_comptime_attribute_on_struct(name, struct_id, span, generated_items)
self.run_comptime_attribute_on_item(name, item.clone(), span, generated_items)
{
self.errors.push(error);
}
}
}
}

fn run_comptime_attribute_on_struct(
fn run_comptime_attribute_on_item(
&mut self,
attribute: String,
struct_id: StructId,
attribute: &str,
item: Value,
span: Span,
generated_items: &mut CollectedItems,
) -> Result<(), (CompilationError, FileId)> {
let location = Location::new(span, self.file);
let (function_name, mut arguments) =
Self::parse_attribute(&attribute, location).unwrap_or((attribute, Vec::new()));
let (function_name, mut arguments) = Self::parse_attribute(attribute, location)
.unwrap_or_else(|| (attribute.to_string(), Vec::new()));

let id = self
.lookup_global(Path::from_single(function_name, span))
Expand All @@ -1307,7 +1312,7 @@ impl<'context> Elaborator<'context> {
};

self.handle_varargs_attribute(function, &mut arguments, location);
arguments.insert(0, (Value::StructDefinition(struct_id), location));
arguments.insert(0, (item, location));

let mut interpreter_errors = vec![];
let mut interpreter = self.setup_interpreter(&mut interpreter_errors);
Expand Down Expand Up @@ -1741,4 +1746,23 @@ impl<'context> Elaborator<'context> {
));
}
}

fn run_attributes_on_functions(
&mut self,
function_sets: &[UnresolvedFunctions],
generated_items: &mut CollectedItems,
) {
for function_set in function_sets {
self.file = function_set.file_id;
self.self_type = function_set.self_type.clone();

for (local_module, function_id, function) in &function_set.functions {
self.local_module = *local_module;
let attributes = function.secondary_attributes();
let item = Value::FunctionDefinition(*function_id);
let span = function.span();
self.run_comptime_attributes_on_item(attributes, item, span, generated_items);
}
}
}
}
13 changes: 11 additions & 2 deletions compiler/noirc_frontend/src/elaborator/traits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use crate::{
ast::{
FunctionKind, TraitItem, UnresolvedGeneric, UnresolvedGenerics, UnresolvedTraitConstraint,
},
hir::def_collector::dc_crate::UnresolvedTrait,
hir::def_collector::dc_crate::{CollectedItems, UnresolvedTrait},
hir_def::traits::{TraitConstant, TraitFunction, TraitType},
macros_api::{
BlockExpression, FunctionDefinition, FunctionReturnType, Ident, ItemVisibility,
Expand All @@ -21,7 +21,11 @@ use crate::{
use super::Elaborator;

impl<'context> Elaborator<'context> {
pub fn collect_traits(&mut self, traits: BTreeMap<TraitId, UnresolvedTrait>) {
pub fn collect_traits(
&mut self,
traits: BTreeMap<TraitId, UnresolvedTrait>,
generated_items: &mut CollectedItems,
) {
for (trait_id, unresolved_trait) in traits {
self.recover_generics(|this| {
let resolved_generics = this.interner.get_trait(trait_id).generics.clone();
Expand All @@ -41,6 +45,11 @@ impl<'context> Elaborator<'context> {
this.interner.update_trait(trait_id, |trait_def| {
trait_def.set_methods(methods);
});

let attributes = &unresolved_trait.trait_def.attributes;
let item = crate::hir::comptime::Value::TraitDefinition(trait_id);
let span = unresolved_trait.trait_def.span;
this.run_comptime_attributes_on_item(attributes, item, span, generated_items);
});

// This check needs to be after the trait's methods are set since
Expand Down
24 changes: 21 additions & 3 deletions compiler/noirc_frontend/src/hir/comptime/value.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,13 @@ use noirc_errors::Location;

use crate::{
ast::{ArrayLiteral, ConstructorExpression, Ident, IntegerBitSize, Signedness},
hir::def_map::ModuleId,
hir_def::expr::{HirArrayLiteral, HirConstructorExpression, HirIdent, HirLambda, ImplKind},
macros_api::{
Expression, ExpressionKind, HirExpression, HirLiteral, Literal, NodeInterner, Path,
StructId,
},
node_interner::{ExprId, FuncId},
node_interner::{ExprId, FuncId, TraitId},
parser::{self, NoirParser, TopLevelStatement},
token::{SpannedToken, Token, Tokens},
QuotedType, Shared, Type, TypeBindings,
Expand Down Expand Up @@ -45,6 +46,9 @@ pub enum Value {
Slice(Vector<Value>, Type),
Code(Rc<Tokens>),
StructDefinition(StructId),
TraitDefinition(TraitId),
FunctionDefinition(FuncId),
ModuleDefinition(ModuleId),
}

impl Value {
Expand Down Expand Up @@ -79,6 +83,9 @@ impl Value {
let element = element.borrow().get_type().into_owned();
Type::MutableReference(Box::new(element))
}
Value::TraitDefinition(_) => Type::Quoted(QuotedType::TraitDefinition),
Value::FunctionDefinition(_) => Type::Quoted(QuotedType::FunctionDefinition),
Value::ModuleDefinition(_) => Type::Quoted(QuotedType::Module),
})
}

Expand Down Expand Up @@ -192,7 +199,11 @@ impl Value {
}
};
}
Value::Pointer(_) | Value::StructDefinition(_) => {
Value::Pointer(_)
| Value::StructDefinition(_)
| Value::TraitDefinition(_)
| Value::FunctionDefinition(_)
| Value::ModuleDefinition(_) => {
return Err(InterpreterError::CannotInlineMacro { value: self, location })
}
};
Expand Down Expand Up @@ -298,7 +309,11 @@ impl Value {
HirExpression::Literal(HirLiteral::Slice(HirArrayLiteral::Standard(elements)))
}
Value::Code(block) => HirExpression::Unquote(unwrap_rc(block)),
Value::Pointer(_) | Value::StructDefinition(_) => {
Value::Pointer(_)
| Value::StructDefinition(_)
| Value::TraitDefinition(_)
| Value::FunctionDefinition(_)
| Value::ModuleDefinition(_) => {
return Err(InterpreterError::CannotInlineMacro { value: self, location })
}
};
Expand Down Expand Up @@ -402,6 +417,9 @@ impl Display for Value {
write!(f, " }}")
}
Value::StructDefinition(_) => write!(f, "(struct definition)"),
Value::TraitDefinition(_) => write!(f, "(trait definition)"),
Value::FunctionDefinition(_) => write!(f, "(function definition)"),
Value::ModuleDefinition(_) => write!(f, "(module)"),
}
}
}
6 changes: 6 additions & 0 deletions compiler/noirc_frontend/src/hir_def/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,9 @@ pub enum QuotedType {
TopLevelItem,
Type,
StructDefinition,
TraitDefinition,
FunctionDefinition,
Module,
}

/// A list of TypeVariableIds to bind to a type. Storing the
Expand Down Expand Up @@ -677,6 +680,9 @@ impl std::fmt::Display for QuotedType {
QuotedType::TopLevelItem => write!(f, "TopLevelItem"),
QuotedType::Type => write!(f, "Type"),
QuotedType::StructDefinition => write!(f, "StructDefinition"),
QuotedType::TraitDefinition => write!(f, "TraitDefinition"),
QuotedType::FunctionDefinition => write!(f, "FunctionDefinition"),
QuotedType::Module => write!(f, "Module"),
}
}
}
Expand Down
13 changes: 11 additions & 2 deletions compiler/noirc_frontend/src/lexer/token.rs
Original file line number Diff line number Diff line change
Expand Up @@ -895,25 +895,28 @@ pub enum Keyword {
Fn,
For,
FormatString,
FunctionDefinition,
Global,
If,
Impl,
In,
Let,
Mod,
Module,
Mut,
Pub,
Quoted,
Return,
ReturnData,
String,
Struct,
StructDefinition,
Super,
TopLevelItem,
Trait,
TraitDefinition,
Type,
TypeType,
StructDefinition,
Unchecked,
Unconstrained,
Use,
Expand Down Expand Up @@ -943,25 +946,28 @@ impl fmt::Display for Keyword {
Keyword::Fn => write!(f, "fn"),
Keyword::For => write!(f, "for"),
Keyword::FormatString => write!(f, "fmtstr"),
Keyword::FunctionDefinition => write!(f, "FunctionDefinition"),
Keyword::Global => write!(f, "global"),
Keyword::If => write!(f, "if"),
Keyword::Impl => write!(f, "impl"),
Keyword::In => write!(f, "in"),
Keyword::Let => write!(f, "let"),
Keyword::Mod => write!(f, "mod"),
Keyword::Module => write!(f, "Module"),
Keyword::Mut => write!(f, "mut"),
Keyword::Pub => write!(f, "pub"),
Keyword::Quoted => write!(f, "Quoted"),
Keyword::Return => write!(f, "return"),
Keyword::ReturnData => write!(f, "return_data"),
Keyword::String => write!(f, "str"),
Keyword::Struct => write!(f, "struct"),
Keyword::StructDefinition => write!(f, "StructDefinition"),
Keyword::Super => write!(f, "super"),
Keyword::TopLevelItem => write!(f, "TopLevelItem"),
Keyword::Trait => write!(f, "trait"),
Keyword::TraitDefinition => write!(f, "TraitDefinition"),
Keyword::Type => write!(f, "type"),
Keyword::TypeType => write!(f, "Type"),
Keyword::StructDefinition => write!(f, "StructDefinition"),
Keyword::Unchecked => write!(f, "unchecked"),
Keyword::Unconstrained => write!(f, "unconstrained"),
Keyword::Use => write!(f, "use"),
Expand Down Expand Up @@ -994,12 +1000,14 @@ impl Keyword {
"fn" => Keyword::Fn,
"for" => Keyword::For,
"fmtstr" => Keyword::FormatString,
"FunctionDefinition" => Keyword::FunctionDefinition,
"global" => Keyword::Global,
"if" => Keyword::If,
"impl" => Keyword::Impl,
"in" => Keyword::In,
"let" => Keyword::Let,
"mod" => Keyword::Mod,
"Module" => Keyword::Module,
"mut" => Keyword::Mut,
"pub" => Keyword::Pub,
"Quoted" => Keyword::Quoted,
Expand All @@ -1010,6 +1018,7 @@ impl Keyword {
"super" => Keyword::Super,
"TopLevelItem" => Keyword::TopLevelItem,
"trait" => Keyword::Trait,
"TraitDefinition" => Keyword::TraitDefinition,
"type" => Keyword::Type,
"Type" => Keyword::TypeType,
"StructDefinition" => Keyword::StructDefinition,
Expand Down
18 changes: 14 additions & 4 deletions compiler/noirc_frontend/src/parser/parser/traits.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use chumsky::prelude::*;

use super::attributes::{attributes, validate_secondary_attributes};
use super::function::function_return_type;
use super::{block, expression, fresh_statement, function, function_declaration_parameters};

Expand All @@ -18,15 +19,24 @@ use crate::{
use super::{generic_type_args, parse_type, path, primitives::ident};

pub(super) fn trait_definition() -> impl NoirParser<TopLevelStatement> {
keyword(Keyword::Trait)
.ignore_then(ident())
attributes()
.then_ignore(keyword(Keyword::Trait))
.then(ident())
.then(function::generics())
.then(where_clause())
.then_ignore(just(Token::LeftBrace))
.then(trait_body())
.then_ignore(just(Token::RightBrace))
.map_with_span(|(((name, generics), where_clause), items), span| {
TopLevelStatement::Trait(NoirTrait { name, generics, where_clause, span, items })
.validate(|((((attributes, name), generics), where_clause), items), span, emit| {
let attributes = validate_secondary_attributes(attributes, span, emit);
TopLevelStatement::Trait(NoirTrait {
name,
generics,
where_clause,
span,
items,
attributes,
})
})
}

Expand Down
Loading

0 comments on commit b59a29e

Please sign in to comment.