From e5773e47c212c7c8fa1a7d7456893b508cdb400c Mon Sep 17 00:00:00 2001 From: jfecher Date: Wed, 5 Jul 2023 15:38:56 +0200 Subject: [PATCH] feat(ssa refactor): Implement first-class references (#1849) * Explore work on references * Cleanup * Implement first-class references * Fix frontend test * Remove 'Mutability' struct, it is no longer needed * Remove some extra lines * Remove another function * Revert another line * Fix test again * Fix a bug in mem2reg for nested references * Fix inconsistent .eval during ssa-gen on assign statements * Revert some code * Add check for mutating immutable self objects --- .../references/Nargo.toml | 5 + .../references/Prover.toml | 0 .../references/src/main.nr | 56 +++++ .../test_data_ssa_refactor/tuples/src/main.nr | 2 +- crates/noirc_evaluator/src/ssa/context.rs | 1 + crates/noirc_evaluator/src/ssa/ssa_gen.rs | 10 + crates/noirc_evaluator/src/ssa/value.rs | 3 +- .../src/ssa_refactor/ir/function.rs | 1 + .../src/ssa_refactor/opt/mem2reg.rs | 40 ++-- .../src/ssa_refactor/ssa_gen/context.rs | 40 +++- .../src/ssa_refactor/ssa_gen/mod.rs | 27 ++- .../src/ssa_refactor/ssa_gen/value.rs | 30 +++ crates/noirc_frontend/src/ast/expression.rs | 4 + crates/noirc_frontend/src/ast/mod.rs | 4 + crates/noirc_frontend/src/ast/statement.rs | 8 + .../src/hir/resolution/errors.rs | 10 + .../src/hir/resolution/resolver.rs | 43 +++- .../src/hir/type_check/errors.rs | 4 + .../noirc_frontend/src/hir/type_check/expr.rs | 192 +++++++++++++----- .../noirc_frontend/src/hir/type_check/stmt.rs | 52 ++++- crates/noirc_frontend/src/hir_def/stmt.rs | 6 +- crates/noirc_frontend/src/hir_def/types.rs | 36 ++++ .../src/monomorphization/ast.rs | 4 + .../src/monomorphization/mod.rs | 29 ++- .../src/monomorphization/printer.rs | 4 + crates/noirc_frontend/src/node_interner.rs | 1 + crates/noirc_frontend/src/parser/parser.rs | 102 ++++++++-- 27 files changed, 611 insertions(+), 103 deletions(-) create mode 100644 crates/nargo_cli/tests/test_data_ssa_refactor/references/Nargo.toml create mode 100644 crates/nargo_cli/tests/test_data_ssa_refactor/references/Prover.toml create mode 100644 crates/nargo_cli/tests/test_data_ssa_refactor/references/src/main.nr diff --git a/crates/nargo_cli/tests/test_data_ssa_refactor/references/Nargo.toml b/crates/nargo_cli/tests/test_data_ssa_refactor/references/Nargo.toml new file mode 100644 index 00000000000..09a8dd8e8c8 --- /dev/null +++ b/crates/nargo_cli/tests/test_data_ssa_refactor/references/Nargo.toml @@ -0,0 +1,5 @@ +[package] +authors = [""] +compiler_version = "0.5.1" + +[dependencies] \ No newline at end of file diff --git a/crates/nargo_cli/tests/test_data_ssa_refactor/references/Prover.toml b/crates/nargo_cli/tests/test_data_ssa_refactor/references/Prover.toml new file mode 100644 index 00000000000..e69de29bb2d diff --git a/crates/nargo_cli/tests/test_data_ssa_refactor/references/src/main.nr b/crates/nargo_cli/tests/test_data_ssa_refactor/references/src/main.nr new file mode 100644 index 00000000000..6e5eddd8057 --- /dev/null +++ b/crates/nargo_cli/tests/test_data_ssa_refactor/references/src/main.nr @@ -0,0 +1,56 @@ +fn main() { + let mut x = 2; + add1(&mut x); + assert(x == 3); + + let mut s = S { y: x }; + s.add2(); + assert(s.y == 5); + + // Test that normal mutable variables are still copied + let mut a = 0; + mutate_copy(a); + assert(a == 0); + + // Test something 3 allocations deep + let mut nested_allocations = Nested { y: &mut &mut 0 }; + add1(*nested_allocations.y); + assert(**nested_allocations.y == 1); + + // Test nested struct allocations with a mutable reference to an array. + let mut c = C { + foo: 0, + bar: &mut C2 { + array: &mut [1, 2], + }, + }; + *c.bar.array = [3, 4]; + assert(*c.bar.array == [3, 4]); +} + +fn add1(x: &mut Field) { + *x += 1; +} + +struct S { y: Field } + +struct Nested { y: &mut &mut Field } + +struct C { + foo: Field, + bar: &mut C2, +} + +struct C2 { + array: &mut [Field; 2] +} + +impl S { + fn add2(&mut self) { + self.y += 2; + } +} + +fn mutate_copy(mut a: Field) { + a = 7; +} diff --git a/crates/nargo_cli/tests/test_data_ssa_refactor/tuples/src/main.nr b/crates/nargo_cli/tests/test_data_ssa_refactor/tuples/src/main.nr index b1d310b1412..45d8380372b 100644 --- a/crates/nargo_cli/tests/test_data_ssa_refactor/tuples/src/main.nr +++ b/crates/nargo_cli/tests/test_data_ssa_refactor/tuples/src/main.nr @@ -19,7 +19,7 @@ fn main(x: Field, y: Field) { // Test mutating tuples let mut mutable = ((0, 0), 1, 2, 3); - mutable.0 = pair; + mutable.0 = (x, y); mutable.2 = 7; assert(mutable.0.0 == 1); assert(mutable.0.1 == 0); diff --git a/crates/noirc_evaluator/src/ssa/context.rs b/crates/noirc_evaluator/src/ssa/context.rs index 2efdd8ff304..ea9a062ec10 100644 --- a/crates/noirc_evaluator/src/ssa/context.rs +++ b/crates/noirc_evaluator/src/ssa/context.rs @@ -1207,6 +1207,7 @@ impl SsaContext { } } Type::Array(..) => panic!("Cannot convert an array type {t} into an ObjectType since it is unknown which array it refers to"), + Type::MutableReference(..) => panic!("Mutable reference types are unimplemented in the old ssa backend"), Type::Unit => ObjectType::NotAnObject, Type::Function(..) => ObjectType::Function, Type::Tuple(_) => todo!("Conversion to ObjectType is unimplemented for tuples"), diff --git a/crates/noirc_evaluator/src/ssa/ssa_gen.rs b/crates/noirc_evaluator/src/ssa/ssa_gen.rs index 082758468a6..11fc134215d 100644 --- a/crates/noirc_evaluator/src/ssa/ssa_gen.rs +++ b/crates/noirc_evaluator/src/ssa/ssa_gen.rs @@ -208,6 +208,9 @@ impl IrGenerator { self.context.new_instruction(op, rhs_type) } UnaryOp::Not => self.context.new_instruction(Operation::Not(rhs), rhs_type), + UnaryOp::MutableReference | UnaryOp::Dereference => { + unimplemented!("Mutable references are unimplemented in the old ssa backend") + } } } @@ -248,6 +251,9 @@ impl IrGenerator { let val = self.find_variable(ident_def).unwrap(); val.get_field_member(*field_index) } + LValue::Dereference { .. } => { + unreachable!("Mutable references are unsupported in the old ssa backend") + } } } @@ -256,6 +262,7 @@ impl IrGenerator { LValue::Ident(ident) => &ident.definition, LValue::Index { array, .. } => Self::lvalue_ident_def(array.as_ref()), LValue::MemberAccess { object, .. } => Self::lvalue_ident_def(object.as_ref()), + LValue::Dereference { reference, .. } => Self::lvalue_ident_def(reference.as_ref()), } } @@ -462,6 +469,9 @@ impl IrGenerator { let value = val.get_field_member(*field_index).clone(); self.assign_pattern(&value, rhs)?; } + LValue::Dereference { .. } => { + unreachable!("Mutable references are unsupported in the old ssa backend") + } } Ok(Value::dummy()) } diff --git a/crates/noirc_evaluator/src/ssa/value.rs b/crates/noirc_evaluator/src/ssa/value.rs index 915effe480b..b640369cd56 100644 --- a/crates/noirc_evaluator/src/ssa/value.rs +++ b/crates/noirc_evaluator/src/ssa/value.rs @@ -100,7 +100,8 @@ impl Value { | Type::String(..) | Type::Integer(..) | Type::Bool - | Type::Field => Value::Node(*iter.next().unwrap()), + | Type::Field + | Type::MutableReference(_) => Value::Node(*iter.next().unwrap()), } } diff --git a/crates/noirc_evaluator/src/ssa_refactor/ir/function.rs b/crates/noirc_evaluator/src/ssa_refactor/ir/function.rs index f5be19dd88e..8fe2fe745ff 100644 --- a/crates/noirc_evaluator/src/ssa_refactor/ir/function.rs +++ b/crates/noirc_evaluator/src/ssa_refactor/ir/function.rs @@ -13,6 +13,7 @@ pub(crate) enum RuntimeType { // Unconstrained function, to be compiled to brillig and executed by the Brillig VM Brillig, } + /// A function holds a list of instructions. /// These instructions are further grouped into Basic blocks /// diff --git a/crates/noirc_evaluator/src/ssa_refactor/opt/mem2reg.rs b/crates/noirc_evaluator/src/ssa_refactor/opt/mem2reg.rs index b65b42cb83e..145ba25f5a5 100644 --- a/crates/noirc_evaluator/src/ssa_refactor/opt/mem2reg.rs +++ b/crates/noirc_evaluator/src/ssa_refactor/opt/mem2reg.rs @@ -64,20 +64,39 @@ impl PerBlockContext { dfg: &mut DataFlowGraph, ) -> HashSet { let mut protected_allocations = HashSet::new(); - let mut loads_to_substitute = HashMap::new(); let block = &dfg[self.block_id]; + // Maps Load instruction id -> value to replace the result of the load with + let mut loads_to_substitute = HashMap::new(); + + // Maps Load result id -> value to replace the result of the load with + let mut load_values_to_substitute = HashMap::new(); + for instruction_id in block.instructions() { match &dfg[*instruction_id] { - Instruction::Store { address, value } => { - self.last_stores.insert(*address, *value); + Instruction::Store { mut address, value } => { + if let Some(value) = load_values_to_substitute.get(&address) { + address = *value; + } + + self.last_stores.insert(address, *value); self.store_ids.push(*instruction_id); } - Instruction::Load { address } => { - if let Some(last_value) = self.last_stores.get(address) { + Instruction::Load { mut address } => { + if let Some(value) = load_values_to_substitute.get(&address) { + address = *value; + } + + if let Some(last_value) = self.last_stores.get(&address) { + let result_value = *dfg + .instruction_results(*instruction_id) + .first() + .expect("ICE: Load instructions should have single result"); + loads_to_substitute.insert(*instruction_id, *last_value); + load_values_to_substitute.insert(result_value, *last_value); } else { - protected_allocations.insert(*address); + protected_allocations.insert(address); } } Instruction::Call { arguments, .. } => { @@ -103,12 +122,9 @@ impl PerBlockContext { } // Substitute load result values - for (instruction_id, new_value) in &loads_to_substitute { - let result_value = *dfg - .instruction_results(*instruction_id) - .first() - .expect("ICE: Load instructions should have single result"); - dfg.set_value_from_id(result_value, *new_value); + for (result_value, new_value) in load_values_to_substitute { + let result_value = dfg.resolve(result_value); + dfg.set_value_from_id(result_value, new_value); } // Delete load instructions diff --git a/crates/noirc_evaluator/src/ssa_refactor/ssa_gen/context.rs b/crates/noirc_evaluator/src/ssa_refactor/ssa_gen/context.rs index a154af45ecd..2f13733a2dc 100644 --- a/crates/noirc_evaluator/src/ssa_refactor/ssa_gen/context.rs +++ b/crates/noirc_evaluator/src/ssa_refactor/ssa_gen/context.rs @@ -165,12 +165,17 @@ impl<'a> FunctionContext<'a> { // This helper is needed because we need to take f by mutable reference, // otherwise we cannot move it multiple times each loop of vecmap. - fn map_type_helper(typ: &ast::Type, f: &mut impl FnMut(Type) -> T) -> Tree { + fn map_type_helper(typ: &ast::Type, f: &mut dyn FnMut(Type) -> T) -> Tree { match typ { ast::Type::Tuple(fields) => { Tree::Branch(vecmap(fields, |field| Self::map_type_helper(field, f))) } ast::Type::Unit => Tree::empty(), + // A mutable reference wraps each element into a reference. + // This can be multiple values if the element type is a tuple. + ast::Type::MutableReference(element) => { + Self::map_type_helper(element, &mut |_| f(Type::Reference)) + } other => Tree::Leaf(f(Self::convert_non_tuple_type(other))), } } @@ -201,6 +206,11 @@ impl<'a> FunctionContext<'a> { ast::Type::Unit => panic!("convert_non_tuple_type called on a unit type"), ast::Type::Tuple(_) => panic!("convert_non_tuple_type called on a tuple: {typ}"), ast::Type::Function(_, _) => Type::Function, + ast::Type::MutableReference(element) => { + // Recursive call to panic if element is a tuple + Self::convert_non_tuple_type(element); + Type::Reference + } // How should we represent Vecs? // Are they a struct of array + length + capacity? @@ -473,9 +483,21 @@ impl<'a> FunctionContext<'a> { let object_lvalue = Box::new(object_lvalue); LValue::MemberAccess { old_object, object_lvalue, index: *field_index } } + ast::LValue::Dereference { reference, .. } => { + let (reference, _) = self.extract_current_value_recursive(reference); + LValue::Dereference { reference } + } } } + pub(super) fn dereference(&mut self, values: &Values, element_type: &ast::Type) -> Values { + let element_types = Self::convert_type(element_type); + values.map_both(element_types, |value, element_type| { + let reference = value.eval(self); + self.builder.insert_load(reference, element_type).into() + }) + } + /// Compile the given identifier as a reference - ie. avoid calling .eval() fn ident_lvalue(&self, ident: &ast::Ident) -> Values { match &ident.definition { @@ -516,16 +538,19 @@ impl<'a> FunctionContext<'a> { let element = Self::get_field_ref(&old_object, *index).clone(); (element, LValue::MemberAccess { old_object, object_lvalue, index: *index }) } + ast::LValue::Dereference { reference, element_type } => { + let (reference, _) = self.extract_current_value_recursive(reference); + let dereferenced = self.dereference(&reference, element_type); + (dereferenced, LValue::Dereference { reference }) + } } } /// Assigns a new value to the given LValue. /// The LValue can be created via a previous call to extract_current_value. /// This method recurs on the given LValue to create a new value to assign an allocation - /// instruction within an LValue::Ident - see the comment on `extract_current_value` for more - /// details. - /// When first-class references are supported the nearest reference may be in any LValue - /// variant rather than just LValue::Ident. + /// instruction within an LValue::Ident or LValue::Dereference - see the comment on + /// `extract_current_value` for more details. pub(super) fn assign_new_value(&mut self, lvalue: LValue, new_value: Values) { match lvalue { LValue::Ident(references) => self.assign(references, new_value), @@ -538,6 +563,9 @@ impl<'a> FunctionContext<'a> { let new_object = Self::replace_field(old_object, index, new_value); self.assign_new_value(*object_lvalue, new_object); } + LValue::Dereference { reference } => { + self.assign(reference, new_value); + } } } @@ -705,8 +733,10 @@ impl SharedContext { } /// Used to remember the results of each step of extracting a value from an ast::LValue +#[derive(Debug)] pub(super) enum LValue { Ident(Values), Index { old_array: ValueId, index: ValueId, array_lvalue: Box }, MemberAccess { old_object: Values, index: usize, object_lvalue: Box }, + Dereference { reference: Values }, } diff --git a/crates/noirc_evaluator/src/ssa_refactor/ssa_gen/mod.rs b/crates/noirc_evaluator/src/ssa_refactor/ssa_gen/mod.rs index a7fcf2d664b..ac89575ecea 100644 --- a/crates/noirc_evaluator/src/ssa_refactor/ssa_gen/mod.rs +++ b/crates/noirc_evaluator/src/ssa_refactor/ssa_gen/mod.rs @@ -99,7 +99,7 @@ impl<'a> FunctionContext<'a> { /// Codegen for identifiers fn codegen_ident(&mut self, ident: &ast::Ident) -> Values { match &ident.definition { - ast::Definition::Local(id) => self.lookup(*id).map(|value| value.eval(self).into()), + ast::Definition::Local(id) => self.lookup(*id), ast::Definition::Function(id) => self.get_or_queue_function(*id), ast::Definition::Oracle(name) => self.builder.import_foreign_function(name).into(), ast::Definition::Builtin(name) | ast::Definition::LowLevel(name) => { @@ -165,14 +165,33 @@ impl<'a> FunctionContext<'a> { } fn codegen_unary(&mut self, unary: &ast::Unary) -> Values { - let rhs = self.codegen_non_tuple_expression(&unary.rhs); + let rhs = self.codegen_expression(&unary.rhs); match unary.operator { - noirc_frontend::UnaryOp::Not => self.builder.insert_not(rhs).into(), + noirc_frontend::UnaryOp::Not => { + let rhs = rhs.into_leaf().eval(self); + self.builder.insert_not(rhs).into() + } noirc_frontend::UnaryOp::Minus => { + let rhs = rhs.into_leaf().eval(self); let typ = self.builder.type_of_value(rhs); let zero = self.builder.numeric_constant(0u128, typ); self.builder.insert_binary(zero, BinaryOp::Sub, rhs).into() } + noirc_frontend::UnaryOp::MutableReference => { + rhs.map(|rhs| { + match rhs { + value::Value::Normal(value) => { + let alloc = self.builder.insert_allocate(); + self.builder.insert_store(alloc, value); + Tree::Leaf(value::Value::Normal(alloc)) + } + // NOTE: The `.into()` here converts the Value::Mutable into + // a Value::Normal so it is no longer automatically dereferenced. + value::Value::Mutable(reference, _) => reference.into(), + } + }) + } + noirc_frontend::UnaryOp::Dereference => self.dereference(&rhs, &unary.result_type), } } @@ -343,13 +362,13 @@ impl<'a> FunctionContext<'a> { /// Generate SSA for a function call. Note that calls to built-in functions /// and intrinsics are also represented by the function call instruction. fn codegen_call(&mut self, call: &ast::Call) -> Values { + let function = self.codegen_non_tuple_expression(&call.func); let arguments = call .arguments .iter() .flat_map(|argument| self.codegen_expression(argument).into_value_list(self)) .collect(); - let function = self.codegen_non_tuple_expression(&call.func); self.insert_call(function, arguments, &call.return_type) } diff --git a/crates/noirc_evaluator/src/ssa_refactor/ssa_gen/value.rs b/crates/noirc_evaluator/src/ssa_refactor/ssa_gen/value.rs index c50abb9ca30..2d209635610 100644 --- a/crates/noirc_evaluator/src/ssa_refactor/ssa_gen/value.rs +++ b/crates/noirc_evaluator/src/ssa_refactor/ssa_gen/value.rs @@ -123,6 +123,36 @@ impl Tree { } } + /// Map two trees alongside each other. + /// This asserts each tree has the same internal structure. + pub(super) fn map_both( + &self, + other: Tree, + mut f: impl FnMut(T, U) -> Tree, + ) -> Tree + where + T: std::fmt::Debug + Clone, + U: std::fmt::Debug, + { + self.map_both_helper(other, &mut f) + } + + fn map_both_helper(&self, other: Tree, f: &mut impl FnMut(T, U) -> Tree) -> Tree + where + T: std::fmt::Debug + Clone, + U: std::fmt::Debug, + { + match (self, other) { + (Tree::Branch(self_trees), Tree::Branch(other_trees)) => { + assert_eq!(self_trees.len(), other_trees.len()); + let trees = self_trees.iter().zip(other_trees); + Tree::Branch(vecmap(trees, |(l, r)| l.map_both_helper(r, f))) + } + (Tree::Leaf(self_value), Tree::Leaf(other_value)) => f(self_value.clone(), other_value), + other => panic!("Found unexpected tree combination during SSA: {other:?}"), + } + } + /// Unwraps this Tree into the value of the leaf node. Panics if /// this Tree is a Branch pub(super) fn into_leaf(self) -> T { diff --git a/crates/noirc_frontend/src/ast/expression.rs b/crates/noirc_frontend/src/ast/expression.rs index 4d19d47d484..f3788919d85 100644 --- a/crates/noirc_frontend/src/ast/expression.rs +++ b/crates/noirc_frontend/src/ast/expression.rs @@ -261,6 +261,8 @@ impl BinaryOpKind { pub enum UnaryOp { Minus, Not, + MutableReference, + Dereference, } impl UnaryOp { @@ -479,6 +481,8 @@ impl Display for UnaryOp { match self { UnaryOp::Minus => write!(f, "-"), UnaryOp::Not => write!(f, "!"), + UnaryOp::MutableReference => write!(f, "&mut"), + UnaryOp::Dereference => write!(f, "*"), } } } diff --git a/crates/noirc_frontend/src/ast/mod.rs b/crates/noirc_frontend/src/ast/mod.rs index 24004e34ffa..2f11ecc1564 100644 --- a/crates/noirc_frontend/src/ast/mod.rs +++ b/crates/noirc_frontend/src/ast/mod.rs @@ -47,6 +47,9 @@ pub enum UnresolvedType { /// generic argument is not given. Vec(Vec, Span), + /// &mut T + MutableReference(Box), + // Note: Tuples have no visibility, instead each of their elements may have one. Tuple(Vec), @@ -116,6 +119,7 @@ impl std::fmt::Display for UnresolvedType { let args = vecmap(args, ToString::to_string); write!(f, "Vec<{}>", args.join(", ")) } + MutableReference(element) => write!(f, "&mut {element}"), Unit => write!(f, "()"), Error => write!(f, "error"), Unspecified => write!(f, "unspecified"), diff --git a/crates/noirc_frontend/src/ast/statement.rs b/crates/noirc_frontend/src/ast/statement.rs index 7f77716c5e4..f8093221cf8 100644 --- a/crates/noirc_frontend/src/ast/statement.rs +++ b/crates/noirc_frontend/src/ast/statement.rs @@ -403,6 +403,7 @@ pub enum LValue { Ident(Ident), MemberAccess { object: Box, field_name: Ident }, Index { array: Box, index: Expression }, + Dereference(Box), } #[derive(Debug, PartialEq, Eq, Clone)] @@ -445,6 +446,12 @@ impl LValue { collection: array.as_expression(span), index: index.clone(), })), + LValue::Dereference(lvalue) => { + ExpressionKind::Prefix(Box::new(crate::PrefixExpression { + operator: crate::UnaryOp::Dereference, + rhs: lvalue.as_expression(span), + })) + } }; Expression::new(kind, span) } @@ -487,6 +494,7 @@ impl Display for LValue { LValue::Ident(ident) => ident.fmt(f), LValue::MemberAccess { object, field_name } => write!(f, "{object}.{field_name}"), LValue::Index { array, index } => write!(f, "{array}[{index}]"), + LValue::Dereference(lvalue) => write!(f, "*{lvalue}"), } } } diff --git a/crates/noirc_frontend/src/hir/resolution/errors.rs b/crates/noirc_frontend/src/hir/resolution/errors.rs index 87257cbb842..80638897a59 100644 --- a/crates/noirc_frontend/src/hir/resolution/errors.rs +++ b/crates/noirc_frontend/src/hir/resolution/errors.rs @@ -60,6 +60,10 @@ pub enum ResolverError { ParserError(Box), #[error("Function is not defined in a contract yet sets its contract visibility")] ContractFunctionTypeInNormalFunction { span: Span }, + #[error("Cannot create a mutable reference to {variable}, it was declared to be immutable")] + MutableReferenceToImmutableVariable { variable: String, span: Span }, + #[error("Mutable references to array indices are unsupported")] + MutableReferenceToArrayElement { span: Span }, } impl ResolverError { @@ -258,6 +262,12 @@ impl From for Diagnostic { "Non-contract functions cannot be 'open'".into(), span, ), + ResolverError::MutableReferenceToImmutableVariable { variable, span } => { + Diagnostic::simple_error(format!("Cannot mutably reference the immutable variable {variable}"), format!("{variable} is immutable"), span) + }, + ResolverError::MutableReferenceToArrayElement { span } => { + Diagnostic::simple_error("Mutable references to array elements are currently unsupported".into(), "Try storing the element in a fresh variable first".into(), span) + }, } } } diff --git a/crates/noirc_frontend/src/hir/resolution/resolver.rs b/crates/noirc_frontend/src/hir/resolution/resolver.rs index 883b323ce66..72b03689ec2 100644 --- a/crates/noirc_frontend/src/hir/resolution/resolver.rs +++ b/crates/noirc_frontend/src/hir/resolution/resolver.rs @@ -34,7 +34,7 @@ use crate::{ }; use crate::{ ArrayLiteral, ContractFunctionType, Generics, LValue, NoirStruct, Path, Pattern, Shared, - StructType, Type, TypeBinding, TypeVariable, UnresolvedGenerics, UnresolvedType, + StructType, Type, TypeBinding, TypeVariable, UnaryOp, UnresolvedGenerics, UnresolvedType, UnresolvedTypeExpression, ERROR_IDENT, }; use fm::FileId; @@ -362,6 +362,9 @@ impl<'a> Resolver<'a> { }; Type::Vec(Box::new(arg)) } + UnresolvedType::MutableReference(element) => { + Type::MutableReference(Box::new(self.resolve_type_inner(*element, new_variables))) + } } } @@ -618,6 +621,7 @@ impl<'a> Resolver<'a> { let pattern = self.resolve_pattern(pattern, DefinitionKind::Local(None)); let typ = self.resolve_type_inner(typ, &mut generics); + parameters.push(Param(pattern, typ.clone(), visibility)); parameter_types.push(typ); } @@ -781,6 +785,7 @@ impl<'a> Resolver<'a> { } } Type::Vec(element) => Self::find_numeric_generics_in_type(element, found), + Type::MutableReference(element) => Self::find_numeric_generics_in_type(element, found), } } @@ -841,6 +846,10 @@ impl<'a> Resolver<'a> { let index = self.resolve_expression(index); HirLValue::Index { array, index, typ: Type::Error } } + LValue::Dereference(lvalue) => { + let lvalue = Box::new(self.resolve_lvalue(*lvalue)); + HirLValue::Dereference { lvalue, element_type: Type::Error } + } } } @@ -880,6 +889,13 @@ impl<'a> Resolver<'a> { ExpressionKind::Prefix(prefix) => { let operator = prefix.operator; let rhs = self.resolve_expression(prefix.rhs); + + if operator == UnaryOp::MutableReference { + if let Err(error) = verify_mutable_reference(self.interner, rhs) { + self.errors.push(error); + } + } + HirExpression::Prefix(HirPrefixExpression { operator, rhs }) } ExpressionKind::Infix(infix) => { @@ -1246,6 +1262,31 @@ impl<'a> Resolver<'a> { } } +/// Gives an error if a user tries to create a mutable reference +/// to an immutable variable. +pub fn verify_mutable_reference(interner: &NodeInterner, rhs: ExprId) -> Result<(), ResolverError> { + match interner.expression(&rhs) { + HirExpression::MemberAccess(member_access) => { + verify_mutable_reference(interner, member_access.lhs) + } + HirExpression::Index(_) => { + let span = interner.expr_span(&rhs); + Err(ResolverError::MutableReferenceToArrayElement { span }) + } + HirExpression::Ident(ident) => { + let definition = interner.definition(ident.id); + if !definition.mutable { + let span = interner.expr_span(&rhs); + let variable = definition.name.clone(); + Err(ResolverError::MutableReferenceToImmutableVariable { span, variable }) + } else { + Ok(()) + } + } + _ => Ok(()), + } +} + // XXX: These tests repeat a lot of code // what we should do is have test cases which are passed to a test harness // A test harness will allow for more expressive and readable tests diff --git a/crates/noirc_frontend/src/hir/type_check/errors.rs b/crates/noirc_frontend/src/hir/type_check/errors.rs index 83aac128470..b8d4cc66c18 100644 --- a/crates/noirc_frontend/src/hir/type_check/errors.rs +++ b/crates/noirc_frontend/src/hir/type_check/errors.rs @@ -2,6 +2,7 @@ use noirc_errors::CustomDiagnostic as Diagnostic; use noirc_errors::Span; use thiserror::Error; +use crate::hir::resolution::errors::ResolverError; use crate::hir_def::expr::HirBinaryOp; use crate::hir_def::types::Type; @@ -34,6 +35,8 @@ pub enum TypeCheckError { }, #[error("Cannot infer type of expression, type annotations needed before this point")] TypeAnnotationsNeeded { span: Span }, + #[error("{0}")] + ResolverError(ResolverError), } impl TypeCheckError { @@ -103,6 +106,7 @@ impl From for Diagnostic { "Type must be known at this point".to_string(), span, ), + TypeCheckError::ResolverError(error) => error.into(), } } } diff --git a/crates/noirc_frontend/src/hir/type_check/expr.rs b/crates/noirc_frontend/src/hir/type_check/expr.rs index 031adaf75ec..03f4ef09a92 100644 --- a/crates/noirc_frontend/src/hir/type_check/expr.rs +++ b/crates/noirc_frontend/src/hir/type_check/expr.rs @@ -2,12 +2,16 @@ use iter_extended::vecmap; use noirc_errors::Span; use crate::{ + hir::resolution::resolver::verify_mutable_reference, hir_def::{ - expr::{self, HirArrayLiteral, HirBinaryOp, HirExpression, HirLiteral}, + expr::{ + self, HirArrayLiteral, HirBinaryOp, HirExpression, HirLiteral, HirMethodCallExpression, + HirPrefixExpression, + }, types::Type, }, node_interner::{ExprId, FuncId}, - CompTime, Shared, TypeBinding, + CompTime, Shared, TypeBinding, UnaryOp, }; use super::{errors::TypeCheckError, TypeChecker}; @@ -74,13 +78,7 @@ impl<'interner> TypeChecker<'interner> { Type::Array(Box::new(length), Box::new(elem_type)) } HirLiteral::Bool(_) => Type::Bool(CompTime::new(self.interner)), - HirLiteral::Integer(_) => { - let id = self.interner.next_type_variable_id(); - Type::PolymorphicInteger( - CompTime::new(self.interner), - Shared::new(TypeBinding::Unbound(id)), - ) - } + HirLiteral::Integer(_) => Type::polymorphic_integer(self.interner), HirLiteral::Str(string) => { let len = Type::Constant(string.len() as u64); Type::String(Box::new(len)) @@ -112,13 +110,14 @@ impl<'interner> TypeChecker<'interner> { let span = self.interner.expr_span(expr_id); self.bind_function_type(function, args, span) } - HirExpression::MethodCall(method_call) => { + HirExpression::MethodCall(mut method_call) => { let object_type = self.check_expression(&method_call.object).follow_bindings(); let method_name = method_call.method.0.contents.as_str(); match self.lookup_method(object_type.clone(), method_name, expr_id) { Some(method_id) => { let mut args = vec![(object_type, self.interner.expr_span(&method_call.object))]; + let mut arg_types = vecmap(&method_call.arguments, |arg| { let typ = self.check_expression(arg); (typ, self.interner.expr_span(arg)) @@ -128,6 +127,18 @@ impl<'interner> TypeChecker<'interner> { // Desugar the method call into a normal, resolved function call // so that the backend doesn't need to worry about methods let location = method_call.location; + + // Automatically add `&mut` if the method expects a mutable reference and + // the object is not already one. + if method_id != FuncId::dummy_id() { + let func_meta = self.interner.function_meta(&method_id); + self.try_add_mutable_reference_to_object( + &mut method_call, + &func_meta.typ, + &mut args, + ); + } + let (function_id, function_call) = method_call.into_function_call(method_id, location, self.interner); @@ -216,14 +227,8 @@ impl<'interner> TypeChecker<'interner> { } HirExpression::Prefix(prefix_expr) => { let rhs_type = self.check_expression(&prefix_expr.rhs); - match prefix_operand_type_rules(&prefix_expr.operator, &rhs_type) { - Ok(typ) => typ, - Err(msg) => { - let rhs_span = self.interner.expr_span(&prefix_expr.rhs); - self.errors.push(TypeCheckError::Unstructured { msg, span: rhs_span }); - Type::Error - } - } + let span = self.interner.expr_span(&prefix_expr.rhs); + self.type_check_prefix_operand(&prefix_expr.operator, &rhs_type, span) } HirExpression::If(if_expr) => self.check_if_expr(&if_expr, expr_id), HirExpression::Constructor(constructor) => self.check_constructor(constructor, expr_id), @@ -256,6 +261,48 @@ impl<'interner> TypeChecker<'interner> { typ } + /// Check if the given method type requires a mutable reference to the object type, and check + /// if the given object type is already a mutable reference. If not, add one. + /// This is used to automatically transform a method call: `foo.bar()` into a function + /// call: `bar(&mut foo)`. + fn try_add_mutable_reference_to_object( + &mut self, + method_call: &mut HirMethodCallExpression, + function_type: &Type, + argument_types: &mut [(Type, noirc_errors::Span)], + ) { + let expected_object_type = match function_type { + Type::Function(args, _) => args.get(0), + Type::Forall(_, typ) => match typ.as_ref() { + Type::Function(args, _) => args.get(0), + typ => unreachable!("Unexpected type for function: {typ}"), + }, + typ => unreachable!("Unexpected type for function: {typ}"), + }; + + if let Some(expected_object_type) = expected_object_type { + if matches!(expected_object_type.follow_bindings(), Type::MutableReference(_)) { + let actual_type = argument_types[0].0.follow_bindings(); + + if let Err(error) = verify_mutable_reference(self.interner, method_call.object) { + self.errors.push(TypeCheckError::ResolverError(error)); + } + + if !matches!(actual_type, Type::MutableReference(_)) { + let new_type = Type::MutableReference(Box::new(actual_type)); + + argument_types[0].0 = new_type.clone(); + method_call.object = + self.interner.push_expr(HirExpression::Prefix(HirPrefixExpression { + operator: UnaryOp::MutableReference, + rhs: method_call.object, + })); + self.interner.push_expr_type(&method_call.object, new_type); + } + } + } + } + fn check_index_expression(&mut self, index_expr: expr::HirIndexExpression) -> Type { let index_type = self.check_expression(&index_expr.index); let span = self.interner.expr_span(&index_expr.index); @@ -462,13 +509,26 @@ impl<'interner> TypeChecker<'interner> { Type::Struct(typ, generics) } - fn check_member_access(&mut self, access: expr::HirMemberAccess, expr_id: ExprId) -> Type { + fn check_member_access(&mut self, mut access: expr::HirMemberAccess, expr_id: ExprId) -> Type { let lhs_type = self.check_expression(&access.lhs).follow_bindings(); let span = self.interner.expr_span(&expr_id); + let access_lhs = &mut access.lhs; + + let dereference_lhs = |this: &mut Self, lhs_type, element| { + let old_lhs = *access_lhs; + *access_lhs = this.interner.push_expr(HirExpression::Prefix(HirPrefixExpression { + operator: crate::UnaryOp::Dereference, + rhs: old_lhs, + })); + this.interner.push_expr_type(access_lhs, lhs_type); + this.interner.push_expr_type(&old_lhs, element); + }; - match self.check_field_access(&lhs_type, &access.rhs.0.contents, span) { + match self.check_field_access(&lhs_type, &access.rhs.0.contents, span, dereference_lhs) { Some((element_type, index)) => { self.interner.set_field_index(expr_id, index); + // We must update `access` in case we added any dereferences to it + self.interner.replace_expr(&expr_id, HirExpression::MemberAccess(access)); element_type } None => Type::Error, @@ -481,32 +541,50 @@ impl<'interner> TypeChecker<'interner> { /// /// This function is abstracted from check_member_access so that it can be shared between /// there and the HirLValue::MemberAccess case of check_lvalue. + /// + /// `dereference_lhs` is called when the lhs type is a Type::MutableReference that should be + /// automatically dereferenced so its field can be extracted. This function is expected to + /// perform any mutations necessary to wrap the lhs in a UnaryOp::Dereference prefix + /// expression. The second parameter of this function represents the lhs_type (which should + /// always be a Type::MutableReference if `dereference_lhs` is called) and the third + /// represents the element type. pub(super) fn check_field_access( &mut self, lhs_type: &Type, field_name: &str, span: Span, + mut dereference_lhs: impl FnMut(&mut Self, Type, Type), ) -> Option<(Type, usize)> { let lhs_type = lhs_type.follow_bindings(); - if let Type::Struct(s, args) = &lhs_type { - let s = s.borrow(); - if let Some((field, index)) = s.get_field(field_name, args) { - return Some((field, index)); + match &lhs_type { + Type::Struct(s, args) => { + let s = s.borrow(); + if let Some((field, index)) = s.get_field(field_name, args) { + return Some((field, index)); + } } - } else if let Type::Tuple(elements) = &lhs_type { - if let Ok(index) = field_name.parse::() { - let length = elements.len(); - if index < length { - return Some((elements[index].clone(), index)); - } else { - self.errors.push(TypeCheckError::Unstructured { - msg: format!("Index {index} is out of bounds for this tuple {lhs_type} of length {length}"), - span, - }); - return None; + Type::Tuple(elements) => { + if let Ok(index) = field_name.parse::() { + let length = elements.len(); + if index < length { + return Some((elements[index].clone(), index)); + } else { + self.errors.push(TypeCheckError::Unstructured { + msg: format!("Index {index} is out of bounds for this tuple {lhs_type} of length {length}"), + span, + }); + return None; + } } } + // If the lhs is a mutable reference we automatically transform + // lhs.field into (*lhs).field + Type::MutableReference(element) => { + dereference_lhs(self, lhs_type.clone(), element.as_ref().clone()); + return self.check_field_access(element, field_name, span, dereference_lhs); + } + _ => (), } // If we get here the type has no field named 'access.rhs'. @@ -825,22 +903,44 @@ impl<'interner> TypeChecker<'interner> { (lhs, rhs) => Err(make_error(format!("Unsupported types for binary operation: {lhs} and {rhs}"))), } } -} -fn prefix_operand_type_rules(op: &crate::UnaryOp, rhs_type: &Type) -> Result { - match op { - crate::UnaryOp::Minus => { - if !matches!(rhs_type, Type::Integer(..) | Type::Error) { - return Err("Only Integers can be used in a Minus expression".to_string()); + fn type_check_prefix_operand( + &mut self, + op: &crate::UnaryOp, + rhs_type: &Type, + span: Span, + ) -> Type { + let mut unify = |expected| { + rhs_type.unify(&expected, span, &mut self.errors, || TypeCheckError::TypeMismatch { + expr_typ: rhs_type.to_string(), + expected_typ: expected.to_string(), + expr_span: span, + }); + expected + }; + + match op { + crate::UnaryOp::Minus => unify(Type::polymorphic_integer(self.interner)), + crate::UnaryOp::Not => { + let rhs_type = rhs_type.follow_bindings(); + + // `!` can work on booleans or integers + if matches!(rhs_type, Type::Integer(..)) { + return rhs_type; + } + + unify(Type::Bool(CompTime::new(self.interner))) } - } - crate::UnaryOp::Not => { - if !matches!(rhs_type, Type::Integer(..) | Type::Bool(_) | Type::Error) { - return Err("Only Integers or Bool can be used in a Not expression".to_string()); + crate::UnaryOp::MutableReference => { + Type::MutableReference(Box::new(rhs_type.follow_bindings())) + } + crate::UnaryOp::Dereference => { + let element_type = Type::type_variable(self.interner.next_type_variable_id()); + unify(Type::MutableReference(Box::new(element_type.clone()))); + element_type } } } - Ok(rhs_type.clone()) } /// Taken from: https://stackoverflow.com/a/47127500 diff --git a/crates/noirc_frontend/src/hir/type_check/stmt.rs b/crates/noirc_frontend/src/hir/type_check/stmt.rs index e6ec71bfc10..5606547a568 100644 --- a/crates/noirc_frontend/src/hir/type_check/stmt.rs +++ b/crates/noirc_frontend/src/hir/type_check/stmt.rs @@ -1,5 +1,6 @@ -use noirc_errors::Span; +use noirc_errors::{Location, Span}; +use crate::hir_def::expr::HirIdent; use crate::hir_def::stmt::{ HirAssignStatement, HirConstrainStatement, HirLValue, HirLetStatement, HirPattern, HirStatement, }; @@ -123,8 +124,12 @@ impl<'interner> TypeChecker<'interner> { let typ = if ident.id == DefinitionId::dummy_id() { Type::Error } else { + // Do we need to store TypeBindings here? + let typ = self.interner.id_type(ident.id).instantiate(self.interner).0; + let typ = typ.follow_bindings(); + let definition = self.interner.definition(ident.id); - if !definition.mutable { + if !definition.mutable && !matches!(typ, Type::MutableReference(_)) { self.errors.push(TypeCheckError::Unstructured { msg: format!( "Variable {} must be mutable to be assigned to", @@ -133,19 +138,36 @@ impl<'interner> TypeChecker<'interner> { span: ident.location.span, }); } - // Do we need to store TypeBindings here? - self.interner.id_type(ident.id).instantiate(self.interner).0 + + typ }; (typ.clone(), HirLValue::Ident(ident, typ)) } HirLValue::MemberAccess { object, field_name, .. } => { let (lhs_type, object) = self.check_lvalue(*object, assign_span); - let object = Box::new(object); - + let mut object = Box::new(object); let span = field_name.span(); + + let object_ref = &mut object; + let (typ, field_index) = self - .check_field_access(&lhs_type, &field_name.0.contents, span) + .check_field_access( + &lhs_type, + &field_name.0.contents, + span, + move |_, _, element_type| { + // We must create a temporary value first to move out of object_ref before + // we eventually reassign to it. + let id = DefinitionId::dummy_id(); + let location = Location::new(span, fm::FileId::dummy()); + let tmp_value = + HirLValue::Ident(HirIdent { location, id }, Type::Error); + + let lvalue = std::mem::replace(object_ref, Box::new(tmp_value)); + *object_ref = Box::new(HirLValue::Dereference { lvalue, element_type }); + }, + ) .unwrap_or((Type::Error, 0)); let field_index = Some(field_index); @@ -185,6 +207,22 @@ impl<'interner> TypeChecker<'interner> { (typ.clone(), HirLValue::Index { array, index, typ }) } + HirLValue::Dereference { lvalue, element_type: _ } => { + let (reference_type, lvalue) = self.check_lvalue(*lvalue, assign_span); + let lvalue = Box::new(lvalue); + + let element_type = Type::type_variable(self.interner.next_type_variable_id()); + let expected_type = Type::MutableReference(Box::new(element_type.clone())); + reference_type.unify(&expected_type, assign_span, &mut self.errors, || { + TypeCheckError::TypeMismatch { + expected_typ: expected_type.to_string(), + expr_typ: reference_type.to_string(), + expr_span: assign_span, + } + }); + + (element_type.clone(), HirLValue::Dereference { lvalue, element_type }) + } } } diff --git a/crates/noirc_frontend/src/hir_def/stmt.rs b/crates/noirc_frontend/src/hir_def/stmt.rs index 04a6d9770fa..8cf9d82f580 100644 --- a/crates/noirc_frontend/src/hir_def/stmt.rs +++ b/crates/noirc_frontend/src/hir_def/stmt.rs @@ -60,7 +60,7 @@ impl HirPattern { pub fn field_count(&self) -> usize { match self { HirPattern::Identifier(_) => 0, - HirPattern::Mutable(_, _) => 0, + HirPattern::Mutable(pattern, _) => pattern.field_count(), HirPattern::Tuple(fields, _) => fields.len(), HirPattern::Struct(_, fields, _) => fields.len(), } @@ -97,4 +97,8 @@ pub enum HirLValue { index: ExprId, typ: Type, }, + Dereference { + lvalue: Box, + element_type: Type, + }, } diff --git a/crates/noirc_frontend/src/hir_def/types.rs b/crates/noirc_frontend/src/hir_def/types.rs index 7769b0b1153..adc239976c5 100644 --- a/crates/noirc_frontend/src/hir_def/types.rs +++ b/crates/noirc_frontend/src/hir_def/types.rs @@ -76,6 +76,9 @@ pub enum Type { /// .pop, and similar methods. Vec(Box), + /// &mut T + MutableReference(Box), + /// A type generic over the given type variables. /// Storing both the TypeVariableId and TypeVariable isn't necessary /// but it makes handling them both easier. The TypeVariableId should @@ -524,6 +527,11 @@ impl Type { Type::TypeVariable(Shared::new(TypeBinding::Unbound(id))) } + pub fn polymorphic_integer(interner: &mut NodeInterner) -> Type { + let id = interner.next_type_variable_id(); + Type::PolymorphicInteger(CompTime::new(interner), Shared::new(TypeBinding::Unbound(id))) + } + /// A bit of an awkward name for this function - this function returns /// true for type variables or polymorphic integers which are unbound. /// NamedGenerics will always be false as although they are bindable, @@ -593,6 +601,7 @@ impl Type { }) } Type::Vec(element) => element.contains_numeric_typevar(target_id), + Type::MutableReference(element) => element.contains_numeric_typevar(target_id), } } @@ -662,6 +671,9 @@ impl std::fmt::Display for Type { Type::Vec(element) => { write!(f, "Vec<{element}>") } + Type::MutableReference(element) => { + write!(f, "&mut {element}") + } } } } @@ -994,6 +1006,8 @@ impl Type { (Vec(elem_a), Vec(elem_b)) => elem_a.try_unify(elem_b, span), + (MutableReference(elem_a), MutableReference(elem_b)) => elem_a.try_unify(elem_b, span), + (other_a, other_b) => { if other_a == other_b { Ok(()) @@ -1127,6 +1141,22 @@ impl Type { (Vec(elem_a), Vec(elem_b)) => elem_a.is_subtype_of(elem_b, span), + // `T <: U => &mut T <: &mut U` would be unsound(*), so mutable + // references are never subtypes of each other. + // + // (*) Consider: + // ``` + // // Assume Dog <: Animal and Cat <: Animal + // let x: &mut Dog = ...; + // + // fn set_to_cat(y: &mut Animal) { + // *y = Cat; + // } + // + // set_to_cat(x); // uh-oh: x: Dog, yet it now holds a Cat + // ``` + (MutableReference(elem_a), MutableReference(elem_b)) => elem_a.try_unify(elem_b, span), + (other_a, other_b) => { if other_a == other_b { Ok(()) @@ -1197,6 +1227,7 @@ impl Type { Type::NamedGeneric(..) => unreachable!(), Type::Forall(..) => unreachable!(), Type::Function(_, _) => unreachable!(), + Type::MutableReference(_) => unreachable!("&mut cannot be used in the abi"), Type::Vec(_) => unreachable!("Vecs cannot be used in the abi"), } } @@ -1312,6 +1343,9 @@ impl Type { Type::Function(args, ret) } Type::Vec(element) => Type::Vec(Box::new(element.substitute(type_bindings))), + Type::MutableReference(element) => { + Type::MutableReference(Box::new(element.substitute(type_bindings))) + } Type::FieldElement(_) | Type::Integer(_, _, _) @@ -1342,6 +1376,7 @@ impl Type { args.iter().any(|arg| arg.occurs(target_id)) || ret.occurs(target_id) } Type::Vec(element) => element.occurs(target_id), + Type::MutableReference(element) => element.occurs(target_id), Type::FieldElement(_) | Type::Integer(_, _, _) @@ -1384,6 +1419,7 @@ impl Type { Function(args, ret) } Vec(element) => Vec(Box::new(element.follow_bindings())), + MutableReference(element) => MutableReference(Box::new(element.follow_bindings())), // Expect that this function should only be called on instantiated types Forall(..) => unreachable!(), diff --git a/crates/noirc_frontend/src/monomorphization/ast.rs b/crates/noirc_frontend/src/monomorphization/ast.rs index aaaf7c5bb2f..305e6635dfc 100644 --- a/crates/noirc_frontend/src/monomorphization/ast.rs +++ b/crates/noirc_frontend/src/monomorphization/ast.rs @@ -89,6 +89,7 @@ pub enum Literal { pub struct Unary { pub operator: crate::UnaryOp, pub rhs: Box, + pub result_type: Type, } pub type BinaryOp = BinaryOpKind; @@ -176,6 +177,7 @@ pub enum LValue { Ident(Ident), Index { array: Box, index: Box, element_type: Type, location: Location }, MemberAccess { object: Box, field_index: usize }, + Dereference { reference: Box, element_type: Type }, } pub type Parameters = Vec<(LocalId, /*mutable:*/ bool, /*name:*/ String, Type)>; @@ -208,6 +210,7 @@ pub enum Type { Unit, Tuple(Vec), Vec(Box), + MutableReference(Box), Function(/*args:*/ Vec, /*ret:*/ Box), } @@ -321,6 +324,7 @@ impl std::fmt::Display for Type { write!(f, "fn({}) -> {}", args.join(", "), ret) } Type::Vec(element) => write!(f, "Vec<{element}>"), + Type::MutableReference(element) => write!(f, "&mut {element}"), } } } diff --git a/crates/noirc_frontend/src/monomorphization/mod.rs b/crates/noirc_frontend/src/monomorphization/mod.rs index 8c588acbf90..7412a4124db 100644 --- a/crates/noirc_frontend/src/monomorphization/mod.rs +++ b/crates/noirc_frontend/src/monomorphization/mod.rs @@ -220,7 +220,6 @@ impl<'interner> Monomorphizer<'interner> { ) { match param { HirPattern::Identifier(ident) => { - //let value = self.expand_parameter(typ, new_params); let new_id = self.next_local_id(); let definition = self.interner.definition(ident.id); let name = definition.name.clone(); @@ -277,6 +276,7 @@ impl<'interner> Monomorphizer<'interner> { HirExpression::Prefix(prefix) => ast::Expression::Unary(ast::Unary { operator: prefix.operator, rhs: Box::new(self.expr(prefix.rhs)), + result_type: Self::convert_type(&self.interner.id_type(expr)), }), HirExpression::Infix(infix) => { @@ -382,7 +382,8 @@ impl<'interner> Monomorphizer<'interner> { | ast::Type::Integer(_, _) | ast::Type::Bool | ast::Type::Unit - | ast::Type::Function(_, _) => { + | ast::Type::Function(_, _) + | ast::Type::MutableReference(_) => { ast::Expression::Literal(ast::Literal::Array(ast::ArrayLiteral { contents: array_contents, element_type, @@ -428,7 +429,8 @@ impl<'interner> Monomorphizer<'interner> { | ast::Type::Integer(_, _) | ast::Type::Bool | ast::Type::Unit - | ast::Type::Function(_, _) => { + | ast::Type::Function(_, _) + | ast::Type::MutableReference(_) => { ast::Expression::Index(ast::Index { collection, index, element_type, location }) } @@ -700,6 +702,11 @@ impl<'interner> Monomorphizer<'interner> { ast::Type::Vec(Box::new(element)) } + HirType::MutableReference(element) => { + let element = Self::convert_type(element); + ast::Type::MutableReference(Box::new(element)) + } + HirType::Forall(_, _) | HirType::Constant(_) | HirType::Error => { unreachable!("Unexpected type {} found", typ) } @@ -714,7 +721,8 @@ impl<'interner> Monomorphizer<'interner> { | ast::Type::Integer(_, _) | ast::Type::Bool | ast::Type::Unit - | ast::Type::Function(_, _) => ast::Type::Array(length, Box::new(element)), + | ast::Type::Function(_, _) + | ast::Type::MutableReference(_) => ast::Type::Array(length, Box::new(element)), ast::Type::Tuple(elements) => { ast::Type::Tuple(vecmap(elements, |typ| Self::aos_to_soa_type(length, typ))) @@ -882,6 +890,13 @@ impl<'interner> Monomorphizer<'interner> { let element_type = Self::convert_type(&typ); (array, Some((index, element_type, location))) } + HirLValue::Dereference { lvalue, element_type } => { + let (reference, index) = self.lvalue(*lvalue); + let reference = Box::new(reference); + let element_type = Self::convert_type(&element_type); + let lvalue = ast::LValue::Dereference { reference, element_type }; + (lvalue, index) + } } } @@ -979,6 +994,12 @@ impl<'interner> Monomorphizer<'interner> { self.create_zeroed_function(parameter_types, ret_type) } ast::Type::Vec(_) => panic!("Cannot create a zeroed Vec value. This type is currently unimplemented and meant to be unusable outside of unconstrained functions"), + ast::Type::MutableReference(element) => { + use crate::UnaryOp::MutableReference; + let rhs = Box::new(self.zeroed_value_of_type(element)); + let result_type = typ.clone(); + ast::Expression::Unary(ast::Unary { rhs, result_type, operator: MutableReference }) + }, } } diff --git a/crates/noirc_frontend/src/monomorphization/printer.rs b/crates/noirc_frontend/src/monomorphization/printer.rs index 39c6db8734b..929a14e07da 100644 --- a/crates/noirc_frontend/src/monomorphization/printer.rs +++ b/crates/noirc_frontend/src/monomorphization/printer.rs @@ -262,6 +262,10 @@ impl AstPrinter { self.print_lvalue(object, f)?; write!(f, ".{field_index}") } + LValue::Dereference { reference, .. } => { + write!(f, "*")?; + self.print_lvalue(reference, f) + } } } } diff --git a/crates/noirc_frontend/src/node_interner.rs b/crates/noirc_frontend/src/node_interner.rs index b61886a170a..b2bbc0df0d0 100644 --- a/crates/noirc_frontend/src/node_interner.rs +++ b/crates/noirc_frontend/src/node_interner.rs @@ -621,6 +621,7 @@ fn get_type_method_key(typ: &Type) -> Option { Type::Tuple(_) => Some(Tuple), Type::Function(_, _) => Some(Function), Type::Vec(_) => Some(Vec), + Type::MutableReference(element) => get_type_method_key(element), // We do not support adding methods to these types Type::TypeVariable(_) diff --git a/crates/noirc_frontend/src/parser/parser.rs b/crates/noirc_frontend/src/parser/parser.rs index deaa045ccf0..ec3357c56d4 100644 --- a/crates/noirc_frontend/src/parser/parser.rs +++ b/crates/noirc_frontend/src/parser/parser.rs @@ -308,15 +308,35 @@ fn nothing() -> impl NoirParser { } fn self_parameter() -> impl NoirParser<(Pattern, UnresolvedType, AbiVisibility)> { - filter_map(move |span, found: Token| match found { - Token::Ident(ref word) if word == "self" => { - let ident = Ident::from_token(found, span); + let refmut_pattern = just(Token::Ampersand).then_ignore(keyword(Keyword::Mut)); + let mut_pattern = keyword(Keyword::Mut); + + refmut_pattern + .or(mut_pattern) + .map_with_span(|token, span| (token, span)) + .or_not() + .then(filter_map(move |span, found: Token| match found { + Token::Ident(ref word) if word == "self" => Ok(span), + _ => Err(ParserError::expected_label(ParsingRuleLabel::Parameter, found, span)), + })) + .map(|(pattern_keyword, span)| { + let ident = Ident::new("self".to_string(), span); let path = Path::from_single("Self".to_owned(), span); - let self_type = UnresolvedType::Named(path, vec![]); - Ok((Pattern::Identifier(ident), self_type, AbiVisibility::Private)) - } - _ => Err(ParserError::expected_label(ParsingRuleLabel::Parameter, found, span)), - }) + let mut self_type = UnresolvedType::Named(path, vec![]); + let mut pattern = Pattern::Identifier(ident); + + match pattern_keyword { + Some((Token::Ampersand, _)) => { + self_type = UnresolvedType::MutableReference(Box::new(self_type)); + } + Some((Token::Keyword(_), span)) => { + pattern = Pattern::Mutable(Box::new(pattern), span); + } + _ => (), + } + + (pattern, self_type, AbiVisibility::Private) + }) } fn implementation() -> impl NoirParser { @@ -598,12 +618,17 @@ where .delimited_by(just(Token::LeftBracket), just(Token::RightBracket)) .map(LValueRhs::Index); - l_ident.then(l_member_rhs.or(l_index).repeated()).foldl(|lvalue, rhs| match rhs { - LValueRhs::MemberAccess(field_name) => { - LValue::MemberAccess { object: Box::new(lvalue), field_name } - } - LValueRhs::Index(index) => LValue::Index { array: Box::new(lvalue), index }, - }) + let dereferences = just(Token::Star).repeated(); + + let lvalues = + l_ident.then(l_member_rhs.or(l_index).repeated()).foldl(|lvalue, rhs| match rhs { + LValueRhs::MemberAccess(field_name) => { + LValue::MemberAccess { object: Box::new(lvalue), field_name } + } + LValueRhs::Index(index) => LValue::Index { array: Box::new(lvalue), index }, + }); + + dereferences.then(lvalues).foldr(|_, lvalue| LValue::Dereference(Box::new(lvalue))) } fn parse_type<'a>() -> impl NoirParser + 'a { @@ -622,7 +647,8 @@ fn parse_type_inner( array_type(recursive_type_parser.clone()), tuple_type(recursive_type_parser.clone()), vec_type(recursive_type_parser.clone()), - function_type(recursive_type_parser), + function_type(recursive_type_parser.clone()), + mutable_reference_type(recursive_type_parser), )) } @@ -738,6 +764,16 @@ where .map(|(args, ret)| UnresolvedType::Function(args, Box::new(ret))) } +fn mutable_reference_type(type_parser: T) -> impl NoirParser +where + T: NoirParser, +{ + just(Token::Ampersand) + .ignore_then(keyword(Keyword::Mut)) + .ignore_then(type_parser) + .map(|element| UnresolvedType::MutableReference(Box::new(element))) +} + fn expression() -> impl ExprParser { recursive(|expr| expression_with_precedence(Precedence::Lowest, expr, false)) .labelled(ParsingRuleLabel::Expression) @@ -820,12 +856,17 @@ where P: ExprParser + 'a, { recursive(move |term_parser| { - choice((not(term_parser.clone()), negation(term_parser))) - .map_with_span(Expression::new) - // right-unary operators like a[0] or a.f bind more tightly than left-unary - // operators like - or !, so that !a[0] is parsed as !(a[0]). This is a bit - // awkward for casts so -a as i32 actually binds as -(a as i32). - .or(atom_or_right_unary(expr_parser)) + choice(( + not(term_parser.clone()), + negation(term_parser.clone()), + mutable_reference(term_parser.clone()), + dereference(term_parser), + )) + .map_with_span(Expression::new) + // right-unary operators like a[0] or a.f bind more tightly than left-unary + // operators like - or !, so that !a[0] is parsed as !(a[0]). This is a bit + // awkward for casts so -a as i32 actually binds as -(a as i32). + .or(atom_or_right_unary(expr_parser)) }) } @@ -1003,6 +1044,25 @@ where .map(|rhs| ExpressionKind::prefix(UnaryOp::Minus, rhs)) } +fn mutable_reference

(term_parser: P) -> impl NoirParser +where + P: ExprParser, +{ + just(Token::Ampersand) + .ignore_then(keyword(Keyword::Mut)) + .ignore_then(term_parser) + .map(|rhs| ExpressionKind::prefix(UnaryOp::MutableReference, rhs)) +} + +fn dereference

(term_parser: P) -> impl NoirParser +where + P: ExprParser, +{ + just(Token::Star) + .ignore_then(term_parser) + .map(|rhs| ExpressionKind::prefix(UnaryOp::Dereference, rhs)) +} + fn atom<'a, P>(expr_parser: P) -> impl NoirParser + 'a where P: ExprParser + 'a,