Skip to content

Commit

Permalink
feat(ssa refactor): Implement first-class references (#1849)
Browse files Browse the repository at this point in the history
* Explore work on references

* Cleanup

* Implement first-class references

* Fix frontend test

* Remove 'Mutability' struct, it is no longer needed

* Remove some extra lines

* Remove another function

* Revert another line

* Fix test again

* Fix a bug in mem2reg for nested references

* Fix inconsistent .eval during ssa-gen on assign statements

* Revert some code

* Add check for mutating immutable self objects
  • Loading branch information
jfecher authored Jul 5, 2023
1 parent d0894ad commit e5773e4
Show file tree
Hide file tree
Showing 27 changed files with 611 additions and 103 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
[package]
authors = [""]
compiler_version = "0.5.1"

[dependencies]
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
fn main() {
let mut x = 2;
add1(&mut x);
assert(x == 3);

let mut s = S { y: x };
s.add2();
assert(s.y == 5);

// Test that normal mutable variables are still copied
let mut a = 0;
mutate_copy(a);
assert(a == 0);

// Test something 3 allocations deep
let mut nested_allocations = Nested { y: &mut &mut 0 };
add1(*nested_allocations.y);
assert(**nested_allocations.y == 1);

// Test nested struct allocations with a mutable reference to an array.
let mut c = C {
foo: 0,
bar: &mut C2 {
array: &mut [1, 2],
},
};
*c.bar.array = [3, 4];
assert(*c.bar.array == [3, 4]);
}

fn add1(x: &mut Field) {
*x += 1;
}

struct S { y: Field }

struct Nested { y: &mut &mut Field }

struct C {
foo: Field,
bar: &mut C2,
}

struct C2 {
array: &mut [Field; 2]
}

impl S {
fn add2(&mut self) {
self.y += 2;
}
}

fn mutate_copy(mut a: Field) {
a = 7;
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ fn main(x: Field, y: Field) {

// Test mutating tuples
let mut mutable = ((0, 0), 1, 2, 3);
mutable.0 = pair;
mutable.0 = (x, y);
mutable.2 = 7;
assert(mutable.0.0 == 1);
assert(mutable.0.1 == 0);
Expand Down
1 change: 1 addition & 0 deletions crates/noirc_evaluator/src/ssa/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1207,6 +1207,7 @@ impl SsaContext {
}
}
Type::Array(..) => panic!("Cannot convert an array type {t} into an ObjectType since it is unknown which array it refers to"),
Type::MutableReference(..) => panic!("Mutable reference types are unimplemented in the old ssa backend"),
Type::Unit => ObjectType::NotAnObject,
Type::Function(..) => ObjectType::Function,
Type::Tuple(_) => todo!("Conversion to ObjectType is unimplemented for tuples"),
Expand Down
10 changes: 10 additions & 0 deletions crates/noirc_evaluator/src/ssa/ssa_gen.rs
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,9 @@ impl IrGenerator {
self.context.new_instruction(op, rhs_type)
}
UnaryOp::Not => self.context.new_instruction(Operation::Not(rhs), rhs_type),
UnaryOp::MutableReference | UnaryOp::Dereference => {
unimplemented!("Mutable references are unimplemented in the old ssa backend")
}
}
}

Expand Down Expand Up @@ -248,6 +251,9 @@ impl IrGenerator {
let val = self.find_variable(ident_def).unwrap();
val.get_field_member(*field_index)
}
LValue::Dereference { .. } => {
unreachable!("Mutable references are unsupported in the old ssa backend")
}
}
}

Expand All @@ -256,6 +262,7 @@ impl IrGenerator {
LValue::Ident(ident) => &ident.definition,
LValue::Index { array, .. } => Self::lvalue_ident_def(array.as_ref()),
LValue::MemberAccess { object, .. } => Self::lvalue_ident_def(object.as_ref()),
LValue::Dereference { reference, .. } => Self::lvalue_ident_def(reference.as_ref()),
}
}

Expand Down Expand Up @@ -462,6 +469,9 @@ impl IrGenerator {
let value = val.get_field_member(*field_index).clone();
self.assign_pattern(&value, rhs)?;
}
LValue::Dereference { .. } => {
unreachable!("Mutable references are unsupported in the old ssa backend")
}
}
Ok(Value::dummy())
}
Expand Down
3 changes: 2 additions & 1 deletion crates/noirc_evaluator/src/ssa/value.rs
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,8 @@ impl Value {
| Type::String(..)
| Type::Integer(..)
| Type::Bool
| Type::Field => Value::Node(*iter.next().unwrap()),
| Type::Field
| Type::MutableReference(_) => Value::Node(*iter.next().unwrap()),
}
}

Expand Down
1 change: 1 addition & 0 deletions crates/noirc_evaluator/src/ssa_refactor/ir/function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ pub(crate) enum RuntimeType {
// Unconstrained function, to be compiled to brillig and executed by the Brillig VM
Brillig,
}

/// A function holds a list of instructions.
/// These instructions are further grouped into Basic blocks
///
Expand Down
40 changes: 28 additions & 12 deletions crates/noirc_evaluator/src/ssa_refactor/opt/mem2reg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,20 +64,39 @@ impl PerBlockContext {
dfg: &mut DataFlowGraph,
) -> HashSet<AllocId> {
let mut protected_allocations = HashSet::new();
let mut loads_to_substitute = HashMap::new();
let block = &dfg[self.block_id];

// Maps Load instruction id -> value to replace the result of the load with
let mut loads_to_substitute = HashMap::new();

// Maps Load result id -> value to replace the result of the load with
let mut load_values_to_substitute = HashMap::new();

for instruction_id in block.instructions() {
match &dfg[*instruction_id] {
Instruction::Store { address, value } => {
self.last_stores.insert(*address, *value);
Instruction::Store { mut address, value } => {
if let Some(value) = load_values_to_substitute.get(&address) {
address = *value;
}

self.last_stores.insert(address, *value);
self.store_ids.push(*instruction_id);
}
Instruction::Load { address } => {
if let Some(last_value) = self.last_stores.get(address) {
Instruction::Load { mut address } => {
if let Some(value) = load_values_to_substitute.get(&address) {
address = *value;
}

if let Some(last_value) = self.last_stores.get(&address) {
let result_value = *dfg
.instruction_results(*instruction_id)
.first()
.expect("ICE: Load instructions should have single result");

loads_to_substitute.insert(*instruction_id, *last_value);
load_values_to_substitute.insert(result_value, *last_value);
} else {
protected_allocations.insert(*address);
protected_allocations.insert(address);
}
}
Instruction::Call { arguments, .. } => {
Expand All @@ -103,12 +122,9 @@ impl PerBlockContext {
}

// Substitute load result values
for (instruction_id, new_value) in &loads_to_substitute {
let result_value = *dfg
.instruction_results(*instruction_id)
.first()
.expect("ICE: Load instructions should have single result");
dfg.set_value_from_id(result_value, *new_value);
for (result_value, new_value) in load_values_to_substitute {
let result_value = dfg.resolve(result_value);
dfg.set_value_from_id(result_value, new_value);
}

// Delete load instructions
Expand Down
40 changes: 35 additions & 5 deletions crates/noirc_evaluator/src/ssa_refactor/ssa_gen/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -165,12 +165,17 @@ impl<'a> FunctionContext<'a> {

// This helper is needed because we need to take f by mutable reference,
// otherwise we cannot move it multiple times each loop of vecmap.
fn map_type_helper<T>(typ: &ast::Type, f: &mut impl FnMut(Type) -> T) -> Tree<T> {
fn map_type_helper<T>(typ: &ast::Type, f: &mut dyn FnMut(Type) -> T) -> Tree<T> {
match typ {
ast::Type::Tuple(fields) => {
Tree::Branch(vecmap(fields, |field| Self::map_type_helper(field, f)))
}
ast::Type::Unit => Tree::empty(),
// A mutable reference wraps each element into a reference.
// This can be multiple values if the element type is a tuple.
ast::Type::MutableReference(element) => {
Self::map_type_helper(element, &mut |_| f(Type::Reference))
}
other => Tree::Leaf(f(Self::convert_non_tuple_type(other))),
}
}
Expand Down Expand Up @@ -201,6 +206,11 @@ impl<'a> FunctionContext<'a> {
ast::Type::Unit => panic!("convert_non_tuple_type called on a unit type"),
ast::Type::Tuple(_) => panic!("convert_non_tuple_type called on a tuple: {typ}"),
ast::Type::Function(_, _) => Type::Function,
ast::Type::MutableReference(element) => {
// Recursive call to panic if element is a tuple
Self::convert_non_tuple_type(element);
Type::Reference
}

// How should we represent Vecs?
// Are they a struct of array + length + capacity?
Expand Down Expand Up @@ -473,9 +483,21 @@ impl<'a> FunctionContext<'a> {
let object_lvalue = Box::new(object_lvalue);
LValue::MemberAccess { old_object, object_lvalue, index: *field_index }
}
ast::LValue::Dereference { reference, .. } => {
let (reference, _) = self.extract_current_value_recursive(reference);
LValue::Dereference { reference }
}
}
}

pub(super) fn dereference(&mut self, values: &Values, element_type: &ast::Type) -> Values {
let element_types = Self::convert_type(element_type);
values.map_both(element_types, |value, element_type| {
let reference = value.eval(self);
self.builder.insert_load(reference, element_type).into()
})
}

/// Compile the given identifier as a reference - ie. avoid calling .eval()
fn ident_lvalue(&self, ident: &ast::Ident) -> Values {
match &ident.definition {
Expand Down Expand Up @@ -516,16 +538,19 @@ impl<'a> FunctionContext<'a> {
let element = Self::get_field_ref(&old_object, *index).clone();
(element, LValue::MemberAccess { old_object, object_lvalue, index: *index })
}
ast::LValue::Dereference { reference, element_type } => {
let (reference, _) = self.extract_current_value_recursive(reference);
let dereferenced = self.dereference(&reference, element_type);
(dereferenced, LValue::Dereference { reference })
}
}
}

/// Assigns a new value to the given LValue.
/// The LValue can be created via a previous call to extract_current_value.
/// This method recurs on the given LValue to create a new value to assign an allocation
/// instruction within an LValue::Ident - see the comment on `extract_current_value` for more
/// details.
/// When first-class references are supported the nearest reference may be in any LValue
/// variant rather than just LValue::Ident.
/// instruction within an LValue::Ident or LValue::Dereference - see the comment on
/// `extract_current_value` for more details.
pub(super) fn assign_new_value(&mut self, lvalue: LValue, new_value: Values) {
match lvalue {
LValue::Ident(references) => self.assign(references, new_value),
Expand All @@ -538,6 +563,9 @@ impl<'a> FunctionContext<'a> {
let new_object = Self::replace_field(old_object, index, new_value);
self.assign_new_value(*object_lvalue, new_object);
}
LValue::Dereference { reference } => {
self.assign(reference, new_value);
}
}
}

Expand Down Expand Up @@ -705,8 +733,10 @@ impl SharedContext {
}

/// Used to remember the results of each step of extracting a value from an ast::LValue
#[derive(Debug)]
pub(super) enum LValue {
Ident(Values),
Index { old_array: ValueId, index: ValueId, array_lvalue: Box<LValue> },
MemberAccess { old_object: Values, index: usize, object_lvalue: Box<LValue> },
Dereference { reference: Values },
}
27 changes: 23 additions & 4 deletions crates/noirc_evaluator/src/ssa_refactor/ssa_gen/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ impl<'a> FunctionContext<'a> {
/// Codegen for identifiers
fn codegen_ident(&mut self, ident: &ast::Ident) -> Values {
match &ident.definition {
ast::Definition::Local(id) => self.lookup(*id).map(|value| value.eval(self).into()),
ast::Definition::Local(id) => self.lookup(*id),
ast::Definition::Function(id) => self.get_or_queue_function(*id),
ast::Definition::Oracle(name) => self.builder.import_foreign_function(name).into(),
ast::Definition::Builtin(name) | ast::Definition::LowLevel(name) => {
Expand Down Expand Up @@ -165,14 +165,33 @@ impl<'a> FunctionContext<'a> {
}

fn codegen_unary(&mut self, unary: &ast::Unary) -> Values {
let rhs = self.codegen_non_tuple_expression(&unary.rhs);
let rhs = self.codegen_expression(&unary.rhs);
match unary.operator {
noirc_frontend::UnaryOp::Not => self.builder.insert_not(rhs).into(),
noirc_frontend::UnaryOp::Not => {
let rhs = rhs.into_leaf().eval(self);
self.builder.insert_not(rhs).into()
}
noirc_frontend::UnaryOp::Minus => {
let rhs = rhs.into_leaf().eval(self);
let typ = self.builder.type_of_value(rhs);
let zero = self.builder.numeric_constant(0u128, typ);
self.builder.insert_binary(zero, BinaryOp::Sub, rhs).into()
}
noirc_frontend::UnaryOp::MutableReference => {
rhs.map(|rhs| {
match rhs {
value::Value::Normal(value) => {
let alloc = self.builder.insert_allocate();
self.builder.insert_store(alloc, value);
Tree::Leaf(value::Value::Normal(alloc))
}
// NOTE: The `.into()` here converts the Value::Mutable into
// a Value::Normal so it is no longer automatically dereferenced.
value::Value::Mutable(reference, _) => reference.into(),
}
})
}
noirc_frontend::UnaryOp::Dereference => self.dereference(&rhs, &unary.result_type),
}
}

Expand Down Expand Up @@ -343,13 +362,13 @@ impl<'a> FunctionContext<'a> {
/// Generate SSA for a function call. Note that calls to built-in functions
/// and intrinsics are also represented by the function call instruction.
fn codegen_call(&mut self, call: &ast::Call) -> Values {
let function = self.codegen_non_tuple_expression(&call.func);
let arguments = call
.arguments
.iter()
.flat_map(|argument| self.codegen_expression(argument).into_value_list(self))
.collect();

let function = self.codegen_non_tuple_expression(&call.func);
self.insert_call(function, arguments, &call.return_type)
}

Expand Down
30 changes: 30 additions & 0 deletions crates/noirc_evaluator/src/ssa_refactor/ssa_gen/value.rs
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,36 @@ impl<T> Tree<T> {
}
}

/// Map two trees alongside each other.
/// This asserts each tree has the same internal structure.
pub(super) fn map_both<U, R>(
&self,
other: Tree<U>,
mut f: impl FnMut(T, U) -> Tree<R>,
) -> Tree<R>
where
T: std::fmt::Debug + Clone,
U: std::fmt::Debug,
{
self.map_both_helper(other, &mut f)
}

fn map_both_helper<U, R>(&self, other: Tree<U>, f: &mut impl FnMut(T, U) -> Tree<R>) -> Tree<R>
where
T: std::fmt::Debug + Clone,
U: std::fmt::Debug,
{
match (self, other) {
(Tree::Branch(self_trees), Tree::Branch(other_trees)) => {
assert_eq!(self_trees.len(), other_trees.len());
let trees = self_trees.iter().zip(other_trees);
Tree::Branch(vecmap(trees, |(l, r)| l.map_both_helper(r, f)))
}
(Tree::Leaf(self_value), Tree::Leaf(other_value)) => f(self_value.clone(), other_value),
other => panic!("Found unexpected tree combination during SSA: {other:?}"),
}
}

/// Unwraps this Tree into the value of the leaf node. Panics if
/// this Tree is a Branch
pub(super) fn into_leaf(self) -> T {
Expand Down
Loading

0 comments on commit e5773e4

Please sign in to comment.