From f719307bfb95193e30c0f8f8d81ab210b30f3780 Mon Sep 17 00:00:00 2001 From: Jake Fecher Date: Wed, 3 May 2023 11:07:46 -0500 Subject: [PATCH 01/25] Start inlining pass --- crates/noirc_evaluator/src/ssa_refactor.rs | 1 + .../src/ssa_refactor/ir/basic_block.rs | 13 +- .../src/ssa_refactor/ir/dfg.rs | 44 ++++--- .../src/ssa_refactor/ir/map.rs | 12 ++ .../src/ssa_refactor/opt/inlining.rs | 124 ++++++++++++++++++ .../src/ssa_refactor/opt/mod.rs | 6 + .../src/ssa_refactor/ssa_gen/program.rs | 22 +++- 7 files changed, 196 insertions(+), 26 deletions(-) create mode 100644 crates/noirc_evaluator/src/ssa_refactor/opt/inlining.rs create mode 100644 crates/noirc_evaluator/src/ssa_refactor/opt/mod.rs diff --git a/crates/noirc_evaluator/src/ssa_refactor.rs b/crates/noirc_evaluator/src/ssa_refactor.rs index fc45071e579..a33393866df 100644 --- a/crates/noirc_evaluator/src/ssa_refactor.rs +++ b/crates/noirc_evaluator/src/ssa_refactor.rs @@ -8,5 +8,6 @@ #![allow(dead_code)] mod ir; +mod opt; mod ssa_builder; pub mod ssa_gen; diff --git a/crates/noirc_evaluator/src/ssa_refactor/ir/basic_block.rs b/crates/noirc_evaluator/src/ssa_refactor/ir/basic_block.rs index 8a3f74c4a64..8bf17119029 100644 --- a/crates/noirc_evaluator/src/ssa_refactor/ir/basic_block.rs +++ b/crates/noirc_evaluator/src/ssa_refactor/ir/basic_block.rs @@ -29,10 +29,10 @@ pub(crate) struct BasicBlock { pub(crate) type BasicBlockId = Id; impl BasicBlock { - /// Create a new BasicBlock with the given parameters. + /// Create a new BasicBlock with the given instructions. /// Parameters can also be added later via BasicBlock::add_parameter - pub(crate) fn new(parameters: Vec) -> Self { - Self { parameters, instructions: Vec::new(), terminator: None } + pub(crate) fn new(instructions: Vec) -> Self { + Self { parameters: Vec::new(), instructions, terminator: None } } /// Returns the parameters of this block @@ -57,6 +57,11 @@ impl BasicBlock { &self.instructions } + /// Retrieve a mutable reference to all instructions in this block. + pub(crate) fn instructions_mut(&mut self) -> &mut Vec { + &mut self.instructions + } + /// Sets the terminator instruction of this block. /// /// A properly-constructed block will always terminate with a TerminatorInstruction - @@ -90,7 +95,7 @@ impl BasicBlock { /// Removes the given instruction from this block if present or panics otherwise. pub(crate) fn remove_instruction(&mut self, instruction: InstructionId) { let index = - self.instructions.iter().position(|id| *id == instruction).unwrap_or_else(|| { + self.instructions.iter().rev().position(|id| *id == instruction).unwrap_or_else(|| { panic!("remove_instruction: No such instruction {instruction:?} in block") }); self.instructions.remove(index); diff --git a/crates/noirc_evaluator/src/ssa_refactor/ir/dfg.rs b/crates/noirc_evaluator/src/ssa_refactor/ir/dfg.rs index 67569c6a4c2..2ad6af991c4 100644 --- a/crates/noirc_evaluator/src/ssa_refactor/ir/dfg.rs +++ b/crates/noirc_evaluator/src/ssa_refactor/ir/dfg.rs @@ -13,7 +13,6 @@ use super::{ }; use acvm::FieldElement; -use iter_extended::vecmap; /// The DataFlowGraph contains most of the actual data in a function including /// its blocks, instructions, and values. This struct is largely responsible for @@ -69,22 +68,6 @@ impl DataFlowGraph { self.blocks.insert(BasicBlock::new(Vec::new())) } - /// Creates a new basic block with the given parameters. - /// After being created, the block is unreachable in the current function - /// until another block is made to jump to it. - pub(crate) fn make_block_with_parameters( - &mut self, - parameter_types: impl Iterator, - ) -> BasicBlockId { - self.blocks.insert_with_id(|entry_block| { - let parameters = vecmap(parameter_types.enumerate(), |(position, typ)| { - self.values.insert(Value::Param { block: entry_block, position, typ }) - }); - - BasicBlock::new(parameters) - }) - } - /// Get an iterator over references to each basic block within the dfg, paired with the basic /// block's id. /// @@ -279,6 +262,33 @@ impl DataFlowGraph { ) { self.blocks[block].set_terminator(terminator); } + + /// Splits the given block in two at the given instruction, returning the Id of the new block. + /// This will remove the given instruction and place every instruction after it into a new block + /// with the same terminator as the old block. The old block is modified to stop + /// before the instruction to remove and to unconditionally branch to the new block. + /// This function is useful during function inlining to remove the call instruction + /// while opening a spot at the end of the current block to insert instructions into. + /// + /// Example (before): + /// block1: a; b; c; d; e; jmp block5 + /// + /// After self.split_block_at(block1, c): + /// block1: a; b; jmp block2 + /// block2: d; e; jmp block5 + pub(crate) fn split_block_at(&mut self, block: BasicBlockId, instruction_to_remove: InstructionId) -> BasicBlockId { + let split_block = &mut self.blocks[block]; + + let mut instructions = split_block.instructions().iter(); + let index = instructions.position(|id| *id == instruction_to_remove).unwrap_or_else(|| { + panic!("No instruction found with id {instruction_to_remove:?} in block {block:?}") + }); + + let instructions = split_block.instructions_mut().drain(index..).collect(); + split_block.remove_instruction(instruction_to_remove); + + self.blocks.insert(BasicBlock::new(instructions)) + } } impl std::ops::Index for DataFlowGraph { diff --git a/crates/noirc_evaluator/src/ssa_refactor/ir/map.rs b/crates/noirc_evaluator/src/ssa_refactor/ir/map.rs index 14ea521359d..eee1b593edd 100644 --- a/crates/noirc_evaluator/src/ssa_refactor/ir/map.rs +++ b/crates/noirc_evaluator/src/ssa_refactor/ir/map.rs @@ -45,6 +45,18 @@ impl std::hash::Hash for Id { } } +impl PartialOrd for Id { + fn partial_cmp(&self, other: &Self) -> Option { + self.index.partial_cmp(&other.index) + } +} + +impl Ord for Id { + fn cmp(&self, other: &Self) -> std::cmp::Ordering { + self.index.cmp(&other.index) + } +} + impl Eq for Id {} impl PartialEq for Id { diff --git a/crates/noirc_evaluator/src/ssa_refactor/opt/inlining.rs b/crates/noirc_evaluator/src/ssa_refactor/opt/inlining.rs new file mode 100644 index 00000000000..9ccf1e79943 --- /dev/null +++ b/crates/noirc_evaluator/src/ssa_refactor/opt/inlining.rs @@ -0,0 +1,124 @@ +use std::collections::{HashMap, HashSet}; + +use iter_extended::vecmap; + +use crate::ssa_refactor::{ssa_gen::Ssa, ir::{instruction::Instruction, value::{ValueId, Value}, dfg::DataFlowGraph, function::FunctionId, basic_block::BasicBlockId}}; + +/// An arbitrary limit to the maximum number of recursive call +/// frames at any point in time. +const RECURSION_LIMIT: u32 = 1000; + +impl Ssa { + /// Inline all functions within the IR. + /// + /// In the case of recursive functions, this will attempt + /// to recursively inline until the RECURSION_LIMIT is reached. + /// + /// Functions are recursively inlined into main until either we finish + /// inlining all functions or we encounter a function whose function id is not known. + /// When the later happens, the call instruction is kept in addition to the function + /// it refers to. The function it refers to is kept unmodified without any inlining + /// changes. This is because if the function's id later becomes known by a later + /// pass, we would need to re-run all of inlining anyway to inline it, so we might + /// as well save the work for later instead of performing it twice. + pub(crate) fn inline_functions(&mut self) { + let main_function = self.main(); + let mut context = InlineContext::new(main_function.entry_block(), main_function.id()); + + let blocks = vecmap(main_function.dfg.basic_blocks_iter(), |(id, _)| id); + + for block in blocks { + let instructions = main_function.dfg[block].instructions(); + + let mut new_instructions = Vec::with_capacity(instructions.len()); + + for (index, instruction) in instructions.iter().copied().enumerate() { + match &main_function.dfg[instruction] { + Instruction::Call { func, arguments } => { + match context.get_function(*func, &main_function.dfg) { + Some(id) => { + main_function.dfg.split_block_at(block, instruction); + context.inline_function(self, id, arguments) + } + None => new_instructions.push(instruction), + } + }, + _ => new_instructions.push(instruction), + } + } + } + } +} + +struct InlineContext { + recursion_level: u32, + argument_values: Vec>, + current_block: BasicBlockId, + visited_blocks: HashSet, + functions_to_keep: HashSet, +} + +impl InlineContext { + /// Create a new context object for the function inlining pass. + /// This starts off with an empty mapping of instructions for main's parameters. + fn new(current_block: BasicBlockId, main: FunctionId) -> InlineContext { + let mut visited_blocks = HashSet::new(); + visited_blocks.insert(current_block); + + let mut functions_to_keep = HashSet::new(); + functions_to_keep.insert(main); + + Self { + recursion_level: 0, + argument_values: vec![HashMap::new()], + current_block, + visited_blocks, + functions_to_keep, + } + } + + fn current_function_arguments(&self) -> &HashMap { + self.argument_values.last() + .expect("Expected there to always be argument values for the current function being inlined") + } + + fn get_function(&self, mut id: ValueId, dfg: &DataFlowGraph) -> Option { + if let Some(new_id) = self.current_function_arguments().get(&id) { + id = *new_id; + } + + match dfg[id] { + Value::Function(id) => Some(id), + _ => None, + } + } + + fn inline_function(&mut self, ssa: &Ssa, id: FunctionId, arguments: &[ValueId]) { + let target_function = &ssa.functions[&id]; + let current_block = target_function.entry_block(); + + let parameters = target_function.dfg.block_parameters(current_block); + assert_eq!(parameters.len(), arguments.len()); + + let argument_map = parameters.iter().copied().zip(arguments.iter().copied()).collect(); + self.argument_values.push(argument_map); + + let instructions = target_function.dfg[current_block].instructions(); + + let mut new_instructions = Vec::with_capacity(instructions.len()); + + for id in instructions { + match &target_function.dfg[*id] { + Instruction::Call { func, arguments } => { + match self.get_function(*func, &target_function.dfg) { + Some(id) => self.inline_function(ssa, id, arguments), + None => new_instructions.push(*id), + } + }, + _ => new_instructions.push(*id), + } + } + + self.argument_values.pop(); + } +} diff --git a/crates/noirc_evaluator/src/ssa_refactor/opt/mod.rs b/crates/noirc_evaluator/src/ssa_refactor/opt/mod.rs new file mode 100644 index 00000000000..46ca7d443bc --- /dev/null +++ b/crates/noirc_evaluator/src/ssa_refactor/opt/mod.rs @@ -0,0 +1,6 @@ +//! This folder contains each optimization pass for the SSA IR. +//! +//! Each pass is generally expected to mutate the SSA IR into a gradually +//! simpler form until the IR only has a single function remaining with 1 block within it. +//! Generally, these passes are also expected to minimize the final amount of instructions. +mod inlining; diff --git a/crates/noirc_evaluator/src/ssa_refactor/ssa_gen/program.rs b/crates/noirc_evaluator/src/ssa_refactor/ssa_gen/program.rs index 99d49456210..22bff853f08 100644 --- a/crates/noirc_evaluator/src/ssa_refactor/ssa_gen/program.rs +++ b/crates/noirc_evaluator/src/ssa_refactor/ssa_gen/program.rs @@ -1,22 +1,34 @@ -use std::fmt::Display; +use std::{collections::BTreeMap, fmt::Display}; -use crate::ssa_refactor::ir::function::Function; +use iter_extended::btree_map; + +use crate::ssa_refactor::ir::function::{Function, FunctionId}; /// Contains the entire SSA representation of the program. +/// +/// It is expected that the main function is always the first +/// function in the functions vector. pub struct Ssa { - functions: Vec, + pub functions: BTreeMap, } impl Ssa { /// Create a new Ssa object from the given SSA functions pub fn new(functions: Vec) -> Self { - Self { functions } + Self { functions: btree_map(functions, |f| (f.id(), f)) } + } + + pub fn main(&mut self) -> &mut Function { + self.functions + .first_entry() + .expect("Expected there to be at least 1 SSA function") + .into_mut() } } impl Display for Ssa { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - for function in &self.functions { + for (_, function) in &self.functions { writeln!(f, "{function}")?; } Ok(()) From 658b7649b3681c65dbf3d47aff9463c8b3d547b9 Mon Sep 17 00:00:00 2001 From: Jake Fecher Date: Wed, 3 May 2023 14:47:33 -0500 Subject: [PATCH 02/25] Get most of pass working --- .../src/ssa_refactor/ir/instruction.rs | 39 ++- .../src/ssa_refactor/opt/inlining.rs | 266 +++++++++++++----- .../src/ssa_refactor/ssa_builder/mod.rs | 23 +- .../src/ssa_refactor/ssa_gen/program.rs | 2 +- 4 files changed, 259 insertions(+), 71 deletions(-) diff --git a/crates/noirc_evaluator/src/ssa_refactor/ir/instruction.rs b/crates/noirc_evaluator/src/ssa_refactor/ir/instruction.rs index 66f8b1e3b17..812d12b23a3 100644 --- a/crates/noirc_evaluator/src/ssa_refactor/ir/instruction.rs +++ b/crates/noirc_evaluator/src/ssa_refactor/ir/instruction.rs @@ -1,4 +1,5 @@ use acvm::acir::BlackBoxFunc; +use iter_extended::vecmap; use super::{basic_block::BasicBlockId, map::Id, types::Type, value::ValueId}; @@ -114,6 +115,42 @@ impl Instruction { Instruction::Load { .. } | Instruction::Call { .. } => InstructionResultType::Unknown, } } + + /// True if this instruction requires specifying the control type variables when + /// inserting this instruction into a DataFlowGraph. + pub(crate) fn requires_ctrl_typevars(&self) -> bool { + matches!(self.result_type(), InstructionResultType::Unknown) + } + + /// Maps each ValueId inside this instruction to a new ValueId, returning the new instruction. + /// Note that the returned instruction is fresh and will not have an assigned InstructionId + /// until it is manually inserted in a DataFlowGraph later. + pub(crate) fn map_values(&self, mut f: impl FnMut(ValueId) -> ValueId) -> Instruction { + match self { + Instruction::Binary(binary) => Instruction::Binary(Binary { + lhs: f(binary.lhs), + rhs: f(binary.rhs), + operator: binary.operator, + }), + Instruction::Cast(value, typ) => Instruction::Cast(f(*value), *typ), + Instruction::Not(value) => Instruction::Not(f(*value)), + Instruction::Truncate { value, bit_size, max_bit_size } => Instruction::Truncate { + value: f(*value), + bit_size: *bit_size, + max_bit_size: *max_bit_size, + }, + Instruction::Constrain(value) => Instruction::Constrain(f(*value)), + Instruction::Call { func, arguments } => Instruction::Call { + func: f(*func), + arguments: vecmap(arguments.iter().copied(), f), + }, + Instruction::Allocate { size } => Instruction::Allocate { size: *size }, + Instruction::Load { address } => Instruction::Load { address: f(*address) }, + Instruction::Store { address, value } => { + Instruction::Store { address: f(*address), value: f(*value) } + } + } + } } /// The possible return values for Instruction::return_types @@ -191,7 +228,7 @@ impl Binary { /// All binary operators are also only for numeric types. To implement /// e.g. equality for a compound type like a struct, one must add a /// separate Eq operation for each field and combine them later with And. -#[derive(Debug, PartialEq, Eq, Hash, Clone)] +#[derive(Debug, PartialEq, Eq, Hash, Copy, Clone)] pub(crate) enum BinaryOp { /// Addition of lhs + rhs. Add, diff --git a/crates/noirc_evaluator/src/ssa_refactor/opt/inlining.rs b/crates/noirc_evaluator/src/ssa_refactor/opt/inlining.rs index 9ccf1e79943..3bfcf934ad5 100644 --- a/crates/noirc_evaluator/src/ssa_refactor/opt/inlining.rs +++ b/crates/noirc_evaluator/src/ssa_refactor/opt/inlining.rs @@ -2,7 +2,16 @@ use std::collections::{HashMap, HashSet}; use iter_extended::vecmap; -use crate::ssa_refactor::{ssa_gen::Ssa, ir::{instruction::Instruction, value::{ValueId, Value}, dfg::DataFlowGraph, function::FunctionId, basic_block::BasicBlockId}}; +use crate::ssa_refactor::{ + ir::{ + basic_block::BasicBlockId, + function::{Function, FunctionId}, + instruction::{Instruction, InstructionId, TerminatorInstruction}, + value::{Value, ValueId}, + }, + ssa_builder::FunctionBuilder, + ssa_gen::Ssa, +}; /// An arbitrary limit to the maximum number of recursive call /// frames at any point in time. @@ -23,102 +32,229 @@ impl Ssa { /// as well save the work for later instead of performing it twice. pub(crate) fn inline_functions(&mut self) { let main_function = self.main(); - let mut context = InlineContext::new(main_function.entry_block(), main_function.id()); - - let blocks = vecmap(main_function.dfg.basic_blocks_iter(), |(id, _)| id); - - for block in blocks { - let instructions = main_function.dfg[block].instructions(); - - let mut new_instructions = Vec::with_capacity(instructions.len()); - - for (index, instruction) in instructions.iter().copied().enumerate() { - match &main_function.dfg[instruction] { - Instruction::Call { func, arguments } => { - match context.get_function(*func, &main_function.dfg) { - Some(id) => { - main_function.dfg.split_block_at(block, instruction); - context.inline_function(self, id, arguments) - } - None => new_instructions.push(instruction), - } - }, - _ => new_instructions.push(instruction), - } - } - } + let mut context = InlineContext::new(main_function); + let main_id = main_function.id(); + context.inline_function(self, main_id, &[]) } } +/// The context for the function inlining pass. +/// +/// This works using an internal FunctionBuilder to build a new main function from scratch. +/// Doing it this way properly handles importing instructions between functions and lets us +/// reuse the existing API at the cost of essentially cloning each of main's instructions. struct InlineContext { recursion_level: u32, - argument_values: Vec>, - current_block: BasicBlockId, - visited_blocks: HashSet, functions_to_keep: HashSet, + builder: FunctionBuilder, +} + +/// The per-function inlining context contains information that is only valid for one function. +/// For example, each function has its own DataFlowGraph, and thus each function needs a translation +/// layer to translate between BlockId to BlockId for the current function and the function to +/// inline into. The same goes for ValueIds, InstructionIds, and for storing other data like +/// parameter to argument mappings. +struct PerFunctionContext<'function> { + /// The source function is the function we're currently inlining into the function being built. + source_function: &'function Function, + + /// The shared inlining context for all functions. This notably contains the FunctionBuilder used + /// to build the function we're inlining into. + context: &'function mut InlineContext, + + /// Maps ValueIds in the function being inlined to the new ValueIds to use in the function + /// being inlined into. This mapping also contains the mapping from parameter values to + /// argument values. + values: HashMap, + + /// Maps BasicBlockIds in the function being inlined to the new BasicBlockIds to use in the + /// function being inlined into. + blocks: HashMap, + + /// Maps InstructionIds from the function being inlined to the function being inlined into. + instructions: HashMap, } impl InlineContext { /// Create a new context object for the function inlining pass. /// This starts off with an empty mapping of instructions for main's parameters. - fn new(current_block: BasicBlockId, main: FunctionId) -> InlineContext { - let mut visited_blocks = HashSet::new(); - visited_blocks.insert(current_block); + /// + /// main: The main function of the program to start inlining into. Only this function + /// and all functions still reachable from it will be returned when inlining is finished. + fn new(main: &Function) -> InlineContext { + Self { + recursion_level: 0, + builder: FunctionBuilder::new(main.name().to_owned(), main.id()), + functions_to_keep: HashSet::new(), + } + } - let mut functions_to_keep = HashSet::new(); - functions_to_keep.insert(main); + fn inline_function(&mut self, ssa: &Ssa, id: FunctionId, arguments: &[ValueId]) { + self.recursion_level += 1; + + if self.recursion_level > RECURSION_LIMIT { + panic!( + "Attempted to recur more than {RECURSION_LIMIT} times during function inlining." + ); + } + + let source_function = &ssa.functions[&id]; + let mut context = PerFunctionContext::new(self, source_function, arguments); + context.inline_blocks(ssa); + } +} + +impl<'function> PerFunctionContext<'function> { + /// Create a new PerFunctionContext from the source function. + /// The value and block mappings for this context are initially empty except + /// for containing the mapping between parameters in the source_function and + /// the arguments of the destination function. + fn new( + context: &'function mut InlineContext, + source_function: &'function Function, + arguments: &[ValueId], + ) -> Self { + let entry = source_function.entry_block(); + let parameters = source_function.dfg.block_parameters(entry); + assert_eq!(parameters.len(), arguments.len()); Self { - recursion_level: 0, - argument_values: vec![HashMap::new()], - current_block, - visited_blocks, - functions_to_keep, + context, + source_function, + values: parameters.iter().copied().zip(arguments.iter().copied()).collect(), + blocks: HashMap::new(), + instructions: HashMap::new(), } } - fn current_function_arguments(&self) -> &HashMap { - self.argument_values.last() - .expect("Expected there to always be argument values for the current function being inlined") + /// Translates a ValueId from the function being inlined to a ValueId of the function + /// being inlined into. Note that this expects value ids for all Value::Instruction and + /// Value::Param values are already handled as a result of previous inlining of instructions + /// and blocks respectively. If these assertions trigger it means a value is being used before + /// the instruction or block that defines the value is inserted. + fn translate_value(&mut self, id: ValueId) -> ValueId { + if let Some(value) = self.values.get(&id) { + return *value; + } + + let new_value = match &self.source_function.dfg[id] { + Value::Instruction { .. } => { + unreachable!("All Value::Instructions should already be known during inlining after creating the original inlined instruction") + } + Value::Param { .. } => { + unreachable!("All Value::Params should already be known from previous calls to translate_block") + } + Value::NumericConstant { constant, typ } => { + let value = self.source_function.dfg[*constant].value(); + self.context.builder.numeric_constant(value, *typ) + } + Value::Function(function) => self.context.builder.import_function(*function), + Value::Intrinsic(intrinsic) => self.context.builder.import_intrinsic_id(*intrinsic), + }; + + self.values.insert(id, new_value); + new_value } - fn get_function(&self, mut id: ValueId, dfg: &DataFlowGraph) -> Option { - if let Some(new_id) = self.current_function_arguments().get(&id) { - id = *new_id; + /// Translate a block id from the source function to one of the target function. + /// + /// If the block isn't already known, this will insert a new block into the target function + /// with the same parameter types as the source block. + fn translate_block(&mut self, id: BasicBlockId) -> BasicBlockId { + if let Some(block) = self.blocks.get(&id) { + return *block; + } + + // The block is not already present in the function being inlined into so we must create it. + // The block's instructions are not copied over as they will be copied later in inlining. + let new_block = self.context.builder.insert_block(); + let original_parameters = self.source_function.dfg.block_parameters(id); + + for parameter in original_parameters { + let typ = self.source_function.dfg.type_of_value(*parameter); + let new_parameter = self.context.builder.add_block_parameter(new_block, typ); + self.values.insert(*parameter, new_parameter); } - match dfg[id] { + new_block + } + + /// Try to retrieve the function referred to by the given Id. + /// Expects that the given ValueId belongs to the source_function. + /// + /// Returns None if the id is not known to refer to a function. + fn get_function(&mut self, mut id: ValueId) -> Option { + id = self.translate_value(id); + match self.context.builder[id] { Value::Function(id) => Some(id), _ => None, } } - fn inline_function(&mut self, ssa: &Ssa, id: FunctionId, arguments: &[ValueId]) { - let target_function = &ssa.functions[&id]; - let current_block = target_function.entry_block(); - - let parameters = target_function.dfg.block_parameters(current_block); - assert_eq!(parameters.len(), arguments.len()); + /// Inline all reachable blocks within the source_function into the destination function. + fn inline_blocks(&mut self, ssa: &Ssa) { + let mut seen_blocks = HashSet::new(); + let mut block_queue = vec![self.source_function.entry_block()]; - let argument_map = parameters.iter().copied().zip(arguments.iter().copied()).collect(); - self.argument_values.push(argument_map); + while let Some(block_id) = block_queue.pop() { + self.context.builder.switch_to_block(block_id); + seen_blocks.insert(block_id); - let instructions = target_function.dfg[current_block].instructions(); - - let mut new_instructions = Vec::with_capacity(instructions.len()); + self.inline_block(ssa, block_id); + self.handle_terminator_instruction(block_id); + } + } - for id in instructions { - match &target_function.dfg[*id] { - Instruction::Call { func, arguments } => { - match self.get_function(*func, &target_function.dfg) { - Some(id) => self.inline_function(ssa, id, arguments), - None => new_instructions.push(*id), - } + /// Inline each instruction in the given block into the function being inlined into. + /// This may recurse if it finds another function to inline if a call instruction is within this block. + fn inline_block(&mut self, ssa: &Ssa, block_id: BasicBlockId) { + let block = &self.source_function.dfg[block_id]; + for id in block.instructions() { + match &self.source_function.dfg[*id] { + Instruction::Call { func, arguments } => match self.get_function(*func) { + Some(id) => self.context.inline_function(ssa, id, arguments), + None => self.push_instruction(*id), }, - _ => new_instructions.push(*id), + _ => self.push_instruction(*id), } } + } + + /// Push the given instruction from the source_function into the current block of the + /// function being inlined into. + fn push_instruction(&mut self, id: InstructionId) { + let instruction = self.source_function.dfg[id].map_values(|id| self.translate_value(id)); + let results = self.source_function.dfg.instruction_results(id); + + let ctrl_typevars = instruction + .requires_ctrl_typevars() + .then(|| vecmap(results, |result| self.source_function.dfg.type_of_value(*result))); + + let new_results = self.context.builder.insert_instruction(instruction, ctrl_typevars); + + assert_eq!(results.len(), new_results.len()); + for (result, new_result) in results.iter().zip(new_results) { + self.values.insert(*result, *new_result); + } + } - self.argument_values.pop(); + /// Handle the given terminator instruction from the given source function block. + /// This will push any new blocks to the destination function as needed, add them + /// to the block queue, and set the terminator instruction for the current block. + fn handle_terminator_instruction(&mut self, block_id: BasicBlockId) { + match self.source_function.dfg[block_id].terminator() { + Some(TerminatorInstruction::Jmp { destination, arguments }) => { + let destination = self.translate_block(*destination); + let arguments = vecmap(arguments, |arg| self.translate_value(*arg)); + self.context.builder.terminate_with_jmp(destination, arguments); + } + Some(TerminatorInstruction::JmpIf { + condition, + then_destination, + else_destination, + }) => todo!(), + Some(TerminatorInstruction::Return { .. }) => (), + None => unreachable!("Block has no terminator instruction"), + } } } diff --git a/crates/noirc_evaluator/src/ssa_refactor/ssa_builder/mod.rs b/crates/noirc_evaluator/src/ssa_refactor/ssa_builder/mod.rs index aa67cbed583..840f2f1c4e7 100644 --- a/crates/noirc_evaluator/src/ssa_refactor/ssa_builder/mod.rs +++ b/crates/noirc_evaluator/src/ssa_refactor/ssa_builder/mod.rs @@ -9,7 +9,10 @@ use crate::ssa_refactor::ir::{ }; use super::{ - ir::instruction::{InstructionId, Intrinsic}, + ir::{ + basic_block::BasicBlock, + instruction::{InstructionId, Intrinsic}, + }, ssa_gen::Ssa, }; @@ -96,7 +99,7 @@ impl FunctionBuilder { } /// Inserts a new instruction at the end of the current block and returns its results - fn insert_instruction( + pub(crate) fn insert_instruction( &mut self, instruction: Instruction, ctrl_typevars: Option>, @@ -228,8 +231,12 @@ impl FunctionBuilder { /// Retrieve a value reference to the given intrinsic operation. /// Returns None if there is no intrinsic matching the given name. pub(crate) fn import_intrinsic(&mut self, name: &str) -> Option { - Intrinsic::lookup(name) - .map(|intrinsic| self.current_function.dfg.import_intrinsic(intrinsic)) + Intrinsic::lookup(name).map(|intrinsic| self.import_intrinsic_id(intrinsic)) + } + + /// Retrieve a value reference to the given intrinsic operation. + pub(crate) fn import_intrinsic_id(&mut self, intrinsic: Intrinsic) -> ValueId { + self.current_function.dfg.import_intrinsic(intrinsic) } /// Removes the given instruction from the current block or panics otherwise. @@ -253,3 +260,11 @@ impl std::ops::Index for FunctionBuilder { &self.current_function.dfg[id] } } + +impl std::ops::Index for FunctionBuilder { + type Output = BasicBlock; + + fn index(&self, id: BasicBlockId) -> &Self::Output { + &self.current_function.dfg[id] + } +} diff --git a/crates/noirc_evaluator/src/ssa_refactor/ssa_gen/program.rs b/crates/noirc_evaluator/src/ssa_refactor/ssa_gen/program.rs index 22bff853f08..3c9201100ae 100644 --- a/crates/noirc_evaluator/src/ssa_refactor/ssa_gen/program.rs +++ b/crates/noirc_evaluator/src/ssa_refactor/ssa_gen/program.rs @@ -8,7 +8,7 @@ use crate::ssa_refactor::ir::function::{Function, FunctionId}; /// /// It is expected that the main function is always the first /// function in the functions vector. -pub struct Ssa { +pub(crate) struct Ssa { pub functions: BTreeMap, } From a103b1e8a999d89ae9600a9b2eeb4f19b1a06e85 Mon Sep 17 00:00:00 2001 From: Jake Fecher Date: Thu, 4 May 2023 10:39:56 -0500 Subject: [PATCH 03/25] Finish function inlining pass --- .../src/ssa_refactor/ir/function.rs | 7 + .../src/ssa_refactor/ir/map.rs | 6 + .../src/ssa_refactor/opt/inlining.rs | 152 +++++++++++------- .../src/ssa_refactor/ssa_builder/mod.rs | 5 + .../src/ssa_refactor/ssa_gen/program.rs | 25 ++- 5 files changed, 134 insertions(+), 61 deletions(-) diff --git a/crates/noirc_evaluator/src/ssa_refactor/ir/function.rs b/crates/noirc_evaluator/src/ssa_refactor/ir/function.rs index 97c84942bb5..f37448462b7 100644 --- a/crates/noirc_evaluator/src/ssa_refactor/ir/function.rs +++ b/crates/noirc_evaluator/src/ssa_refactor/ir/function.rs @@ -2,6 +2,7 @@ use super::basic_block::BasicBlockId; use super::dfg::DataFlowGraph; use super::map::Id; use super::types::Type; +use super::value::ValueId; /// A function holds a list of instructions. /// These instructions are further grouped into Basic blocks @@ -54,6 +55,12 @@ impl Function { pub(crate) fn entry_block(&self) -> BasicBlockId { self.entry_block } + + /// Returns the parameters of this function. + /// The parameters will always match that of this function's entry block. + pub(crate) fn parameters(&self) -> &[ValueId] { + self.dfg.block_parameters(self.entry_block) + } } /// FunctionId is a reference for a function diff --git a/crates/noirc_evaluator/src/ssa_refactor/ir/map.rs b/crates/noirc_evaluator/src/ssa_refactor/ir/map.rs index eee1b593edd..43baf4430c7 100644 --- a/crates/noirc_evaluator/src/ssa_refactor/ir/map.rs +++ b/crates/noirc_evaluator/src/ssa_refactor/ir/map.rs @@ -284,6 +284,12 @@ pub(crate) struct AtomicCounter { } impl AtomicCounter { + /// Create a new counter starting after the given Id. + /// Use AtomicCounter::default() to start at zero. + pub(crate) fn starting_after(id: Id) -> Self { + Self { next: AtomicUsize::new(id.index + 1), _marker: Default::default() } + } + /// Return the next fresh id pub(crate) fn next(&self) -> Id { Id::new(self.next.fetch_add(1, Ordering::Relaxed)) diff --git a/crates/noirc_evaluator/src/ssa_refactor/opt/inlining.rs b/crates/noirc_evaluator/src/ssa_refactor/opt/inlining.rs index 44c7847b161..8a940ba53c5 100644 --- a/crates/noirc_evaluator/src/ssa_refactor/opt/inlining.rs +++ b/crates/noirc_evaluator/src/ssa_refactor/opt/inlining.rs @@ -1,3 +1,7 @@ +//! This module defines the function inlining pass for the SSA IR. +//! The purpose of this pass is to inline the instructions of each function call +//! within the function caller. If all function calls are known, there will only +//! be a single function remaining when the pass finishes. use std::collections::{HashMap, HashSet}; use iter_extended::vecmap; @@ -31,10 +35,7 @@ impl Ssa { /// pass, we would need to re-run all of inlining anyway to inline it, so we might /// as well save the work for later instead of performing it twice. pub(crate) fn inline_functions(self) -> Ssa { - let main_function = self.main(); - let mut context = InlineContext::new(main_function); - context.inline_main(&self); - context.finish(self) + InlineContext::new(&self).inline_all(self) } } @@ -45,8 +46,12 @@ impl Ssa { /// reuse the existing API at the cost of essentially cloning each of main's instructions. struct InlineContext { recursion_level: u32, - functions_to_keep: HashSet, builder: FunctionBuilder, + + /// True if we failed to inline at least one call. If this is still false when finishing + /// inlining we can remove all other functions from the resulting Ssa struct and keep only + /// the function that was inlined into. + failed_to_inline_a_call: bool, } /// The per-function inlining context contains information that is only valid for one function. @@ -77,26 +82,44 @@ struct PerFunctionContext<'function> { /// The TerminatorInstruction::Return in the source_function will be mapped to a jmp to /// this block in the destination function instead. return_destination: BasicBlockId, + + /// True if we're currently working on the main function. + inlining_main: bool, } impl InlineContext { /// Create a new context object for the function inlining pass. /// This starts off with an empty mapping of instructions for main's parameters. - /// - /// main: The main function of the program to start inlining into. Only this function - /// and all functions still reachable from it will be returned when inlining is finished. - fn new(main: &Function) -> InlineContext { - Self { - recursion_level: 0, - builder: FunctionBuilder::new(main.name().to_owned(), main.id()), - functions_to_keep: HashSet::new(), - } + /// The function being inlined into will always be the main function, although it is + /// actually a copy that is created in case the original main is still needed from a function + /// that could not be inlined calling it. + fn new(ssa: &Ssa) -> InlineContext { + let main_name = ssa.main().name().to_owned(); + let builder = FunctionBuilder::new(main_name, ssa.next_id.next()); + Self { builder, recursion_level: 0, failed_to_inline_a_call: false } } - fn inline_main(&mut self, ssa: &Ssa) { + /// Start inlining the main function and all functions reachable from it. + fn inline_all(mut self, ssa: Ssa) -> Ssa { let main = ssa.main(); - let mut context = PerFunctionContext::new(self, main); - context.inline_blocks(ssa); + let mut context = PerFunctionContext::new(&mut self, main); + context.inlining_main = true; + + // The main block is already inserted so we have to add it to context.blocks and add + // its parameters here. Failing to do so would cause context.translate_block() to add + // a fresh block for the entry block rather than use the existing one. + let entry_block = context.context.builder.current_function.entry_block(); + let original_parameters = context.source_function.parameters(); + + for parameter in original_parameters { + let typ = context.source_function.dfg.type_of_value(*parameter); + let new_parameter = context.context.builder.add_block_parameter(entry_block, typ); + context.values.insert(*parameter, new_parameter); + } + + context.blocks.insert(context.source_function.entry_block(), entry_block); + context.inline_blocks(&ssa); + self.finish(ssa) } /// Inlines a function into the current function and returns the translated return values @@ -113,30 +136,36 @@ impl InlineContext { let source_function = &ssa.functions[&id]; let mut context = PerFunctionContext::new(self, source_function); - let entry = source_function.entry_block(); - let parameters = source_function.dfg.block_parameters(entry); + let parameters = source_function.parameters(); assert_eq!(parameters.len(), arguments.len()); context.values = parameters.iter().copied().zip(arguments.iter().copied()).collect(); + let current_block = context.context.builder.current_block(); + context.blocks.insert(source_function.entry_block(), current_block); + context.inline_blocks(ssa); let return_destination = context.return_destination; self.builder.block_parameters(return_destination) } + /// Finish inlining and return the new Ssa struct with the inlined version of main. + /// If any functions failed to inline, they are not removed from the final Ssa struct. fn finish(self, mut ssa: Ssa) -> Ssa { let mut new_ssa = self.builder.finish(); - - for function_id in self.functions_to_keep { - let function = ssa - .functions - .remove(&function_id) - .unwrap_or_else(|| panic!("Expected to remove function with id {function_id}")); - - let existing = new_ssa.functions.insert(function.id(), function); - assert!(existing.is_none()); + assert_eq!(new_ssa.functions.len(), 1); + + // If we failed to inline any call, any function may still be reachable so we + // don't remove any from the final program. We could be more precise here and + // do a reachability analysis but it should be fine to keep the extra functions + // around longer if they are not called. + if self.failed_to_inline_a_call { + let new_main = new_ssa.functions.pop_first().unwrap().1; + ssa.main_id = new_main.id(); + ssa.functions.insert(new_main.id(), new_main); + ssa + } else { + new_ssa } - - new_ssa } } @@ -148,15 +177,14 @@ impl<'function> PerFunctionContext<'function> { fn new(context: &'function mut InlineContext, source_function: &'function Function) -> Self { // Create the block to return to but don't insert its parameters until we // have the types of the actual return values later. - let return_destination = context.builder.insert_block(); - Self { + return_destination: context.builder.insert_block(), context, source_function, blocks: HashMap::new(), instructions: HashMap::new(), values: HashMap::new(), - return_destination, + inlining_main: false, } } @@ -193,15 +221,22 @@ impl<'function> PerFunctionContext<'function> { /// /// If the block isn't already known, this will insert a new block into the target function /// with the same parameter types as the source block. - fn translate_block(&mut self, id: BasicBlockId) -> BasicBlockId { - if let Some(block) = self.blocks.get(&id) { + fn translate_block( + &mut self, + source_block: BasicBlockId, + block_queue: &mut Vec, + ) -> BasicBlockId { + if let Some(block) = self.blocks.get(&source_block) { return *block; } + // The block is not yet inlined, queue it + block_queue.push(source_block); + // The block is not already present in the function being inlined into so we must create it. // The block's instructions are not copied over as they will be copied later in inlining. let new_block = self.context.builder.insert_block(); - let original_parameters = self.source_function.dfg.block_parameters(id); + let original_parameters = self.source_function.dfg.block_parameters(source_block); for parameter in original_parameters { let typ = self.source_function.dfg.type_of_value(*parameter); @@ -209,6 +244,7 @@ impl<'function> PerFunctionContext<'function> { self.values.insert(*parameter, new_parameter); } + self.blocks.insert(source_block, new_block); new_block } @@ -220,7 +256,11 @@ impl<'function> PerFunctionContext<'function> { id = self.translate_value(id); match self.context.builder[id] { Value::Function(id) => Some(id), - _ => None, + Value::Intrinsic(_) => None, + _ => { + self.context.failed_to_inline_a_call = true; + None + } } } @@ -230,13 +270,15 @@ impl<'function> PerFunctionContext<'function> { let mut block_queue = vec![self.source_function.entry_block()]; while let Some(source_block_id) = block_queue.pop() { - let translated_block_id = self.translate_block(source_block_id); + let translated_block_id = self.translate_block(source_block_id, &mut block_queue); self.context.builder.switch_to_block(translated_block_id); seen_blocks.insert(source_block_id); self.inline_block(ssa, source_block_id); self.handle_terminator_instruction(source_block_id, &mut block_queue); } + + self.context.builder.switch_to_block(self.return_destination); } /// Inline each instruction in the given block into the function being inlined into. @@ -302,16 +344,9 @@ impl<'function> PerFunctionContext<'function> { block_id: BasicBlockId, block_queue: &mut Vec, ) { - let mut translate_and_queue_block = |this: &mut Self, block| { - if this.blocks.get(block).is_none() { - block_queue.push(*block); - } - this.translate_block(*block) - }; - match self.source_function.dfg[block_id].terminator() { Some(TerminatorInstruction::Jmp { destination, arguments }) => { - let destination = translate_and_queue_block(self, destination); + let destination = self.translate_block(*destination, block_queue); let arguments = vecmap(arguments, |arg| self.translate_value(*arg)); self.context.builder.terminate_with_jmp(destination, arguments); } @@ -321,19 +356,24 @@ impl<'function> PerFunctionContext<'function> { else_destination, }) => { let condition = self.translate_value(*condition); - let then_block = translate_and_queue_block(self, then_destination); - let else_block = translate_and_queue_block(self, else_destination); + let then_block = self.translate_block(*then_destination, block_queue); + let else_block = self.translate_block(*else_destination, block_queue); self.context.builder.terminate_with_jmpif(condition, then_block, else_block); } Some(TerminatorInstruction::Return { return_values }) => { - let return_values = vecmap(return_values, |value| { - // Add the block parameters for the return block here since we don't do - // it when inserting the block in PerFunctionContext::new - let typ = self.source_function.dfg.type_of_value(*value); - self.context.builder.add_block_parameter(self.return_destination, typ); - self.translate_value(*value) - }); - self.context.builder.terminate_with_jmp(self.return_destination, return_values); + let return_values = vecmap(return_values, |value| self.translate_value(*value)); + + if self.inlining_main { + self.context.builder.terminate_with_return(return_values); + } else { + for value in &return_values { + // Add the block parameters for the return block here since we don't do + // it when inserting the block in PerFunctionContext::new + let typ = self.context.builder.current_function.dfg.type_of_value(*value); + self.context.builder.add_block_parameter(self.return_destination, typ); + } + self.context.builder.terminate_with_jmp(self.return_destination, return_values); + } } None => unreachable!("Block has no terminator instruction"), } diff --git a/crates/noirc_evaluator/src/ssa_refactor/ssa_builder/mod.rs b/crates/noirc_evaluator/src/ssa_refactor/ssa_builder/mod.rs index ffb1f7210e4..f621503e59a 100644 --- a/crates/noirc_evaluator/src/ssa_refactor/ssa_builder/mod.rs +++ b/crates/noirc_evaluator/src/ssa_refactor/ssa_builder/mod.rs @@ -121,6 +121,11 @@ impl FunctionBuilder { self.current_block = block; } + /// Returns the block currently being inserted into + pub(crate) fn current_block(&mut self) -> BasicBlockId { + self.current_block + } + /// Insert an allocate instruction at the end of the current block, allocating the /// given amount of field elements. Returns the result of the allocate instruction, /// which is always a Reference to the allocated data. diff --git a/crates/noirc_evaluator/src/ssa_refactor/ssa_gen/program.rs b/crates/noirc_evaluator/src/ssa_refactor/ssa_gen/program.rs index 3304513b35b..7f4b9a8dd25 100644 --- a/crates/noirc_evaluator/src/ssa_refactor/ssa_gen/program.rs +++ b/crates/noirc_evaluator/src/ssa_refactor/ssa_gen/program.rs @@ -2,27 +2,42 @@ use std::{collections::BTreeMap, fmt::Display}; use iter_extended::btree_map; -use crate::ssa_refactor::ir::function::{Function, FunctionId}; +use crate::ssa_refactor::ir::{ + function::{Function, FunctionId}, + map::AtomicCounter, +}; /// Contains the entire SSA representation of the program. pub(crate) struct Ssa { pub(crate) functions: BTreeMap, + pub(crate) main_id: FunctionId, + pub(crate) next_id: AtomicCounter, } impl Ssa { - /// Create a new Ssa object from the given SSA functions + /// Create a new Ssa object from the given SSA functions. + /// The first function in this vector is expected to be the main function. pub(crate) fn new(functions: Vec) -> Self { - Self { functions: btree_map(functions, |f| (f.id(), f)) } + let main_id = functions.first().expect("Expected at least 1 SSA function").id(); + let mut max_id = main_id; + + let functions = btree_map(functions, |f| { + max_id = std::cmp::max(max_id, f.id()); + (f.id(), f) + }); + + Self { functions, main_id, next_id: AtomicCounter::starting_after(max_id) } } + /// Returns the entry-point function of the program pub(crate) fn main(&self) -> &Function { - self.functions.first_key_value().expect("Expected there to be at least 1 SSA function").1 + &self.functions[&self.main_id] } } impl Display for Ssa { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - for (_, function) in &self.functions { + for function in self.functions.values() { writeln!(f, "{function}")?; } Ok(()) From 70608fb4b12d74a617493b5b6ed34f1060edfc8f Mon Sep 17 00:00:00 2001 From: Jake Fecher Date: Thu, 4 May 2023 13:02:39 -0500 Subject: [PATCH 04/25] Add basic test --- .../src/ssa_refactor/ir/dom.rs | 4 +- .../src/ssa_refactor/opt/inlining.rs | 39 +++++++++++++++++++ 2 files changed, 41 insertions(+), 2 deletions(-) diff --git a/crates/noirc_evaluator/src/ssa_refactor/ir/dom.rs b/crates/noirc_evaluator/src/ssa_refactor/ir/dom.rs index 588b25cf91f..dba656838b8 100644 --- a/crates/noirc_evaluator/src/ssa_refactor/ir/dom.rs +++ b/crates/noirc_evaluator/src/ssa_refactor/ir/dom.rs @@ -278,7 +278,7 @@ mod tests { builder.switch_to_block(block3_id); builder.terminate_with_return(vec![]); - let mut ssa = builder.finish(); + let ssa = builder.finish(); let func = ssa.main(); let block0_id = func.entry_block(); @@ -382,7 +382,7 @@ mod tests { builder.switch_to_block(block2_id); builder.terminate_with_jmp(block1_id, vec![]); - let mut ssa = builder.finish(); + let ssa = builder.finish(); let func = ssa.main(); let block0_id = func.entry_block(); diff --git a/crates/noirc_evaluator/src/ssa_refactor/opt/inlining.rs b/crates/noirc_evaluator/src/ssa_refactor/opt/inlining.rs index 8a940ba53c5..6e7c9848748 100644 --- a/crates/noirc_evaluator/src/ssa_refactor/opt/inlining.rs +++ b/crates/noirc_evaluator/src/ssa_refactor/opt/inlining.rs @@ -379,3 +379,42 @@ impl<'function> PerFunctionContext<'function> { } } } + +#[cfg(test)] +mod test { + use crate::ssa_refactor::{ + ir::{map::Id, types::Type}, + ssa_builder::FunctionBuilder, + }; + + #[test] + fn basic_inlining() { + // fn foo { + // b0(): + // v0 = call bar() + // return v0 + // } + // fn bar { + // b0(): + // return 72 + // } + let foo_id = Id::test_new(0); + let mut builder = FunctionBuilder::new("foo".into(), foo_id); + + let bar_id = Id::test_new(1); + let bar = builder.import_function(bar_id); + let results = builder.insert_call(bar, Vec::new(), vec![Type::field()]).to_vec(); + builder.terminate_with_return(results); + + builder.new_function("bar".into(), bar_id); + let expected_return = 72u128; + let seventy_two = builder.field_constant(expected_return); + builder.terminate_with_return(vec![seventy_two]); + + let ssa = builder.finish(); + assert_eq!(ssa.functions.len(), 2); + + let inlined = ssa.inline_functions(); + assert_eq!(inlined.functions.len(), 1); + } +} From 4ec74d6cce3077cc2b12c68b00603ca80a5a2da6 Mon Sep 17 00:00:00 2001 From: Jake Fecher Date: Fri, 5 May 2023 13:49:00 -0500 Subject: [PATCH 05/25] Address PR comments --- .../src/ssa_refactor/ir/basic_block.rs | 2 ++ .../src/ssa_refactor/ir/dfg.rs | 31 ------------------- 2 files changed, 2 insertions(+), 31 deletions(-) diff --git a/crates/noirc_evaluator/src/ssa_refactor/ir/basic_block.rs b/crates/noirc_evaluator/src/ssa_refactor/ir/basic_block.rs index c110b722d6f..30526bc296e 100644 --- a/crates/noirc_evaluator/src/ssa_refactor/ir/basic_block.rs +++ b/crates/noirc_evaluator/src/ssa_refactor/ir/basic_block.rs @@ -96,6 +96,8 @@ impl BasicBlock { /// Removes the given instruction from this block if present or panics otherwise. pub(crate) fn remove_instruction(&mut self, instruction: InstructionId) { + // Iterate in reverse here as an optimization since remove_instruction is most + // often called to remove instructions at the end of a block. let index = self.instructions.iter().rev().position(|id| *id == instruction).unwrap_or_else(|| { panic!("remove_instruction: No such instruction {instruction:?} in block") diff --git a/crates/noirc_evaluator/src/ssa_refactor/ir/dfg.rs b/crates/noirc_evaluator/src/ssa_refactor/ir/dfg.rs index 5ad96e1bc1a..3ab345f06b9 100644 --- a/crates/noirc_evaluator/src/ssa_refactor/ir/dfg.rs +++ b/crates/noirc_evaluator/src/ssa_refactor/ir/dfg.rs @@ -263,37 +263,6 @@ impl DataFlowGraph { ) { self.blocks[block].set_terminator(terminator); } - - /// Splits the given block in two at the given instruction, returning the Id of the new block. - /// This will remove the given instruction and place every instruction after it into a new block - /// with the same terminator as the old block. The old block is modified to stop - /// before the instruction to remove and to unconditionally branch to the new block. - /// This function is useful during function inlining to remove the call instruction - /// while opening a spot at the end of the current block to insert instructions into. - /// - /// Example (before): - /// block1: a; b; c; d; e; jmp block5 - /// - /// After self.split_block_at(block1, c): - /// block1: a; b; jmp block2 - /// block2: d; e; jmp block5 - pub(crate) fn split_block_at( - &mut self, - block: BasicBlockId, - instruction_to_remove: InstructionId, - ) -> BasicBlockId { - let split_block = &mut self.blocks[block]; - - let mut instructions = split_block.instructions().iter(); - let index = instructions.position(|id| *id == instruction_to_remove).unwrap_or_else(|| { - panic!("No instruction found with id {instruction_to_remove:?} in block {block:?}") - }); - - let instructions = split_block.instructions_mut().drain(index..).collect(); - split_block.remove_instruction(instruction_to_remove); - - self.blocks.insert(BasicBlock::new(instructions)) - } } impl std::ops::Index for DataFlowGraph { From 59d945997cc85064af70ec0bee0e9071538fa899 Mon Sep 17 00:00:00 2001 From: Jake Fecher Date: Mon, 8 May 2023 15:29:07 -0500 Subject: [PATCH 06/25] Start block inlining --- .../src/ssa_refactor/ir/dfg.rs | 20 ++++----- .../src/ssa_refactor/opt/mod.rs | 1 + .../src/ssa_refactor/opt/simplify_cfg.rs | 44 +++++++++++++++++++ 3 files changed, 54 insertions(+), 11 deletions(-) create mode 100644 crates/noirc_evaluator/src/ssa_refactor/opt/simplify_cfg.rs diff --git a/crates/noirc_evaluator/src/ssa_refactor/ir/dfg.rs b/crates/noirc_evaluator/src/ssa_refactor/ir/dfg.rs index 3ab345f06b9..ad8f27951bb 100644 --- a/crates/noirc_evaluator/src/ssa_refactor/ir/dfg.rs +++ b/crates/noirc_evaluator/src/ssa_refactor/ir/dfg.rs @@ -100,23 +100,21 @@ impl DataFlowGraph { id } - /// Replace an instruction id with another. - /// - /// This function should generally be avoided if possible in favor of inserting new - /// instructions since it does not check whether the instruction results of the removed - /// instruction are still in use. Users of this function thus need to ensure the old - /// instruction's results are no longer in use or are otherwise compatible with the - /// new instruction's result count and types. - pub(crate) fn replace_instruction(&mut self, id: Id, instruction: Instruction) { - self.instructions[id] = instruction; - } - /// Insert a value into the dfg's storage and return an id to reference it. /// Until the value is used in an instruction it is unreachable. pub(crate) fn make_value(&mut self, value: Value) -> ValueId { self.values.insert(value) } + /// Replaces the value specified by the given ValueId with a new Value. + /// + /// This is the preferred method to call for optimizations simplifying + /// values since other instructions referring to the same ValueId need + /// not be modified to refer to a new ValueId. + pub(crate) fn set_value(&mut self, value_id: ValueId, new_value: Value) { + self.values[value_id] = new_value; + } + /// Creates a new constant value, or returns the Id to an existing one if /// one already exists. pub(crate) fn make_constant(&mut self, value: FieldElement, typ: Type) -> ValueId { diff --git a/crates/noirc_evaluator/src/ssa_refactor/opt/mod.rs b/crates/noirc_evaluator/src/ssa_refactor/opt/mod.rs index 46ca7d443bc..2701b0bb73c 100644 --- a/crates/noirc_evaluator/src/ssa_refactor/opt/mod.rs +++ b/crates/noirc_evaluator/src/ssa_refactor/opt/mod.rs @@ -4,3 +4,4 @@ //! simpler form until the IR only has a single function remaining with 1 block within it. //! Generally, these passes are also expected to minimize the final amount of instructions. mod inlining; +mod simplify_cfg; diff --git a/crates/noirc_evaluator/src/ssa_refactor/opt/simplify_cfg.rs b/crates/noirc_evaluator/src/ssa_refactor/opt/simplify_cfg.rs new file mode 100644 index 00000000000..d821fe81fbc --- /dev/null +++ b/crates/noirc_evaluator/src/ssa_refactor/opt/simplify_cfg.rs @@ -0,0 +1,44 @@ +use crate::ssa_refactor::{ssa_gen::Ssa, ir::{function::Function, cfg::ControlFlowGraph, basic_block::BasicBlockId}}; + + +impl Ssa { + pub(crate) fn simplify_cfg(mut self) -> Ssa { + for function in self.functions.values_mut() { + simplify_function_cfg(function); + } + self + } +} + +fn simplify_function_cfg(function: &mut Function) { + let current_block = function.entry_block(); + simplify_function_cfg_recursive(function, current_block); +} + +fn simplify_function_cfg_recursive(function: &mut Function, current_block: BasicBlockId) { + let successors: Vec<_> = function.dfg[current_block].successors().collect(); + + if successors.len() == 1 { + let source_block = successors[0]; + inline_instructions_from_block(function, current_block, source_block); + simplify_function_cfg_recursive(function, current_block); + } else { + for block in successors { + simplify_function_cfg_recursive(function, block); + } + } +} + +/// TODO: Translate block parameters +fn inline_instructions_from_block(function: &mut Function, dest_block: BasicBlockId, source_block: BasicBlockId) { + let instructions = function.dfg[source_block].instructions().to_vec(); + + // We cannot directly append each instruction since we need to substitute the + // block parameter values. + for instruction in instructions { + function.dfg.insert_instruction_in_block(dest_block, instruction); + } + + let terminator = function.dfg[source_block].terminator().expect("Expected each block during the simplify_cfg optimization to have a terminator instruction").clone(); + function.dfg.set_block_terminator(dest_block, terminator); +} From 4f426736a5f6999a0cf2138f2a4e2ca3b6169785 Mon Sep 17 00:00:00 2001 From: Jake Fecher Date: Tue, 9 May 2023 13:51:56 -0500 Subject: [PATCH 07/25] Add basic instruction simplification --- crates/noirc_evaluator/src/ssa_refactor.rs | 4 +- .../src/ssa_refactor/ir/dfg.rs | 64 ++++++- .../src/ssa_refactor/ir/instruction.rs | 180 +++++++++++++++++- .../src/ssa_refactor/opt/inlining.rs | 18 +- .../src/ssa_refactor/ssa_builder/mod.rs | 20 +- 5 files changed, 259 insertions(+), 27 deletions(-) diff --git a/crates/noirc_evaluator/src/ssa_refactor.rs b/crates/noirc_evaluator/src/ssa_refactor.rs index 34061227336..8257a2b0171 100644 --- a/crates/noirc_evaluator/src/ssa_refactor.rs +++ b/crates/noirc_evaluator/src/ssa_refactor.rs @@ -25,7 +25,9 @@ pub mod ssa_gen; /// form and performing optimizations there. When finished, /// convert the final SSA into ACIR and return it. pub fn optimize_into_acir(program: Program) -> Acir { - ssa_gen::generate_ssa(program).inline_functions().into_acir() + ssa_gen::generate_ssa(program) + .inline_functions() + .into_acir() } /// Compiles the Program into ACIR and applies optimizations to the arithmetic gates /// This is analogous to `ssa:create_circuit` and this method is called when one wants diff --git a/crates/noirc_evaluator/src/ssa_refactor/ir/dfg.rs b/crates/noirc_evaluator/src/ssa_refactor/ir/dfg.rs index 3ab345f06b9..fe79bce4a92 100644 --- a/crates/noirc_evaluator/src/ssa_refactor/ir/dfg.rs +++ b/crates/noirc_evaluator/src/ssa_refactor/ir/dfg.rs @@ -100,15 +100,21 @@ impl DataFlowGraph { id } - /// Replace an instruction id with another. - /// - /// This function should generally be avoided if possible in favor of inserting new - /// instructions since it does not check whether the instruction results of the removed - /// instruction are still in use. Users of this function thus need to ensure the old - /// instruction's results are no longer in use or are otherwise compatible with the - /// new instruction's result count and types. - pub(crate) fn replace_instruction(&mut self, id: Id, instruction: Instruction) { - self.instructions[id] = instruction; + /// Inserts a new instruction at the end of the given block and returns its results + pub(crate) fn insert_instruction( + &mut self, + instruction: Instruction, + block: BasicBlockId, + ctrl_typevars: Option>, + ) -> InsertInstructionResult { + match instruction.simplify(self) { + Some(simplification) => InsertInstructionResult::SimplifiedTo(simplification), + None => { + let id = self.make_instruction(instruction, ctrl_typevars); + self.insert_instruction_in_block(block, id); + InsertInstructionResult::Results(self.instruction_results(id)) + } + } } /// Insert a value into the dfg's storage and return an id to reference it. @@ -300,6 +306,46 @@ impl std::ops::IndexMut for DataFlowGraph { } } +// The result of calling DataFlowGraph::insert_instruction can +// be a list of results or a single ValueId if the instruction was simplified +// to an existing value. +pub(crate) enum InsertInstructionResult<'dfg> { + Results(&'dfg [ValueId]), + SimplifiedTo(ValueId), + InstructionRemoved, +} + +impl<'dfg> InsertInstructionResult<'dfg> { + /// Retrieve the first (and expected to be the only) result. + pub(crate) fn first(&self) -> ValueId { + match self { + InsertInstructionResult::SimplifiedTo(value) => *value, + InsertInstructionResult::Results(results) => results[0], + InsertInstructionResult::InstructionRemoved => panic!("Instruction was removed, no results"), + } + } + + /// Return all the results contained in the internal results array. + /// This is used for instructions returning multiple results that were + /// not simplified - like function calls. + pub(crate) fn results(&self) -> &'dfg [ValueId] { + match self { + InsertInstructionResult::Results(results) => results, + InsertInstructionResult::SimplifiedTo(_) => panic!("InsertInstructionResult::results called on a simplified instruction"), + InsertInstructionResult::InstructionRemoved => panic!("InsertInstructionResult::results called on a removed instruction"), + } + } + + /// Returns the amount of ValueIds contained + pub(crate) fn len(&self) -> usize { + match self { + InsertInstructionResult::SimplifiedTo(_) => 1, + InsertInstructionResult::Results(results) => results.len(), + InsertInstructionResult::InstructionRemoved => 0, + } + } +} + #[cfg(test)] mod tests { use super::DataFlowGraph; diff --git a/crates/noirc_evaluator/src/ssa_refactor/ir/instruction.rs b/crates/noirc_evaluator/src/ssa_refactor/ir/instruction.rs index 812d12b23a3..b5dac74f248 100644 --- a/crates/noirc_evaluator/src/ssa_refactor/ir/instruction.rs +++ b/crates/noirc_evaluator/src/ssa_refactor/ir/instruction.rs @@ -1,7 +1,7 @@ -use acvm::acir::BlackBoxFunc; +use acvm::{acir::BlackBoxFunc, FieldElement}; use iter_extended::vecmap; -use super::{basic_block::BasicBlockId, map::Id, types::Type, value::ValueId}; +use super::{basic_block::BasicBlockId, map::Id, types::Type, value::{ValueId, Value}, dfg::DataFlowGraph}; /// Reference to an instruction /// @@ -151,6 +151,46 @@ impl Instruction { } } } + + pub(crate) fn simplify(&self, dfg: &mut DataFlowGraph) -> Option { + match self { + Instruction::Binary(binary) => binary.simplify(dfg), + Instruction::Cast(value, typ) => (*typ == dfg.type_of_value(*value)).then_some(*value), + Instruction::Not(value) => { + match &dfg[*value] { + // Limit optimizing ! on constants to only booleans. If we tried it on fields, + // there is no Not on FieldElement, so we'd need to convert between u128. This + // would be incorrect however since the extra bits on the field would not be flipped. + Value::NumericConstant { constant, typ } if *typ == Type::bool() => { + let value = dfg[*constant].value().is_zero() as u128; + Some(dfg.make_constant(value.into(), Type::bool())) + }, + Value::Instruction { instruction, .. } => { + // !!v => v + match &dfg[*instruction] { + Instruction::Not(value) => Some(*value), + _ => None, + } + }, + _ => None, + } + }, + Instruction::Constrain(value) => { + if let Some(constant) = dfg.get_numeric_constant(*value) { + if constant.is_one() { + // "simplify" to a unit literal that will just be thrown away anyway + return Some(dfg.make_constant(0u128.into(), Type::Unit)); + } + } + None + }, + Instruction::Truncate { .. } => None, + Instruction::Call { .. } => None, + Instruction::Allocate { .. } => None, + Instruction::Load { .. } => None, + Instruction::Store { .. } => None, + } + } } /// The possible return values for Instruction::return_types @@ -219,6 +259,142 @@ impl Binary { _ => InstructionResultType::Operand(self.lhs), } } + + fn simplify(&self, dfg: &mut DataFlowGraph) -> Option { + let lhs = dfg.get_numeric_constant(self.lhs); + let rhs = dfg.get_numeric_constant(self.rhs); + let operand_type = dfg.type_of_value(self.lhs); + + if let (Some(lhs), Some(rhs)) = (lhs, rhs) { + return self.eval_constants(dfg, lhs, rhs, operand_type); + } + + let lhs_is_zero = lhs.map_or(false, |lhs| lhs.is_zero()); + let rhs_is_zero = rhs.map_or(false, |rhs| rhs.is_zero()); + + let lhs_is_one = lhs.map_or(false, |lhs| lhs.is_one()); + let rhs_is_one = rhs.map_or(false, |rhs| rhs.is_one()); + + match self.operator { + BinaryOp::Add => { + if lhs_is_zero { + return Some(self.rhs); + } + if rhs_is_zero { + return Some(self.lhs); + } + }, + BinaryOp::Sub => { + if rhs_is_zero { + return Some(self.lhs); + } + }, + BinaryOp::Mul => { + if lhs_is_one { + return Some(self.rhs); + } + if rhs_is_one { + return Some(self.lhs); + } + }, + BinaryOp::Div => { + if rhs_is_one { + return Some(self.lhs); + } + }, + BinaryOp::Mod => { + if rhs_is_one { + return Some(self.lhs); + } + }, + BinaryOp::Eq => { + if self.lhs == self.rhs { + return Some(dfg.make_constant(FieldElement::one(), Type::bool())); + } + }, + BinaryOp::Lt => { + if self.lhs == self.rhs { + return Some(dfg.make_constant(FieldElement::zero(), Type::bool())); + } + }, + BinaryOp::And => { + if lhs_is_zero || rhs_is_zero { + return Some(dfg.make_constant(FieldElement::zero(), operand_type)); + } + }, + BinaryOp::Or => { + if lhs_is_zero { + return Some(self.rhs); + } + if rhs_is_zero { + return Some(self.lhs); + } + }, + BinaryOp::Xor => (), + BinaryOp::Shl => { + if rhs_is_zero { + return Some(self.lhs); + } + }, + BinaryOp::Shr => { + if rhs_is_zero { + return Some(self.lhs); + } + }, + } + None + } + + fn eval_constants(&self, dfg: &mut DataFlowGraph, lhs: FieldElement, rhs: FieldElement, operand_type: Type) -> Option> { + let value = match self.operator { + BinaryOp::Add => lhs + rhs, + BinaryOp::Sub => lhs - rhs, + BinaryOp::Mul => lhs * rhs, + BinaryOp::Div => lhs / rhs, + BinaryOp::Eq => (lhs == rhs).into(), + BinaryOp::Lt => (lhs < rhs).into(), + + // The rest of the operators we must try to convert to u128 first + BinaryOp::Mod => self.eval_constant_u128_operations(lhs, rhs)?, + BinaryOp::And => self.eval_constant_u128_operations(lhs, rhs)?, + BinaryOp::Or => self.eval_constant_u128_operations(lhs, rhs)?, + BinaryOp::Xor => self.eval_constant_u128_operations(lhs, rhs)?, + BinaryOp::Shl => self.eval_constant_u128_operations(lhs, rhs)?, + BinaryOp::Shr => self.eval_constant_u128_operations(lhs, rhs)?, + }; + // TODO: Keep original type of constant + Some(dfg.make_constant(value, operand_type)) + } + + /// Try to evaluate the given operands as u128s for operators that are only valid on u128s, + /// like the bitwise operators and modulus. + fn eval_constant_u128_operations(&self, lhs: FieldElement, rhs: FieldElement) -> Option { + let lhs = lhs.try_into_u128()?; + let rhs = rhs.try_into_u128()?; + match self.operator { + BinaryOp::Mod => Some((lhs % rhs).into()), + BinaryOp::And => Some((lhs & rhs).into()), + BinaryOp::Or => Some((lhs | rhs).into()), + BinaryOp::Shr => Some((lhs >> rhs).into()), + // Check for overflow and return None if anything does overflow + BinaryOp::Shl => { + let rhs = rhs.try_into().ok()?; + lhs.checked_shl(rhs).map(Into::into) + } + + // Converting a field xor to a u128 xor would be incorrect since we wouldn't have the + // extra bits of the field. So we don't optimize it here. + BinaryOp::Xor => None, + + op @ (BinaryOp::Add + | BinaryOp::Sub + | BinaryOp::Mul + | BinaryOp::Div + | BinaryOp::Eq + | BinaryOp::Lt) => panic!("eval_constant_u128_operations invalid for {op:?} use eval_constants instead"), + + } + } } /// Binary Operations allowed in the IR. diff --git a/crates/noirc_evaluator/src/ssa_refactor/opt/inlining.rs b/crates/noirc_evaluator/src/ssa_refactor/opt/inlining.rs index 6e7c9848748..69e25fda87f 100644 --- a/crates/noirc_evaluator/src/ssa_refactor/opt/inlining.rs +++ b/crates/noirc_evaluator/src/ssa_refactor/opt/inlining.rs @@ -11,7 +11,7 @@ use crate::ssa_refactor::{ basic_block::BasicBlockId, function::{Function, FunctionId}, instruction::{Instruction, InstructionId, TerminatorInstruction}, - value::{Value, ValueId}, + value::{Value, ValueId}, dfg::InsertInstructionResult, }, ssa_builder::FunctionBuilder, ssa_gen::Ssa, @@ -306,6 +306,7 @@ impl<'function> PerFunctionContext<'function> { ) { let old_results = self.source_function.dfg.instruction_results(call_id); let new_results = self.context.inline_function(ssa, function, arguments); + let new_results = InsertInstructionResult::Results(new_results); Self::insert_new_instruction_results(&mut self.values, old_results, new_results); } @@ -328,11 +329,20 @@ impl<'function> PerFunctionContext<'function> { fn insert_new_instruction_results( values: &mut HashMap, old_results: &[ValueId], - new_results: &[ValueId], + new_results: InsertInstructionResult, ) { assert_eq!(old_results.len(), new_results.len()); - for (old_result, new_result) in old_results.iter().zip(new_results) { - values.insert(*old_result, *new_result); + + match new_results { + InsertInstructionResult::SimplifiedTo(new_result) => { + values.insert(old_results[0], new_result); + }, + InsertInstructionResult::Results(new_results) => { + for (old_result, new_result) in old_results.iter().zip(new_results) { + values.insert(*old_result, *new_result); + } + }, + InsertInstructionResult::InstructionRemoved => (), } } diff --git a/crates/noirc_evaluator/src/ssa_refactor/ssa_builder/mod.rs b/crates/noirc_evaluator/src/ssa_refactor/ssa_builder/mod.rs index f621503e59a..4029757d7de 100644 --- a/crates/noirc_evaluator/src/ssa_refactor/ssa_builder/mod.rs +++ b/crates/noirc_evaluator/src/ssa_refactor/ssa_builder/mod.rs @@ -11,7 +11,7 @@ use crate::ssa_refactor::ir::{ use super::{ ir::{ basic_block::BasicBlock, - instruction::{InstructionId, Intrinsic}, + instruction::{InstructionId, Intrinsic}, dfg::InsertInstructionResult, }, ssa_gen::Ssa, }; @@ -108,10 +108,8 @@ impl FunctionBuilder { &mut self, instruction: Instruction, ctrl_typevars: Option>, - ) -> &[ValueId] { - let id = self.current_function.dfg.make_instruction(instruction, ctrl_typevars); - self.current_function.dfg.insert_instruction_in_block(self.current_block, id); - self.current_function.dfg.instruction_results(id) + ) -> InsertInstructionResult { + self.current_function.dfg.insert_instruction(instruction, self.current_block, ctrl_typevars) } /// Switch to inserting instructions in the given block. @@ -130,7 +128,7 @@ impl FunctionBuilder { /// given amount of field elements. Returns the result of the allocate instruction, /// which is always a Reference to the allocated data. pub(crate) fn insert_allocate(&mut self, size_to_allocate: u32) -> ValueId { - self.insert_instruction(Instruction::Allocate { size: size_to_allocate }, None)[0] + self.insert_instruction(Instruction::Allocate { size: size_to_allocate }, None).first() } /// Insert a Load instruction at the end of the current block, loading from the given offset @@ -147,7 +145,7 @@ impl FunctionBuilder { type_to_load: Type, ) -> ValueId { address = self.insert_binary(address, BinaryOp::Add, offset); - self.insert_instruction(Instruction::Load { address }, Some(vec![type_to_load]))[0] + self.insert_instruction(Instruction::Load { address }, Some(vec![type_to_load])).first() } /// Insert a Store instruction at the end of the current block, storing the given element @@ -166,19 +164,19 @@ impl FunctionBuilder { rhs: ValueId, ) -> ValueId { let instruction = Instruction::Binary(Binary { lhs, rhs, operator }); - self.insert_instruction(instruction, None)[0] + self.insert_instruction(instruction, None).first() } /// Insert a not instruction at the end of the current block. /// Returns the result of the instruction. pub(crate) fn insert_not(&mut self, rhs: ValueId) -> ValueId { - self.insert_instruction(Instruction::Not(rhs), None)[0] + self.insert_instruction(Instruction::Not(rhs), None).first() } /// Insert a cast instruction at the end of the current block. /// Returns the result of the cast instruction. pub(crate) fn insert_cast(&mut self, value: ValueId, typ: Type) -> ValueId { - self.insert_instruction(Instruction::Cast(value, typ), None)[0] + self.insert_instruction(Instruction::Cast(value, typ), None).first() } /// Insert a constrain instruction at the end of the current block. @@ -194,7 +192,7 @@ impl FunctionBuilder { arguments: Vec, result_types: Vec, ) -> &[ValueId] { - self.insert_instruction(Instruction::Call { func, arguments }, Some(result_types)) + self.insert_instruction(Instruction::Call { func, arguments }, Some(result_types)).results() } /// Terminates the current block with the given terminator instruction From 19f8c1b83f8cc2a72dde029671980e0df1eb5c64 Mon Sep 17 00:00:00 2001 From: Jake Fecher Date: Tue, 9 May 2023 13:55:31 -0500 Subject: [PATCH 08/25] Cargo fmt --- crates/noirc_evaluator/src/ssa_refactor.rs | 4 +- .../src/ssa_refactor/ir/dfg.rs | 12 +++- .../src/ssa_refactor/ir/instruction.rs | 57 ++++++++++++------- .../src/ssa_refactor/opt/inlining.rs | 7 ++- .../src/ssa_refactor/ssa_builder/mod.rs | 3 +- 5 files changed, 53 insertions(+), 30 deletions(-) diff --git a/crates/noirc_evaluator/src/ssa_refactor.rs b/crates/noirc_evaluator/src/ssa_refactor.rs index 8257a2b0171..34061227336 100644 --- a/crates/noirc_evaluator/src/ssa_refactor.rs +++ b/crates/noirc_evaluator/src/ssa_refactor.rs @@ -25,9 +25,7 @@ pub mod ssa_gen; /// form and performing optimizations there. When finished, /// convert the final SSA into ACIR and return it. pub fn optimize_into_acir(program: Program) -> Acir { - ssa_gen::generate_ssa(program) - .inline_functions() - .into_acir() + ssa_gen::generate_ssa(program).inline_functions().into_acir() } /// Compiles the Program into ACIR and applies optimizations to the arithmetic gates /// This is analogous to `ssa:create_circuit` and this method is called when one wants diff --git a/crates/noirc_evaluator/src/ssa_refactor/ir/dfg.rs b/crates/noirc_evaluator/src/ssa_refactor/ir/dfg.rs index fe79bce4a92..fc15f3e2168 100644 --- a/crates/noirc_evaluator/src/ssa_refactor/ir/dfg.rs +++ b/crates/noirc_evaluator/src/ssa_refactor/ir/dfg.rs @@ -321,7 +321,9 @@ impl<'dfg> InsertInstructionResult<'dfg> { match self { InsertInstructionResult::SimplifiedTo(value) => *value, InsertInstructionResult::Results(results) => results[0], - InsertInstructionResult::InstructionRemoved => panic!("Instruction was removed, no results"), + InsertInstructionResult::InstructionRemoved => { + panic!("Instruction was removed, no results") + } } } @@ -331,8 +333,12 @@ impl<'dfg> InsertInstructionResult<'dfg> { pub(crate) fn results(&self) -> &'dfg [ValueId] { match self { InsertInstructionResult::Results(results) => results, - InsertInstructionResult::SimplifiedTo(_) => panic!("InsertInstructionResult::results called on a simplified instruction"), - InsertInstructionResult::InstructionRemoved => panic!("InsertInstructionResult::results called on a removed instruction"), + InsertInstructionResult::SimplifiedTo(_) => { + panic!("InsertInstructionResult::results called on a simplified instruction") + } + InsertInstructionResult::InstructionRemoved => { + panic!("InsertInstructionResult::results called on a removed instruction") + } } } diff --git a/crates/noirc_evaluator/src/ssa_refactor/ir/instruction.rs b/crates/noirc_evaluator/src/ssa_refactor/ir/instruction.rs index b5dac74f248..c2ee104d058 100644 --- a/crates/noirc_evaluator/src/ssa_refactor/ir/instruction.rs +++ b/crates/noirc_evaluator/src/ssa_refactor/ir/instruction.rs @@ -1,7 +1,13 @@ use acvm::{acir::BlackBoxFunc, FieldElement}; use iter_extended::vecmap; -use super::{basic_block::BasicBlockId, map::Id, types::Type, value::{ValueId, Value}, dfg::DataFlowGraph}; +use super::{ + basic_block::BasicBlockId, + dfg::DataFlowGraph, + map::Id, + types::Type, + value::{Value, ValueId}, +}; /// Reference to an instruction /// @@ -164,17 +170,17 @@ impl Instruction { Value::NumericConstant { constant, typ } if *typ == Type::bool() => { let value = dfg[*constant].value().is_zero() as u128; Some(dfg.make_constant(value.into(), Type::bool())) - }, + } Value::Instruction { instruction, .. } => { // !!v => v match &dfg[*instruction] { Instruction::Not(value) => Some(*value), _ => None, } - }, + } _ => None, } - }, + } Instruction::Constrain(value) => { if let Some(constant) = dfg.get_numeric_constant(*value) { if constant.is_one() { @@ -183,7 +189,7 @@ impl Instruction { } } None - }, + } Instruction::Truncate { .. } => None, Instruction::Call { .. } => None, Instruction::Allocate { .. } => None, @@ -283,12 +289,12 @@ impl Binary { if rhs_is_zero { return Some(self.lhs); } - }, + } BinaryOp::Sub => { if rhs_is_zero { return Some(self.lhs); } - }, + } BinaryOp::Mul => { if lhs_is_one { return Some(self.rhs); @@ -296,32 +302,32 @@ impl Binary { if rhs_is_one { return Some(self.lhs); } - }, + } BinaryOp::Div => { if rhs_is_one { return Some(self.lhs); } - }, + } BinaryOp::Mod => { if rhs_is_one { return Some(self.lhs); } - }, + } BinaryOp::Eq => { if self.lhs == self.rhs { return Some(dfg.make_constant(FieldElement::one(), Type::bool())); } - }, + } BinaryOp::Lt => { if self.lhs == self.rhs { return Some(dfg.make_constant(FieldElement::zero(), Type::bool())); } - }, + } BinaryOp::And => { if lhs_is_zero || rhs_is_zero { return Some(dfg.make_constant(FieldElement::zero(), operand_type)); } - }, + } BinaryOp::Or => { if lhs_is_zero { return Some(self.rhs); @@ -329,23 +335,29 @@ impl Binary { if rhs_is_zero { return Some(self.lhs); } - }, + } BinaryOp::Xor => (), BinaryOp::Shl => { if rhs_is_zero { return Some(self.lhs); } - }, + } BinaryOp::Shr => { if rhs_is_zero { return Some(self.lhs); } - }, + } } None } - fn eval_constants(&self, dfg: &mut DataFlowGraph, lhs: FieldElement, rhs: FieldElement, operand_type: Type) -> Option> { + fn eval_constants( + &self, + dfg: &mut DataFlowGraph, + lhs: FieldElement, + rhs: FieldElement, + operand_type: Type, + ) -> Option> { let value = match self.operator { BinaryOp::Add => lhs + rhs, BinaryOp::Sub => lhs - rhs, @@ -368,7 +380,11 @@ impl Binary { /// Try to evaluate the given operands as u128s for operators that are only valid on u128s, /// like the bitwise operators and modulus. - fn eval_constant_u128_operations(&self, lhs: FieldElement, rhs: FieldElement) -> Option { + fn eval_constant_u128_operations( + &self, + lhs: FieldElement, + rhs: FieldElement, + ) -> Option { let lhs = lhs.try_into_u128()?; let rhs = rhs.try_into_u128()?; match self.operator { @@ -391,8 +407,9 @@ impl Binary { | BinaryOp::Mul | BinaryOp::Div | BinaryOp::Eq - | BinaryOp::Lt) => panic!("eval_constant_u128_operations invalid for {op:?} use eval_constants instead"), - + | BinaryOp::Lt) => panic!( + "eval_constant_u128_operations invalid for {op:?} use eval_constants instead" + ), } } } diff --git a/crates/noirc_evaluator/src/ssa_refactor/opt/inlining.rs b/crates/noirc_evaluator/src/ssa_refactor/opt/inlining.rs index 69e25fda87f..9ed616026e5 100644 --- a/crates/noirc_evaluator/src/ssa_refactor/opt/inlining.rs +++ b/crates/noirc_evaluator/src/ssa_refactor/opt/inlining.rs @@ -9,9 +9,10 @@ use iter_extended::vecmap; use crate::ssa_refactor::{ ir::{ basic_block::BasicBlockId, + dfg::InsertInstructionResult, function::{Function, FunctionId}, instruction::{Instruction, InstructionId, TerminatorInstruction}, - value::{Value, ValueId}, dfg::InsertInstructionResult, + value::{Value, ValueId}, }, ssa_builder::FunctionBuilder, ssa_gen::Ssa, @@ -336,12 +337,12 @@ impl<'function> PerFunctionContext<'function> { match new_results { InsertInstructionResult::SimplifiedTo(new_result) => { values.insert(old_results[0], new_result); - }, + } InsertInstructionResult::Results(new_results) => { for (old_result, new_result) in old_results.iter().zip(new_results) { values.insert(*old_result, *new_result); } - }, + } InsertInstructionResult::InstructionRemoved => (), } } diff --git a/crates/noirc_evaluator/src/ssa_refactor/ssa_builder/mod.rs b/crates/noirc_evaluator/src/ssa_refactor/ssa_builder/mod.rs index 4029757d7de..60379097523 100644 --- a/crates/noirc_evaluator/src/ssa_refactor/ssa_builder/mod.rs +++ b/crates/noirc_evaluator/src/ssa_refactor/ssa_builder/mod.rs @@ -11,7 +11,8 @@ use crate::ssa_refactor::ir::{ use super::{ ir::{ basic_block::BasicBlock, - instruction::{InstructionId, Intrinsic}, dfg::InsertInstructionResult, + dfg::InsertInstructionResult, + instruction::{InstructionId, Intrinsic}, }, ssa_gen::Ssa, }; From 899e4cf3812235b5ba6f1d5b51a644e44ae66bc0 Mon Sep 17 00:00:00 2001 From: Jake Fecher Date: Tue, 9 May 2023 13:59:45 -0500 Subject: [PATCH 09/25] Add comments --- crates/noirc_evaluator/src/ssa_refactor/ir/instruction.rs | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/crates/noirc_evaluator/src/ssa_refactor/ir/instruction.rs b/crates/noirc_evaluator/src/ssa_refactor/ir/instruction.rs index c2ee104d058..42968568dee 100644 --- a/crates/noirc_evaluator/src/ssa_refactor/ir/instruction.rs +++ b/crates/noirc_evaluator/src/ssa_refactor/ir/instruction.rs @@ -158,6 +158,8 @@ impl Instruction { } } + /// Try to simplify this instruction. If the instruction can be simplified to a known value, + /// that value is returned. Otherwise None is returned. pub(crate) fn simplify(&self, dfg: &mut DataFlowGraph) -> Option { match self { Instruction::Binary(binary) => binary.simplify(dfg), @@ -266,6 +268,7 @@ impl Binary { } } + /// Try to simplify this binary instruction, returning the new value if possible. fn simplify(&self, dfg: &mut DataFlowGraph) -> Option { let lhs = dfg.get_numeric_constant(self.lhs); let rhs = dfg.get_numeric_constant(self.rhs); @@ -351,6 +354,8 @@ impl Binary { None } + /// Evaluate the two constants with the operation specified by self.operator. + /// Pushes the resulting value to the given DataFlowGraph's constants and returns it. fn eval_constants( &self, dfg: &mut DataFlowGraph, From 1ffdb0faf60f21c529d50f6e8968d1165f5e12fb Mon Sep 17 00:00:00 2001 From: Jake Fecher Date: Tue, 9 May 2023 15:35:18 -0500 Subject: [PATCH 10/25] Add context object --- .../src/ssa_refactor/opt/simplify_cfg.rs | 84 ++++++++++++++----- 1 file changed, 63 insertions(+), 21 deletions(-) diff --git a/crates/noirc_evaluator/src/ssa_refactor/opt/simplify_cfg.rs b/crates/noirc_evaluator/src/ssa_refactor/opt/simplify_cfg.rs index d821fe81fbc..ebc230b49da 100644 --- a/crates/noirc_evaluator/src/ssa_refactor/opt/simplify_cfg.rs +++ b/crates/noirc_evaluator/src/ssa_refactor/opt/simplify_cfg.rs @@ -1,4 +1,8 @@ -use crate::ssa_refactor::{ssa_gen::Ssa, ir::{function::Function, cfg::ControlFlowGraph, basic_block::BasicBlockId}}; +use std::collections::{HashSet, HashMap}; + +use acvm::FieldElement; + +use crate::ssa_refactor::{ssa_gen::Ssa, ir::{function::Function, cfg::ControlFlowGraph, basic_block::BasicBlockId, value::ValueId, instruction::TerminatorInstruction}}; impl Ssa { @@ -12,33 +16,71 @@ impl Ssa { fn simplify_function_cfg(function: &mut Function) { let current_block = function.entry_block(); - simplify_function_cfg_recursive(function, current_block); + let mut context = Context::new(function); + context.simplify_function_cfg(current_block); } -fn simplify_function_cfg_recursive(function: &mut Function, current_block: BasicBlockId) { - let successors: Vec<_> = function.dfg[current_block].successors().collect(); +struct Context<'f> { + visited_blocks: HashSet, + values: HashMap, + function: &'f mut Function, +} - if successors.len() == 1 { - let source_block = successors[0]; - inline_instructions_from_block(function, current_block, source_block); - simplify_function_cfg_recursive(function, current_block); - } else { - for block in successors { - simplify_function_cfg_recursive(function, block); +impl<'f> Context<'f> { + fn new(function: &'f mut Function) -> Self { + Self { + visited_blocks: HashSet::new(), + values: HashMap::new(), + function, + } + } + + fn simplify_function_cfg(&mut self, current_block: BasicBlockId) { + let block = &self.function.dfg[current_block]; + let successors: Vec<_> = block.successors().collect(); + self.visited_blocks.insert(current_block); + + if successors.len() == 1 { + let source_block = successors[0]; + self.inline_instructions_from_block(current_block, source_block); + self.simplify_function_cfg(current_block); + } else if successors.len() > 1 { + if let Some(TerminatorInstruction::JmpIf { condition, then_destination, else_destination }) = block.terminator() { + match self.get_constant(*condition) { + Some(constant) => { + let next_block = if constant.is_zero() { *else_destination } else { *then_destination }; + self.inline_instructions_from_block(current_block, next_block); + self.simplify_function_cfg(current_block); + }, + None => todo!(), + } + + } else { + unreachable!("Only JmpIf terminators should have more than 1 successor") + } } } -} -/// TODO: Translate block parameters -fn inline_instructions_from_block(function: &mut Function, dest_block: BasicBlockId, source_block: BasicBlockId) { - let instructions = function.dfg[source_block].instructions().to_vec(); + fn get_value(&self, value: ValueId) -> ValueId { + self.values.get(&value).copied().unwrap_or(value) + } - // We cannot directly append each instruction since we need to substitute the - // block parameter values. - for instruction in instructions { - function.dfg.insert_instruction_in_block(dest_block, instruction); + fn get_constant(&self, value: ValueId) -> Option { + let value = self.get_value(value); + self.function.dfg.get_numeric_constant(value) } - let terminator = function.dfg[source_block].terminator().expect("Expected each block during the simplify_cfg optimization to have a terminator instruction").clone(); - function.dfg.set_block_terminator(dest_block, terminator); + /// TODO: Translate block parameters + fn inline_instructions_from_block(&mut self, dest_block: BasicBlockId, source_block: BasicBlockId) { + let instructions = self.function.dfg[source_block].instructions().to_vec(); + + // We cannot directly append each instruction since we need to substitute the + // block parameter values. + for instruction in instructions { + self.function.dfg.insert_instruction_in_block(dest_block, instruction); + } + + let terminator = self.function.dfg[source_block].terminator().expect("Expected each block during the simplify_cfg optimization to have a terminator instruction").clone(); + self.function.dfg.set_block_terminator(dest_block, terminator); + } } From 99284d6568a494b3cacef7238a854c03fc37f440 Mon Sep 17 00:00:00 2001 From: Jake Fecher Date: Wed, 10 May 2023 10:42:54 -0500 Subject: [PATCH 11/25] Add push_instruction --- crates/noirc_evaluator/src/ssa_refactor.rs | 22 ++++++++-- .../src/ssa_refactor/opt/simplify_cfg.rs | 41 ++++++++++++++++++- 2 files changed, 57 insertions(+), 6 deletions(-) diff --git a/crates/noirc_evaluator/src/ssa_refactor.rs b/crates/noirc_evaluator/src/ssa_refactor.rs index 34061227336..c0394fde4af 100644 --- a/crates/noirc_evaluator/src/ssa_refactor.rs +++ b/crates/noirc_evaluator/src/ssa_refactor.rs @@ -13,7 +13,7 @@ use noirc_abi::Abi; use noirc_frontend::monomorphization::ast::Program; -use self::acir_gen::Acir; +use self::ssa_gen::Ssa; mod acir_gen; mod ir; @@ -24,9 +24,15 @@ pub mod ssa_gen; /// Optimize the given program by converting it into SSA /// form and performing optimizations there. When finished, /// convert the final SSA into ACIR and return it. -pub fn optimize_into_acir(program: Program) -> Acir { - ssa_gen::generate_ssa(program).inline_functions().into_acir() +pub fn optimize_into_acir(program: Program) { + ssa_gen::generate_ssa(program) + .print("Initial SSA:") + .inline_functions() + .print("After Inlining:") + .simplify_cfg() + .print("After Simplifying the CFG:"); } + /// Compiles the Program into ACIR and applies optimizations to the arithmetic gates /// This is analogous to `ssa:create_circuit` and this method is called when one wants /// to use the new ssa module to process Noir code. @@ -37,5 +43,13 @@ pub fn experimental_create_circuit( _enable_logging: bool, _show_output: bool, ) -> Result<(Circuit, Abi), RuntimeError> { - todo!("this is a stub function for the new SSA refactor module") + optimize_into_acir(_program); + std::process::exit(0); +} + +impl Ssa { + fn print(self, msg: &str) -> Ssa { + println!("{msg}\n{self}"); + self + } } diff --git a/crates/noirc_evaluator/src/ssa_refactor/opt/simplify_cfg.rs b/crates/noirc_evaluator/src/ssa_refactor/opt/simplify_cfg.rs index ebc230b49da..7e467c2da31 100644 --- a/crates/noirc_evaluator/src/ssa_refactor/opt/simplify_cfg.rs +++ b/crates/noirc_evaluator/src/ssa_refactor/opt/simplify_cfg.rs @@ -1,8 +1,9 @@ use std::collections::{HashSet, HashMap}; use acvm::FieldElement; +use iter_extended::vecmap; -use crate::ssa_refactor::{ssa_gen::Ssa, ir::{function::Function, cfg::ControlFlowGraph, basic_block::BasicBlockId, value::ValueId, instruction::TerminatorInstruction}}; +use crate::ssa_refactor::{ssa_gen::Ssa, ir::{function::Function, cfg::ControlFlowGraph, basic_block::BasicBlockId, value::ValueId, instruction::{TerminatorInstruction, InstructionId}}}; impl Ssa { @@ -48,11 +49,22 @@ impl<'f> Context<'f> { if let Some(TerminatorInstruction::JmpIf { condition, then_destination, else_destination }) = block.terminator() { match self.get_constant(*condition) { Some(constant) => { + println!("jmpif destination known: {}", !constant.is_zero()); let next_block = if constant.is_zero() { *else_destination } else { *then_destination }; self.inline_instructions_from_block(current_block, next_block); self.simplify_function_cfg(current_block); }, - None => todo!(), + None => { + // We only allow dynamic branching if we're not going in a loop + assert!(!self.visited_blocks.contains(then_destination), "Dynamic loops are unsupported - block {then_destination} was already visited"); + assert!(!self.visited_blocks.contains(else_destination), "Dynamic loops are unsupported - block {else_destination} was already visited"); + let else_destination = *else_destination; + + self.inline_instructions_from_block(current_block, *then_destination); + self.inline_instructions_from_block(current_block, else_destination); + self.simplify_function_cfg(current_block); + self.simplify_function_cfg(current_block); + }, } } else { @@ -83,4 +95,29 @@ impl<'f> Context<'f> { let terminator = self.function.dfg[source_block].terminator().expect("Expected each block during the simplify_cfg optimization to have a terminator instruction").clone(); self.function.dfg.set_block_terminator(dest_block, terminator); } + + fn push_instruction(&mut self, id: InstructionId) { + let instruction = self.function.dfg[id].map_values(|id| self.get_value(id)); + let results = self.function.dfg.instruction_results(id); + + let ctrl_typevars = instruction + .requires_ctrl_typevars() + .then(|| vecmap(results, |result| self.function.dfg.type_of_value(*result))); + + // let new_results = self.function.dfg.insert_instruction_in_block(instruction, ctrl_typevars); + // Self::insert_new_instruction_results(&mut self.values, results, new_results); + } + + /// Modify the values HashMap to remember the mapping between an instruction result's previous + /// ValueId (from the source_function) and its new ValueId in the destination function. + fn insert_new_instruction_results( + values: &mut HashMap, + old_results: &[ValueId], + new_results: &[ValueId], + ) { + assert_eq!(old_results.len(), new_results.len()); + for (old_result, new_result) in old_results.iter().zip(new_results) { + values.insert(*old_result, *new_result); + } + } } From 86ead4d30f017b81f6d5845c0108aeb67521a477 Mon Sep 17 00:00:00 2001 From: Jake Fecher Date: Thu, 11 May 2023 09:21:54 -0500 Subject: [PATCH 12/25] Fix bug in inlining pass --- .../src/ssa_refactor/opt/inlining.rs | 116 +++++++++++++----- 1 file changed, 88 insertions(+), 28 deletions(-) diff --git a/crates/noirc_evaluator/src/ssa_refactor/opt/inlining.rs b/crates/noirc_evaluator/src/ssa_refactor/opt/inlining.rs index 288446bc702..28db5b6d550 100644 --- a/crates/noirc_evaluator/src/ssa_refactor/opt/inlining.rs +++ b/crates/noirc_evaluator/src/ssa_refactor/opt/inlining.rs @@ -80,10 +80,6 @@ struct PerFunctionContext<'function> { /// Maps InstructionIds from the function being inlined to the function being inlined into. instructions: HashMap, - /// The TerminatorInstruction::Return in the source_function will be mapped to a jmp to - /// this block in the destination function instead. - return_destination: BasicBlockId, - /// True if we're currently working on the main function. inlining_main: bool, } @@ -125,7 +121,7 @@ impl InlineContext { /// Inlines a function into the current function and returns the translated return values /// of the inlined function. - fn inline_function(&mut self, ssa: &Ssa, id: FunctionId, arguments: &[ValueId]) -> &[ValueId] { + fn inline_function(&mut self, ssa: &Ssa, id: FunctionId, arguments: &[ValueId]) -> Vec { self.recursion_level += 1; if self.recursion_level > RECURSION_LIMIT { @@ -144,10 +140,7 @@ impl InlineContext { let current_block = context.context.builder.current_block(); context.blocks.insert(source_function.entry_block(), current_block); - context.inline_blocks(ssa); - - let return_destination = context.return_destination; - self.builder.block_parameters(return_destination) + context.inline_blocks(ssa) } /// Finish inlining and return the new Ssa struct with the inlined version of main. @@ -177,10 +170,7 @@ impl<'function> PerFunctionContext<'function> { /// for containing the mapping between parameters in the source_function and /// the arguments of the destination function. fn new(context: &'function mut InlineContext, source_function: &'function Function) -> Self { - // Create the block to return to but don't insert its parameters until we - // have the types of the actual return values later. Self { - return_destination: context.builder.insert_block(), context, source_function, blocks: HashMap::new(), @@ -267,20 +257,27 @@ impl<'function> PerFunctionContext<'function> { } /// Inline all reachable blocks within the source_function into the destination function. - fn inline_blocks(&mut self, ssa: &Ssa) { + fn inline_blocks(&mut self, ssa: &Ssa) -> Vec { let mut seen_blocks = HashSet::new(); let mut block_queue = vec![self.source_function.entry_block()]; + let mut function_return = None; + while let Some(source_block_id) = block_queue.pop() { let translated_block_id = self.translate_block(source_block_id, &mut block_queue); self.context.builder.switch_to_block(translated_block_id); seen_blocks.insert(source_block_id); self.inline_block(ssa, source_block_id); - self.handle_terminator_instruction(source_block_id, &mut block_queue); + function_return = self.handle_terminator_instruction(source_block_id, &mut block_queue); } - self.context.builder.switch_to_block(self.return_destination); + if let Some((block, values)) = function_return { + self.context.builder.switch_to_block(block); + values + } else { + unreachable!("Inlined function had no return instruction") + } } /// Inline each instruction in the given block into the function being inlined into. @@ -309,7 +306,7 @@ impl<'function> PerFunctionContext<'function> { let old_results = self.source_function.dfg.instruction_results(call_id); let arguments = vecmap(arguments, |arg| self.translate_value(*arg)); let new_results = self.context.inline_function(ssa, function, &arguments); - let new_results = InsertInstructionResult::Results(new_results); + let new_results = InsertInstructionResult::Results(&new_results); Self::insert_new_instruction_results(&mut self.values, old_results, new_results); } @@ -352,16 +349,20 @@ impl<'function> PerFunctionContext<'function> { /// Handle the given terminator instruction from the given source function block. /// This will push any new blocks to the destination function as needed, add them /// to the block queue, and set the terminator instruction for the current block. + /// + /// If the terminator instruction was a Return, this will return the block this instruction + /// was in as well as the values that were returned. fn handle_terminator_instruction( &mut self, block_id: BasicBlockId, block_queue: &mut Vec, - ) { + ) -> Option<(BasicBlockId, Vec)> { match self.source_function.dfg[block_id].terminator() { Some(TerminatorInstruction::Jmp { destination, arguments }) => { let destination = self.translate_block(*destination, block_queue); let arguments2 = vecmap(arguments, |arg| self.translate_value(*arg)); self.context.builder.terminate_with_jmp(destination, arguments2); + None } Some(TerminatorInstruction::JmpIf { condition, @@ -372,21 +373,14 @@ impl<'function> PerFunctionContext<'function> { let then_block = self.translate_block(*then_destination, block_queue); let else_block = self.translate_block(*else_destination, block_queue); self.context.builder.terminate_with_jmpif(condition, then_block, else_block); + None } Some(TerminatorInstruction::Return { return_values }) => { let return_values = vecmap(return_values, |value| self.translate_value(*value)); - if self.inlining_main { - self.context.builder.terminate_with_return(return_values); - } else { - for value in &return_values { - // Add the block parameters for the return block here since we don't do - // it when inserting the block in PerFunctionContext::new - let typ = self.context.builder.current_function.dfg.type_of_value(*value); - self.context.builder.add_block_parameter(self.return_destination, typ); - } - self.context.builder.terminate_with_jmp(self.return_destination, return_values); + self.context.builder.terminate_with_return(return_values.clone()); } + Some((block_id, return_values)) } None => unreachable!("Block has no terminator instruction"), } @@ -396,7 +390,7 @@ impl<'function> PerFunctionContext<'function> { #[cfg(test)] mod test { use crate::ssa_refactor::{ - ir::{map::Id, types::Type}, + ir::{map::Id, types::Type, instruction::BinaryOp}, ssa_builder::FunctionBuilder, }; @@ -430,4 +424,70 @@ mod test { let inlined = ssa.inline_functions(); assert_eq!(inlined.functions.len(), 1); } + + #[test] + fn complex_inlining() { + // This SSA is from issue #1327 which previously failed to inline properly + // + // fn main f0 { + // b0(v0: Field): + // v7 = call f2(f1) + // v13 = call f3(v7) + // v16 = call v13(v0) + // return v16 + // } + // fn square f1 { + // b0(v0: Field): + // v2 = mul v0, v0 + // return v2 + // } + // fn id1 f2 { + // b0(v0: function): + // return v0 + // } + // fn id2 f3 { + // b0(v0: function): + // return v0 + // } + let main_id = Id::test_new(0); + let square_id = Id::test_new(1); + let id1_id = Id::test_new(2); + let id2_id = Id::test_new(3); + + // Compiling main + let mut builder = FunctionBuilder::new("main".into(), main_id); + let main_v0 = builder.add_parameter(Type::field()); + + let main_f1 = builder.import_function(square_id); + let main_f2 = builder.import_function(id1_id); + let main_f3 = builder.import_function(id2_id); + + let main_v7 = builder.insert_call(main_f2, vec![main_f1], vec![Type::Function])[0]; + let main_v13 = builder.insert_call(main_f3, vec![main_v7], vec![Type::Function])[0]; + let main_v16 = builder.insert_call(main_v13, vec![main_v0], vec![Type::field()])[0]; + builder.terminate_with_return(vec![main_v16]); + + // Compiling square f1 + builder.new_function("square".into(), square_id); + let square_v0 = builder.add_parameter(Type::field()); + let square_v2 = builder.insert_binary(square_v0, BinaryOp::Mul, square_v0); + builder.terminate_with_return(vec![square_v2]); + + // Compiling id1 f2 + builder.new_function("id1".into(), id1_id); + let id1_v0 = builder.add_parameter(Type::Function); + builder.terminate_with_return(vec![id1_v0]); + + // Compiling id2 f3 + builder.new_function("id2".into(), id2_id); + let id2_v0 = builder.add_parameter(Type::Function); + builder.terminate_with_return(vec![id2_v0]); + + // Done, now we test that we can successfully inline all functions. + let ssa = builder.finish(); + assert_eq!(ssa.functions.len(), 2); + + let inlined = ssa.inline_functions(); + assert_eq!(inlined.functions.len(), 1); + } } From e4720c5661bc9533cb3aa59a3089eb9b65b3acf9 Mon Sep 17 00:00:00 2001 From: Jake Fecher Date: Thu, 11 May 2023 09:22:11 -0500 Subject: [PATCH 13/25] Reorder loop unrolling pass --- crates/noirc_evaluator/src/ssa_refactor.rs | 2 +- .../src/ssa_refactor/ir/instruction.rs | 22 +- .../src/ssa_refactor/ir/printer.rs | 2 +- .../src/ssa_refactor/opt/mod.rs | 1 + .../src/ssa_refactor/opt/simplify_cfg.rs | 151 +---------- .../src/ssa_refactor/opt/unrolling.rs | 254 ++++++++++++++++++ 6 files changed, 275 insertions(+), 157 deletions(-) create mode 100644 crates/noirc_evaluator/src/ssa_refactor/opt/unrolling.rs diff --git a/crates/noirc_evaluator/src/ssa_refactor.rs b/crates/noirc_evaluator/src/ssa_refactor.rs index c0394fde4af..63d1454c303 100644 --- a/crates/noirc_evaluator/src/ssa_refactor.rs +++ b/crates/noirc_evaluator/src/ssa_refactor.rs @@ -29,7 +29,7 @@ pub fn optimize_into_acir(program: Program) { .print("Initial SSA:") .inline_functions() .print("After Inlining:") - .simplify_cfg() + .unroll() .print("After Simplifying the CFG:"); } diff --git a/crates/noirc_evaluator/src/ssa_refactor/ir/instruction.rs b/crates/noirc_evaluator/src/ssa_refactor/ir/instruction.rs index 69e7de7ea80..1fc799adc75 100644 --- a/crates/noirc_evaluator/src/ssa_refactor/ir/instruction.rs +++ b/crates/noirc_evaluator/src/ssa_refactor/ir/instruction.rs @@ -164,9 +164,11 @@ impl Instruction { use SimplifyResult::*; match self { Instruction::Binary(binary) => binary.simplify(dfg), - Instruction::Cast(value, typ) => match (*typ == dfg.type_of_value(*value)).then_some(*value) { - Some(value) => SimplifiedTo(value), - _ => None, + Instruction::Cast(value, typ) => { + match (*typ == dfg.type_of_value(*value)).then_some(*value) { + Some(value) => SimplifiedTo(value), + _ => None, + } } Instruction::Not(value) => { match &dfg[*value] { @@ -346,17 +348,23 @@ impl Binary { } BinaryOp::Eq => { if self.lhs == self.rhs { - return SimplifyResult::SimplifiedTo(dfg.make_constant(FieldElement::one(), Type::bool())); + return SimplifyResult::SimplifiedTo( + dfg.make_constant(FieldElement::one(), Type::bool()), + ); } } BinaryOp::Lt => { if self.lhs == self.rhs { - return SimplifyResult::SimplifiedTo(dfg.make_constant(FieldElement::zero(), Type::bool())); + return SimplifyResult::SimplifiedTo( + dfg.make_constant(FieldElement::zero(), Type::bool()), + ); } } BinaryOp::And => { if lhs_is_zero || rhs_is_zero { - return SimplifyResult::SimplifiedTo(dfg.make_constant(FieldElement::zero(), operand_type)); + return SimplifyResult::SimplifiedTo( + dfg.make_constant(FieldElement::zero(), operand_type), + ); } } BinaryOp::Or => { @@ -516,5 +524,5 @@ pub(crate) enum SimplifyResult { Remove, /// Instruction could not be simplified - None + None, } diff --git a/crates/noirc_evaluator/src/ssa_refactor/ir/printer.rs b/crates/noirc_evaluator/src/ssa_refactor/ir/printer.rs index 3eb6d43e17d..403a2087c91 100644 --- a/crates/noirc_evaluator/src/ssa_refactor/ir/printer.rs +++ b/crates/noirc_evaluator/src/ssa_refactor/ir/printer.rs @@ -59,7 +59,7 @@ pub(crate) fn display_block( /// Specialize displaying value ids so that if they refer to a numeric /// constant or a function we print those directly. -fn value(function: &Function, id: ValueId) -> String { +pub(crate) fn value(function: &Function, id: ValueId) -> String { use super::value::Value; match &function.dfg[id] { Value::NumericConstant { constant, typ } => { diff --git a/crates/noirc_evaluator/src/ssa_refactor/opt/mod.rs b/crates/noirc_evaluator/src/ssa_refactor/opt/mod.rs index 2701b0bb73c..997472bde84 100644 --- a/crates/noirc_evaluator/src/ssa_refactor/opt/mod.rs +++ b/crates/noirc_evaluator/src/ssa_refactor/opt/mod.rs @@ -5,3 +5,4 @@ //! Generally, these passes are also expected to minimize the final amount of instructions. mod inlining; mod simplify_cfg; +mod unrolling; diff --git a/crates/noirc_evaluator/src/ssa_refactor/opt/simplify_cfg.rs b/crates/noirc_evaluator/src/ssa_refactor/opt/simplify_cfg.rs index e5a09d2e8ea..f52c90fe27e 100644 --- a/crates/noirc_evaluator/src/ssa_refactor/opt/simplify_cfg.rs +++ b/crates/noirc_evaluator/src/ssa_refactor/opt/simplify_cfg.rs @@ -1,156 +1,11 @@ -use std::collections::{HashMap, HashSet}; - -use acvm::FieldElement; -use iter_extended::vecmap; - -use crate::ssa_refactor::{ - ir::{ - basic_block::BasicBlockId, - dfg::InsertInstructionResult, - function::Function, - instruction::{InstructionId, TerminatorInstruction}, - value::ValueId, - }, - ssa_gen::Ssa, -}; +use crate::ssa_refactor::Ssa; impl Ssa { + /// Simplifies the function's control flow graph by removing blocks pub(crate) fn simplify_cfg(mut self) -> Ssa { for function in self.functions.values_mut() { - simplify_function_cfg(function); + // Context::new(function).simplify_function_cfg(); } self } } - -fn simplify_function_cfg(function: &mut Function) { - let current_block = function.entry_block(); - let mut context = Context::new(function); - context.simplify_function_cfg(current_block); -} - -struct Context<'f> { - visited_blocks: HashSet, - values: HashMap, - function: &'f mut Function, -} - -impl<'f> Context<'f> { - fn new(function: &'f mut Function) -> Self { - Self { visited_blocks: HashSet::new(), values: HashMap::new(), function } - } - - fn simplify_function_cfg(&mut self, current_block: BasicBlockId) { - let block = &self.function.dfg[current_block]; - self.visited_blocks.insert(current_block); - - match block.terminator() { - Some(TerminatorInstruction::Jmp { destination, arguments }) => { - let source_block = *destination; - let arguments = arguments.clone(); // TODO Remove clone - self.inline_instructions_from_block(current_block, &arguments, source_block); - self.simplify_function_cfg(current_block); - }, - Some(TerminatorInstruction::JmpIf { condition, then_destination, else_destination }) => { - match self.get_constant(*condition) { - Some(constant) => { - let next_block = - if constant.is_zero() { *else_destination } else { *then_destination }; - self.inline_instructions_from_block(current_block, &[], next_block); - self.simplify_function_cfg(current_block); - } - None => { - // We only allow dynamic branching if we're not going in a loop - assert!(!self.visited_blocks.contains(then_destination), "Dynamic loops are unsupported - block {then_destination} was already visited"); - assert!(!self.visited_blocks.contains(else_destination), "Dynamic loops are unsupported - block {else_destination} was already visited"); - let else_destination = *else_destination; - - self.inline_instructions_from_block(current_block, &[], *then_destination); - self.simplify_function_cfg(current_block); - self.inline_instructions_from_block(current_block, &[], else_destination); - self.simplify_function_cfg(current_block); - } - } - }, - Some(TerminatorInstruction::Return { return_values: _ }) => (), - None => unreachable!("Block has no terminator"), - } - } - - fn get_value(&self, value: ValueId) -> ValueId { - self.values.get(&value).copied().unwrap_or(value) - } - - fn get_constant(&self, value: ValueId) -> Option { - let value = self.get_value(value); - self.function.dfg.get_numeric_constant(value) - } - - /// TODO: Translate block parameters - fn inline_instructions_from_block( - &mut self, - dest_block: BasicBlockId, - jmp_args: &[ValueId], - source_block_id: BasicBlockId, - ) { - let source_block = &self.function.dfg[source_block_id]; - assert_eq!(source_block.parameters().len(), jmp_args.len(), "Parameter len != arg len when inlining block {source_block_id} into {dest_block}"); - - // Map each parameter to its new value - for (param, arg) in source_block.parameters().iter().zip(jmp_args) { - self.values.insert(*param, *arg); - } - - let instructions = source_block.instructions().to_vec(); - - // We cannot directly append each instruction since we need to substitute the - // block parameter values. - for instruction in instructions { - self.push_instruction(dest_block, instruction); - } - - let terminator = self.function.dfg[source_block_id].terminator() - .expect("Expected each block during the simplify_cfg optimization to have a terminator instruction") - .map_values(|id| self.get_value(id)); - - self.function.dfg.set_block_terminator(dest_block, terminator); - } - - fn push_instruction(&mut self, current_block: BasicBlockId, id: InstructionId) { - let instruction = self.function.dfg[id].map_values(|id| self.get_value(id)); - let results = self.function.dfg.instruction_results(id).to_vec(); - - let ctrl_typevars = instruction - .requires_ctrl_typevars() - .then(|| vecmap(&results, |result| self.function.dfg.type_of_value(*result))); - - let new_results = self.function.dfg.insert_instruction_and_results( - instruction, - current_block, - ctrl_typevars, - ); - Self::insert_new_instruction_results(&mut self.values, &results, new_results); - } - - /// Modify the values HashMap to remember the mapping between an instruction result's previous - /// ValueId (from the source_function) and its new ValueId in the destination function. - fn insert_new_instruction_results( - values: &mut HashMap, - old_results: &[ValueId], - new_results: InsertInstructionResult, - ) { - assert_eq!(old_results.len(), new_results.len()); - - match new_results { - InsertInstructionResult::SimplifiedTo(new_result) => { - values.insert(old_results[0], new_result); - } - InsertInstructionResult::Results(new_results) => { - for (old_result, new_result) in old_results.iter().zip(new_results) { - values.insert(*old_result, *new_result); - } - } - InsertInstructionResult::InstructionRemoved => (), - } - } -} diff --git a/crates/noirc_evaluator/src/ssa_refactor/opt/unrolling.rs b/crates/noirc_evaluator/src/ssa_refactor/opt/unrolling.rs new file mode 100644 index 00000000000..452c161316b --- /dev/null +++ b/crates/noirc_evaluator/src/ssa_refactor/opt/unrolling.rs @@ -0,0 +1,254 @@ +use std::collections::{HashMap, HashSet}; + +use acvm::FieldElement; +use iter_extended::vecmap; + +use crate::ssa_refactor::{ + ir::{ + basic_block::BasicBlockId, + cfg::ControlFlowGraph, + dfg::InsertInstructionResult, + function::Function, + instruction::{InstructionId, TerminatorInstruction}, + value::ValueId, + }, + ssa_gen::Ssa, +}; + +impl Ssa { + pub(crate) fn unroll(mut self) -> Ssa { + for function in self.functions.values_mut() { + Context::new(function).simplify_function_cfg(); + } + self + } +} + +struct Context<'f> { + visited_blocks: HashSet, + values: HashMap, + function: &'f mut Function, + + current_block: BasicBlockId, + + /// This CFG is the original CFG before the pass modifies each block + cfg: ControlFlowGraph, +} + +enum Job { + MergeIf { + block_with_brif: BasicBlockId, + merge_point: BasicBlockId, + + // The last block where the condition of the brif instruction can be assumed to be true. + // This block should always be a direct predecessor of merge_point + final_then_block: BasicBlockId, + + // The last block where the condition of the brif instruction can be assumed to be false. + // This block should always be a direct predecessor of merge_point + final_else_block: BasicBlockId, + }, + UnrollLoop { + pre_loop: BasicBlockId, + loop_start: BasicBlockId, + loop_body: BasicBlockId, + loop_end: BasicBlockId, + }, + /// A transitive block is one that looks like B in A -> B -> {C, D, ..}. + /// That is, it is unconditionally branched to by exactly 1 predecessor. + /// We can merge these blocks and remove the unnecessary br to turn the + /// new cfg into AB -> {C, D, ..} + RemoveTransitiveBlock { + before: BasicBlockId, + transitive: BasicBlockId, + after: BasicBlockId, + } +} + +// fn main f2 { +// b0(v0: u1): +// jmpif v0 then: b2, else: b3 +// b2(): +// jmp b4(Field 6 (v17)) +// b4(v1: Field): +// v4 = eq v1, Field 6 (v3) +// constrain v4 +// jmp b6(Field 0 (v8)) +// b6(v7: Field): +// v9 = lt v7, Field 2 (v6) +// jmpif v9 then: b7, else: b8 +// b7(): +// v13 = call println(v7) +// v15 = add v7, Field 1 (v14) +// jmp b6(v15) +// b8(): +// jmp b5(unit 0 (v10)) +// b5(v11: unit): +// return unit 0 (v16) +// b3(): +// jmp b4(Field 7 (v2)) +// } +// +// +// b0 -> b2 v----| +// => b4 -> b6 -> b7 +// -> b3 +// -> b8 -> b5 +// +// MergeIf(b0, b2, b3, b4) +// UnrollLoop(b4, b6, b7, b8) +// RemoveTransitiveBlock(b8, b5) + +impl<'f> Context<'f> { + fn new(function: &'f mut Function) -> Self { + Self { + visited_blocks: HashSet::new(), + values: HashMap::new(), + cfg: ControlFlowGraph::with_function(function), + current_block: function.entry_block(), + function, + } + } + + fn simplify_function_cfg(&mut self) { + let block = &self.function.dfg[self.current_block]; + self.visited_blocks.insert(self.current_block); + + match block.terminator() { + // TODO Remove the clone + Some(TerminatorInstruction::Jmp { destination, arguments }) => { + self.handle_jmp(*destination, &arguments.clone()); + } + Some(TerminatorInstruction::JmpIf { + condition, + then_destination, + else_destination, + }) => { + self.handle_jmpif(*condition, *then_destination, *else_destination); + } + Some(TerminatorInstruction::Return { return_values: _ }) => (), + None => unreachable!("Block has no terminator"), + } + } + + fn handle_jmp( + &mut self, + destination: BasicBlockId, + arguments: &[ValueId], + ) { + self.inline_instructions_from_block(&arguments, destination); + self.simplify_function_cfg(); + } + + fn handle_jmpif( + &mut self, + condition: ValueId, + then_block: BasicBlockId, + else_block: BasicBlockId, + ) { + match self.get_constant(condition) { + Some(constant) => { + let next_block = if constant.is_zero() { else_block } else { then_block }; + self.handle_jmp(next_block, &[]); + } + None => { + // We only allow dynamic branching if we're not going in a loop + assert!( + !self.visited_blocks.contains(&then_block), + "Dynamic loops are unsupported - block {then_block} was already visited" + ); + assert!( + !self.visited_blocks.contains(&else_block), + "Dynamic loops are unsupported - block {else_block} was already visited" + ); + + self.current_block = then_block; + self.handle_jmp(then_block, &[]); + self.current_block = else_block; + self.handle_jmp(else_block, &[]); + } + } + } + + fn get_value(&self, value: ValueId) -> ValueId { + self.values.get(&value).copied().unwrap_or(value) + } + + fn get_constant(&self, value: ValueId) -> Option { + let value = self.get_value(value); + self.function.dfg.get_numeric_constant(value) + } + + fn inline_instructions_from_block( + &mut self, + jmp_args: &[ValueId], + source_block_id: BasicBlockId, + ) { + let dest_block = self.current_block; + let source_block = &self.function.dfg[source_block_id]; + assert_eq!( + source_block.parameters().len(), + jmp_args.len(), + "Parameter len != arg len when inlining block {source_block_id} into {dest_block}" + ); + + // Map each parameter to its new value + for (param, arg) in source_block.parameters().iter().zip(jmp_args) { + self.values.insert(*param, *arg); + } + + let instructions = source_block.instructions().to_vec(); + + // We cannot directly append each instruction since we need to substitute the + // block parameter values. + for instruction in instructions { + self.push_instruction(instruction); + } + + let terminator = self.function.dfg[source_block_id].terminator() + .expect("Expected each block during the simplify_cfg optimization to have a terminator instruction") + .map_values(|id| self.get_value(id)); + + self.function.dfg.set_block_terminator(dest_block, terminator); + } + + fn push_instruction(&mut self, id: InstructionId) { + let instruction = self.function.dfg[id].map_values(|id| self.get_value(id)); + let results = self.function.dfg.instruction_results(id).to_vec(); + + let ctrl_typevars = instruction + .requires_ctrl_typevars() + .then(|| vecmap(&results, |result| self.function.dfg.type_of_value(*result))); + + let new_results = self.function.dfg.insert_instruction_and_results( + instruction, + self.current_block, + ctrl_typevars, + ); + Self::insert_new_instruction_results(&mut self.values, &results, new_results); + } + + /// Modify the values HashMap to remember the mapping between an instruction result's previous + /// ValueId (from the source_function) and its new ValueId in the destination function. + fn insert_new_instruction_results( + values: &mut HashMap, + old_results: &[ValueId], + new_results: InsertInstructionResult, + ) { + assert_eq!(old_results.len(), new_results.len()); + + match new_results { + InsertInstructionResult::SimplifiedTo(new_result) => { + println!("result {} -> {}", old_results[0], new_result); + values.insert(old_results[0], new_result); + } + InsertInstructionResult::Results(new_results) => { + for (old_result, new_result) in old_results.iter().zip(new_results) { + println!("result {} -> {}", old_result, new_result); + values.insert(*old_result, *new_result); + } + } + InsertInstructionResult::InstructionRemoved => (), + } + } +} From 57f0f031489ae3e774e3f402e9c785d1c7e1c67d Mon Sep 17 00:00:00 2001 From: Jake Fecher Date: Thu, 11 May 2023 15:19:38 -0500 Subject: [PATCH 14/25] Get it working for most loops. Still missing loops with if inside --- crates/noirc_evaluator/src/ssa_refactor.rs | 2 +- .../src/ssa_refactor/opt/inlining.rs | 14 +- .../src/ssa_refactor/opt/unrolling.rs | 160 +++++++++--------- 3 files changed, 89 insertions(+), 87 deletions(-) diff --git a/crates/noirc_evaluator/src/ssa_refactor.rs b/crates/noirc_evaluator/src/ssa_refactor.rs index 63d1454c303..cb81c2b0021 100644 --- a/crates/noirc_evaluator/src/ssa_refactor.rs +++ b/crates/noirc_evaluator/src/ssa_refactor.rs @@ -29,7 +29,7 @@ pub fn optimize_into_acir(program: Program) { .print("Initial SSA:") .inline_functions() .print("After Inlining:") - .unroll() + .unroll_loops() .print("After Simplifying the CFG:"); } diff --git a/crates/noirc_evaluator/src/ssa_refactor/opt/inlining.rs b/crates/noirc_evaluator/src/ssa_refactor/opt/inlining.rs index 28db5b6d550..408a27c6a2a 100644 --- a/crates/noirc_evaluator/src/ssa_refactor/opt/inlining.rs +++ b/crates/noirc_evaluator/src/ssa_refactor/opt/inlining.rs @@ -121,7 +121,12 @@ impl InlineContext { /// Inlines a function into the current function and returns the translated return values /// of the inlined function. - fn inline_function(&mut self, ssa: &Ssa, id: FunctionId, arguments: &[ValueId]) -> Vec { + fn inline_function( + &mut self, + ssa: &Ssa, + id: FunctionId, + arguments: &[ValueId], + ) -> Vec { self.recursion_level += 1; if self.recursion_level > RECURSION_LIMIT { @@ -269,7 +274,9 @@ impl<'function> PerFunctionContext<'function> { seen_blocks.insert(source_block_id); self.inline_block(ssa, source_block_id); - function_return = self.handle_terminator_instruction(source_block_id, &mut block_queue); + + self.handle_terminator_instruction(source_block_id, &mut block_queue) + .map(|ret| function_return = Some(ret)); } if let Some((block, values)) = function_return { @@ -380,6 +387,7 @@ impl<'function> PerFunctionContext<'function> { if self.inlining_main { self.context.builder.terminate_with_return(return_values.clone()); } + let block_id = self.translate_block(block_id, block_queue); Some((block_id, return_values)) } None => unreachable!("Block has no terminator instruction"), @@ -390,7 +398,7 @@ impl<'function> PerFunctionContext<'function> { #[cfg(test)] mod test { use crate::ssa_refactor::{ - ir::{map::Id, types::Type, instruction::BinaryOp}, + ir::{instruction::BinaryOp, map::Id, types::Type}, ssa_builder::FunctionBuilder, }; diff --git a/crates/noirc_evaluator/src/ssa_refactor/opt/unrolling.rs b/crates/noirc_evaluator/src/ssa_refactor/opt/unrolling.rs index 452c161316b..008b10e0ed7 100644 --- a/crates/noirc_evaluator/src/ssa_refactor/opt/unrolling.rs +++ b/crates/noirc_evaluator/src/ssa_refactor/opt/unrolling.rs @@ -16,9 +16,9 @@ use crate::ssa_refactor::{ }; impl Ssa { - pub(crate) fn unroll(mut self) -> Ssa { + pub(crate) fn unroll_loops(mut self) -> Ssa { for function in self.functions.values_mut() { - Context::new(function).simplify_function_cfg(); + Context::new(function).unroll_loops(); } self } @@ -31,93 +31,33 @@ struct Context<'f> { current_block: BasicBlockId, - /// This CFG is the original CFG before the pass modifies each block + /// The original ControlFlowGraph of this function before it was modified + /// by this loop unrolling pass. cfg: ControlFlowGraph, + inlined_loop_blocks: HashSet, } -enum Job { - MergeIf { - block_with_brif: BasicBlockId, - merge_point: BasicBlockId, - - // The last block where the condition of the brif instruction can be assumed to be true. - // This block should always be a direct predecessor of merge_point - final_then_block: BasicBlockId, - - // The last block where the condition of the brif instruction can be assumed to be false. - // This block should always be a direct predecessor of merge_point - final_else_block: BasicBlockId, - }, - UnrollLoop { - pre_loop: BasicBlockId, - loop_start: BasicBlockId, - loop_body: BasicBlockId, - loop_end: BasicBlockId, - }, - /// A transitive block is one that looks like B in A -> B -> {C, D, ..}. - /// That is, it is unconditionally branched to by exactly 1 predecessor. - /// We can merge these blocks and remove the unnecessary br to turn the - /// new cfg into AB -> {C, D, ..} - RemoveTransitiveBlock { - before: BasicBlockId, - transitive: BasicBlockId, - after: BasicBlockId, - } -} - -// fn main f2 { -// b0(v0: u1): -// jmpif v0 then: b2, else: b3 -// b2(): -// jmp b4(Field 6 (v17)) -// b4(v1: Field): -// v4 = eq v1, Field 6 (v3) -// constrain v4 -// jmp b6(Field 0 (v8)) -// b6(v7: Field): -// v9 = lt v7, Field 2 (v6) -// jmpif v9 then: b7, else: b8 -// b7(): -// v13 = call println(v7) -// v15 = add v7, Field 1 (v14) -// jmp b6(v15) -// b8(): -// jmp b5(unit 0 (v10)) -// b5(v11: unit): -// return unit 0 (v16) -// b3(): -// jmp b4(Field 7 (v2)) -// } -// -// -// b0 -> b2 v----| -// => b4 -> b6 -> b7 -// -> b3 -// -> b8 -> b5 -// -// MergeIf(b0, b2, b3, b4) -// UnrollLoop(b4, b6, b7, b8) -// RemoveTransitiveBlock(b8, b5) - impl<'f> Context<'f> { fn new(function: &'f mut Function) -> Self { Self { visited_blocks: HashSet::new(), values: HashMap::new(), - cfg: ControlFlowGraph::with_function(function), current_block: function.entry_block(), + inlined_loop_blocks: HashSet::new(), + cfg: ControlFlowGraph::with_function(function), function, } } - fn simplify_function_cfg(&mut self) { + fn unroll_loops(&mut self) { let block = &self.function.dfg[self.current_block]; self.visited_blocks.insert(self.current_block); + println!("Visited {}", self.current_block); match block.terminator() { // TODO Remove the clone Some(TerminatorInstruction::Jmp { destination, arguments }) => { - self.handle_jmp(*destination, &arguments.clone()); + self.handle_jmp(*destination, &arguments.clone(), false); } Some(TerminatorInstruction::JmpIf { condition, @@ -131,13 +71,63 @@ impl<'f> Context<'f> { } } + /// + /// entry -> a \ + /// |-> c + /// -> b / + /// + /// V----| + /// entry -> a + /// -> b + /// + fn handle_jmp( &mut self, destination: BasicBlockId, arguments: &[ValueId], + conditional_jmp: bool, ) { - self.inline_instructions_from_block(&arguments, destination); - self.simplify_function_cfg(); + let non_looping_predecessor_count = self.count_non_looping_predecessors(destination); + + if !conditional_jmp && non_looping_predecessor_count <= 1 { + // Inline the block + println!("Directly inlining {destination}"); + self.inline_instructions_from_block(&arguments, destination); + } else { + println!("Switching to {destination}"); + self.current_block = destination; + } + self.unroll_loops(); + } + + fn count_non_looping_predecessors(&mut self, block: BasicBlockId) -> usize { + let predecessors = self.cfg.predecessors(block); + + predecessors.filter(|pred| !self.reachable_from(*pred, *pred, &mut HashSet::new())).count() + } + + fn reachable_from( + &self, + current_block: BasicBlockId, + target: BasicBlockId, + visited: &mut HashSet, + ) -> bool { + if visited.contains(¤t_block) { + return false; + } + + visited.insert(current_block); + + for successor in self.cfg.successors(current_block) { + if successor == target { + return true; + } + if self.reachable_from(successor, target, visited) { + return true; + } + } + + false } fn handle_jmpif( @@ -149,23 +139,27 @@ impl<'f> Context<'f> { match self.get_constant(condition) { Some(constant) => { let next_block = if constant.is_zero() { else_block } else { then_block }; - self.handle_jmp(next_block, &[]); + self.inlined_loop_blocks.insert(self.current_block); + println!("Constant jmpif to {next_block}"); + self.handle_jmp(next_block, &[], false); } None => { // We only allow dynamic branching if we're not going in a loop - assert!( - !self.visited_blocks.contains(&then_block), - "Dynamic loops are unsupported - block {then_block} was already visited" - ); - assert!( - !self.visited_blocks.contains(&else_block), - "Dynamic loops are unsupported - block {else_block} was already visited" - ); + let verify = |block| { + let looped = self.visited_blocks.contains(block); + assert!(!looped, "Dynamic loops are unsupported - {block} was already visited"); + }; + + verify(&then_block); + verify(&else_block); + + println!("Condition = {condition}"); + println!("Non-constant jmpif to {then_block} or {else_block}"); self.current_block = then_block; - self.handle_jmp(then_block, &[]); + self.handle_jmp(then_block, &[], true); self.current_block = else_block; - self.handle_jmp(else_block, &[]); + self.handle_jmp(else_block, &[], true); } } } From b94bfff35bf64d250e497abcc9cf96f4606cc8b3 Mon Sep 17 00:00:00 2001 From: Jake Fecher Date: Tue, 16 May 2023 16:19:21 -0500 Subject: [PATCH 15/25] Rework entire pass from scratch --- crates/noirc_evaluator/src/ssa_refactor.rs | 2 +- .../src/ssa_refactor/acir_gen/mod.rs | 2 +- .../src/ssa_refactor/ir/basic_block.rs | 25 +- .../src/ssa_refactor/ir/cfg.rs | 4 - .../src/ssa_refactor/ir/dfg.rs | 30 +- .../src/ssa_refactor/ir/function.rs | 18 + .../src/ssa_refactor/ir/instruction.rs | 14 + .../src/ssa_refactor/opt/inlining.rs | 28 +- .../src/ssa_refactor/opt/simplify_cfg.rs | 2 +- .../src/ssa_refactor/opt/unrolling.rs | 437 ++++++++++++------ 10 files changed, 400 insertions(+), 162 deletions(-) diff --git a/crates/noirc_evaluator/src/ssa_refactor.rs b/crates/noirc_evaluator/src/ssa_refactor.rs index cb81c2b0021..f1e167fdd80 100644 --- a/crates/noirc_evaluator/src/ssa_refactor.rs +++ b/crates/noirc_evaluator/src/ssa_refactor.rs @@ -30,7 +30,7 @@ pub fn optimize_into_acir(program: Program) { .inline_functions() .print("After Inlining:") .unroll_loops() - .print("After Simplifying the CFG:"); + .print("After Unrolling:"); } /// Compiles the Program into ACIR and applies optimizations to the arithmetic gates diff --git a/crates/noirc_evaluator/src/ssa_refactor/acir_gen/mod.rs b/crates/noirc_evaluator/src/ssa_refactor/acir_gen/mod.rs index a0959db5db8..ddda689a0da 100644 --- a/crates/noirc_evaluator/src/ssa_refactor/acir_gen/mod.rs +++ b/crates/noirc_evaluator/src/ssa_refactor/acir_gen/mod.rs @@ -6,7 +6,7 @@ use super::ssa_gen::Ssa; struct Context {} /// The output of the Acir-gen pass -pub struct Acir {} +pub(crate) struct Acir {} impl Ssa { pub(crate) fn into_acir(self) -> Acir { diff --git a/crates/noirc_evaluator/src/ssa_refactor/ir/basic_block.rs b/crates/noirc_evaluator/src/ssa_refactor/ir/basic_block.rs index 30526bc296e..ad9ab914125 100644 --- a/crates/noirc_evaluator/src/ssa_refactor/ir/basic_block.rs +++ b/crates/noirc_evaluator/src/ssa_refactor/ir/basic_block.rs @@ -29,10 +29,10 @@ pub(crate) struct BasicBlock { pub(crate) type BasicBlockId = Id; impl BasicBlock { - /// Create a new BasicBlock with the given instructions. + /// Create a new BasicBlock with the given parameters. /// Parameters can also be added later via BasicBlock::add_parameter - pub(crate) fn new(instructions: Vec) -> Self { - Self { parameters: Vec::new(), instructions, terminator: None } + pub(crate) fn new() -> Self { + Self { parameters: Vec::new(), instructions: Vec::new(), terminator: None } } /// Returns the parameters of this block @@ -47,6 +47,12 @@ impl BasicBlock { self.parameters.push(parameter); } + /// Replace this block's current parameters with that of the given Vec. + /// This does not perform any checks that any previous parameters were unused. + pub(crate) fn set_parameters(&mut self, parameters: Vec) { + self.parameters = parameters; + } + /// Insert an instruction at the end of this block pub(crate) fn insert_instruction(&mut self, instruction: InstructionId) { self.instructions.push(instruction); @@ -78,6 +84,19 @@ impl BasicBlock { self.terminator.as_ref() } + /// Returns the terminator of this block, panics if there is None. + /// + /// Once this block has finished construction, this is expected to always be Some. + pub(crate) fn unwrap_terminator(&self) -> &TerminatorInstruction { + self.terminator().expect("Expected block to have terminator instruction") + } + + pub(crate) fn mutate_terminator_blocks(&mut self, f: impl FnMut(BasicBlockId) -> BasicBlockId) { + if let Some(terminator) = self.terminator.as_mut() { + terminator.mutate_blocks(f); + } + } + /// Iterate over all the successors of the currently block, as determined by /// the blocks jumped to in the terminator instruction. If there is no terminator /// instruction yet, this will iterate 0 times. diff --git a/crates/noirc_evaluator/src/ssa_refactor/ir/cfg.rs b/crates/noirc_evaluator/src/ssa_refactor/ir/cfg.rs index b2d16b29bfd..f219c874fa3 100644 --- a/crates/noirc_evaluator/src/ssa_refactor/ir/cfg.rs +++ b/crates/noirc_evaluator/src/ssa_refactor/ir/cfg.rs @@ -84,10 +84,6 @@ impl ControlFlowGraph { ); predecessor_node.successors.insert(to); let successor_node = self.data.entry(to).or_default(); - assert!( - successor_node.predecessors.len() < 2, - "ICE: A cfg node cannot have more than two predecessors" - ); successor_node.predecessors.insert(from); } diff --git a/crates/noirc_evaluator/src/ssa_refactor/ir/dfg.rs b/crates/noirc_evaluator/src/ssa_refactor/ir/dfg.rs index 9da388f4562..98c97bbb5f2 100644 --- a/crates/noirc_evaluator/src/ssa_refactor/ir/dfg.rs +++ b/crates/noirc_evaluator/src/ssa_refactor/ir/dfg.rs @@ -15,6 +15,7 @@ use super::{ }; use acvm::FieldElement; +use iter_extended::vecmap; /// The DataFlowGraph contains most of the actual data in a function including /// its blocks, instructions, and values. This struct is largely responsible for @@ -59,6 +60,8 @@ pub(crate) struct DataFlowGraph { signatures: DenseMap, /// All blocks in a function + /// + /// This map is sparse to allow removing unreachable blocks during optimizations blocks: DenseMap, } @@ -67,7 +70,27 @@ impl DataFlowGraph { /// After being created, the block is unreachable in the current function /// until another block is made to jump to it. pub(crate) fn make_block(&mut self) -> BasicBlockId { - self.blocks.insert(BasicBlock::new(Vec::new())) + self.blocks.insert(BasicBlock::new()) + } + + /// Create a new block with the same parameter count and parameter + /// types from the given block. + /// This is a somewhat niche operation used in loop unrolling but is included + /// here as doing it outside the DataFlowGraph would require cloning the parameters. + pub(crate) fn make_block_with_parameters_from_block( + &mut self, + block: BasicBlockId, + ) -> BasicBlockId { + let new_block = self.make_block(); + let parameters = self.blocks[block].parameters(); + + let parameters = vecmap(parameters.iter().enumerate(), |(position, param)| { + let typ = self.values[*param].get_type(); + self.values.insert(Value::Param { block: new_block, position, typ }) + }); + + self.blocks[new_block].set_parameters(parameters); + new_block } /// Get an iterator over references to each basic block within the dfg, paired with the basic @@ -80,6 +103,11 @@ impl DataFlowGraph { self.blocks.iter() } + // Remove all blocks in this DFG that do not satisfy the given predicate + // pub(crate) fn retain_blocks(&mut self, mut predicate: impl FnMut(BasicBlockId) -> bool) { + // self.blocks.retain(|id, _| predicate(*id)) + // } + /// Returns the parameters of the given block pub(crate) fn block_parameters(&self, block: BasicBlockId) -> &[ValueId] { self.blocks[block].parameters() diff --git a/crates/noirc_evaluator/src/ssa_refactor/ir/function.rs b/crates/noirc_evaluator/src/ssa_refactor/ir/function.rs index f37448462b7..8bd20050cfc 100644 --- a/crates/noirc_evaluator/src/ssa_refactor/ir/function.rs +++ b/crates/noirc_evaluator/src/ssa_refactor/ir/function.rs @@ -61,6 +61,24 @@ impl Function { pub(crate) fn parameters(&self) -> &[ValueId] { self.dfg.block_parameters(self.entry_block) } + + // Remove any unreachable blocks from this function. + // To do this, this method must traverse the cfg of this function in program order. + // pub(crate) fn remove_unreachable_blocks(&mut self) { + // let mut reached = HashSet::new(); + // let mut stack: Vec = vec![self.entry_block]; + + // while let Some(block) = stack.pop() { + // reached.insert(block); + // for block in self.dfg[block].successors() { + // if !reached.contains(&block) { + // stack.push(block); + // } + // } + // } + + // self.dfg.retain_blocks(|block| reached.contains(&block)); + // } } /// FunctionId is a reference for a function diff --git a/crates/noirc_evaluator/src/ssa_refactor/ir/instruction.rs b/crates/noirc_evaluator/src/ssa_refactor/ir/instruction.rs index 1fc799adc75..fb8473be3d2 100644 --- a/crates/noirc_evaluator/src/ssa_refactor/ir/instruction.rs +++ b/crates/noirc_evaluator/src/ssa_refactor/ir/instruction.rs @@ -273,6 +273,20 @@ impl TerminatorInstruction { } } } + + pub(crate) fn mutate_blocks(&mut self, mut f: impl FnMut(BasicBlockId) -> BasicBlockId) { + use TerminatorInstruction::*; + match self { + JmpIf { then_destination, else_destination, .. } => { + *then_destination = f(*then_destination); + *else_destination = f(*else_destination); + } + Jmp { destination, .. } => { + *destination = f(*destination); + } + Return { .. } => (), + } + } } /// A binary instruction in the IR. diff --git a/crates/noirc_evaluator/src/ssa_refactor/opt/inlining.rs b/crates/noirc_evaluator/src/ssa_refactor/opt/inlining.rs index 408a27c6a2a..8f94b791496 100644 --- a/crates/noirc_evaluator/src/ssa_refactor/opt/inlining.rs +++ b/crates/noirc_evaluator/src/ssa_refactor/opt/inlining.rs @@ -268,15 +268,16 @@ impl<'function> PerFunctionContext<'function> { let mut function_return = None; - while let Some(source_block_id) = block_queue.pop() { - let translated_block_id = self.translate_block(source_block_id, &mut block_queue); - self.context.builder.switch_to_block(translated_block_id); + while let Some(source_block) = block_queue.pop() { + let translated_block = self.translate_block(source_block, &mut block_queue); + self.context.builder.switch_to_block(translated_block); - seen_blocks.insert(source_block_id); - self.inline_block(ssa, source_block_id); + seen_blocks.insert(source_block); + self.inline_block(ssa, source_block); - self.handle_terminator_instruction(source_block_id, &mut block_queue) - .map(|ret| function_return = Some(ret)); + if let Some(ret) = self.handle_terminator_instruction(source_block, &mut block_queue) { + function_return = Some(ret); + } } if let Some((block, values)) = function_return { @@ -364,25 +365,21 @@ impl<'function> PerFunctionContext<'function> { block_id: BasicBlockId, block_queue: &mut Vec, ) -> Option<(BasicBlockId, Vec)> { - match self.source_function.dfg[block_id].terminator() { - Some(TerminatorInstruction::Jmp { destination, arguments }) => { + match self.source_function.dfg[block_id].unwrap_terminator() { + TerminatorInstruction::Jmp { destination, arguments } => { let destination = self.translate_block(*destination, block_queue); let arguments2 = vecmap(arguments, |arg| self.translate_value(*arg)); self.context.builder.terminate_with_jmp(destination, arguments2); None } - Some(TerminatorInstruction::JmpIf { - condition, - then_destination, - else_destination, - }) => { + TerminatorInstruction::JmpIf { condition, then_destination, else_destination } => { let condition = self.translate_value(*condition); let then_block = self.translate_block(*then_destination, block_queue); let else_block = self.translate_block(*else_destination, block_queue); self.context.builder.terminate_with_jmpif(condition, then_block, else_block); None } - Some(TerminatorInstruction::Return { return_values }) => { + TerminatorInstruction::Return { return_values } => { let return_values = vecmap(return_values, |value| self.translate_value(*value)); if self.inlining_main { self.context.builder.terminate_with_return(return_values.clone()); @@ -390,7 +387,6 @@ impl<'function> PerFunctionContext<'function> { let block_id = self.translate_block(block_id, block_queue); Some((block_id, return_values)) } - None => unreachable!("Block has no terminator instruction"), } } } diff --git a/crates/noirc_evaluator/src/ssa_refactor/opt/simplify_cfg.rs b/crates/noirc_evaluator/src/ssa_refactor/opt/simplify_cfg.rs index f52c90fe27e..5cd0759955e 100644 --- a/crates/noirc_evaluator/src/ssa_refactor/opt/simplify_cfg.rs +++ b/crates/noirc_evaluator/src/ssa_refactor/opt/simplify_cfg.rs @@ -3,7 +3,7 @@ use crate::ssa_refactor::Ssa; impl Ssa { /// Simplifies the function's control flow graph by removing blocks pub(crate) fn simplify_cfg(mut self) -> Ssa { - for function in self.functions.values_mut() { + for _function in self.functions.values_mut() { // Context::new(function).simplify_function_cfg(); } self diff --git a/crates/noirc_evaluator/src/ssa_refactor/opt/unrolling.rs b/crates/noirc_evaluator/src/ssa_refactor/opt/unrolling.rs index 008b10e0ed7..935994f8a14 100644 --- a/crates/noirc_evaluator/src/ssa_refactor/opt/unrolling.rs +++ b/crates/noirc_evaluator/src/ssa_refactor/opt/unrolling.rs @@ -1,6 +1,5 @@ use std::collections::{HashMap, HashSet}; -use acvm::FieldElement; use iter_extended::vecmap; use crate::ssa_refactor::{ @@ -8,8 +7,10 @@ use crate::ssa_refactor::{ basic_block::BasicBlockId, cfg::ControlFlowGraph, dfg::InsertInstructionResult, + dom::DominatorTree, function::Function, instruction::{InstructionId, TerminatorInstruction}, + post_order::PostOrder, value::ValueId, }, ssa_gen::Ssa, @@ -18,149 +19,278 @@ use crate::ssa_refactor::{ impl Ssa { pub(crate) fn unroll_loops(mut self) -> Ssa { for function in self.functions.values_mut() { - Context::new(function).unroll_loops(); + unroll_loops_in_function(function); } self } } -struct Context<'f> { - visited_blocks: HashSet, - values: HashMap, - function: &'f mut Function, +fn unroll_loops_in_function(function: &mut Function) { + // Arbitrary maximum of 10k loops unrolled in a program to prevent looping forever + // if a bug causes us to continually unroll the same loop. + let max_loops_unrolled = 10_000; - current_block: BasicBlockId, + for _ in 0..max_loops_unrolled { + // Recompute the cfg & dom_tree after each loop in case we unrolled into another loop. + // TODO: Optimize: lazily recompute this only if the next loops' blocks have already been visited. + let cfg = ControlFlowGraph::with_function(function); + let post_order = PostOrder::with_function(function); + let dom_tree = DominatorTree::with_cfg_and_post_order(&cfg, &post_order); - /// The original ControlFlowGraph of this function before it was modified - /// by this loop unrolling pass. - cfg: ControlFlowGraph, - inlined_loop_blocks: HashSet, + if let Some(loop_) = find_next_loop(function, &cfg, &dom_tree) { + unroll_loop(function, &cfg, loop_).ok(); + } else { + return; + } + } + + panic!("Did not finish unrolling all loops after the maximum of {max_loops_unrolled} loops unrolled") } -impl<'f> Context<'f> { - fn new(function: &'f mut Function) -> Self { - Self { - visited_blocks: HashSet::new(), - values: HashMap::new(), - current_block: function.entry_block(), - inlined_loop_blocks: HashSet::new(), - cfg: ControlFlowGraph::with_function(function), - function, +/// Find a loop in the program by finding a node that dominates any predecessor node. +/// The edge where this happens will be the back-edge of the loop. +fn find_next_loop( + function: &Function, + cfg: &ControlFlowGraph, + dom_tree: &DominatorTree, +) -> Option { + let mut loops = vec![]; + + for (block, _) in function.dfg.basic_blocks_iter() { + // These reachable checks wouldn't be needed if we only iterated over reachable blocks + if dom_tree.is_reachable(block) { + for predecessor in cfg.predecessors(block) { + if dom_tree.is_reachable(predecessor) && dom_tree.dominates(block, predecessor) { + // predecessor -> block is the back-edge of a loop + loops.push(find_blocks_in_loop(block, predecessor, cfg)); + } + } } } - fn unroll_loops(&mut self) { - let block = &self.function.dfg[self.current_block]; - self.visited_blocks.insert(self.current_block); - println!("Visited {}", self.current_block); + // Sort loops by block size so that we unroll the smaller, nested loops first. + loops.sort_by_key(|loop_| loop_.blocks.len()); + loops.pop() +} - match block.terminator() { - // TODO Remove the clone - Some(TerminatorInstruction::Jmp { destination, arguments }) => { - self.handle_jmp(*destination, &arguments.clone(), false); - } - Some(TerminatorInstruction::JmpIf { - condition, - then_destination, - else_destination, - }) => { - self.handle_jmpif(*condition, *then_destination, *else_destination); - } - Some(TerminatorInstruction::Return { return_values: _ }) => (), - None => unreachable!("Block has no terminator"), +/// Return each block that is in a loop starting in the given header block. +/// Expects back_edge_start -> header to be the back edge of the loop. +fn find_blocks_in_loop( + header: BasicBlockId, + back_edge_start: BasicBlockId, + cfg: &ControlFlowGraph, +) -> Loop { + let mut blocks = HashSet::new(); + blocks.insert(header); + + let mut insert = |block, stack: &mut Vec| { + if !blocks.contains(&block) { + blocks.insert(block); + stack.push(block); + } + }; + + let mut stack = vec![]; + insert(back_edge_start, &mut stack); + + while let Some(block) = stack.pop() { + for predecessor in cfg.predecessors(block) { + insert(predecessor, &mut stack); } } - /// - /// entry -> a \ - /// |-> c - /// -> b / - /// - /// V----| - /// entry -> a - /// -> b - /// - - fn handle_jmp( - &mut self, - destination: BasicBlockId, - arguments: &[ValueId], - conditional_jmp: bool, - ) { - let non_looping_predecessor_count = self.count_non_looping_predecessors(destination); + Loop { header, back_edge_start, blocks } +} - if !conditional_jmp && non_looping_predecessor_count <= 1 { - // Inline the block - println!("Directly inlining {destination}"); - self.inline_instructions_from_block(&arguments, destination); - } else { - println!("Switching to {destination}"); - self.current_block = destination; +fn unroll_loop(function: &mut Function, cfg: &ControlFlowGraph, loop_: Loop) -> Result<(), ()> { + let mut unroll_into = get_pre_header(cfg, &loop_); + let mut jump_value = get_induction_variable(function, unroll_into)?; + + while let Some(context) = unroll_loop_header(function, &loop_, unroll_into, jump_value) { + let (last_block, last_value) = context.unroll_loop_iteration(); + unroll_into = last_block; + jump_value = last_value; + } + + Ok(()) +} + +/// The loop pre-header is the block that comes before the loop begins. Generally a header block +/// is expected to have 2 predecessors: the pre-header and the final block of the loop which jumps +/// back to the beginning. +fn get_pre_header(cfg: &ControlFlowGraph, loop_: &Loop) -> BasicBlockId { + let mut pre_header = cfg + .predecessors(loop_.header) + .filter(|predecessor| *predecessor != loop_.back_edge_start) + .collect::>(); + + assert_eq!(pre_header.len(), 1); + pre_header.remove(0) +} + +/// Return the induction value of the current iteration of the loop, from the given block's jmp arguments. +/// +/// Expects the current block to terminate in `jmp h(N)` where h is the loop header and N is +/// a Field value. +fn get_induction_variable(function: &Function, block: BasicBlockId) -> Result { + match function.dfg[block].terminator() { + Some(TerminatorInstruction::Jmp { arguments, .. }) => { + assert_eq!(arguments.len(), 1); + let value = arguments[0]; + if function.dfg.get_numeric_constant(value).is_some() { + Ok(value) + } else { + Err(()) + } } - self.unroll_loops(); + _ => Err(()), } +} - fn count_non_looping_predecessors(&mut self, block: BasicBlockId) -> usize { - let predecessors = self.cfg.predecessors(block); +/// Unroll a single iteration of the loop. Returns true if we should perform another iteration. +fn unroll_loop_header<'a>( + function: &'a mut Function, + loop_: &'a Loop, + unroll_into: BasicBlockId, + induction_value: ValueId, +) -> Option> { + let mut context = LoopIteration::new(function, loop_, unroll_into, loop_.header); + let source_block = &context.function.dfg[context.source_block]; + assert_eq!(source_block.parameters().len(), 1, "Expected only 1 argument in loop header",); + + let first_param = source_block.parameters()[0]; + println!( + "Remembering {} <- {} ({:?})", + first_param, + induction_value, + context.function.dfg.get_numeric_constant(induction_value) + ); + context.values.insert(first_param, induction_value); + + context.inline_instructions_from_block(); + + match context.function.dfg[unroll_into].unwrap_terminator() { + TerminatorInstruction::JmpIf { condition, then_destination, else_destination } => { + let condition = context.get_value(*condition); + + match context.function.dfg.get_numeric_constant(condition) { + Some(constant) => { + let next_block = if constant.is_zero() { *else_destination } else { *then_destination }; + + // context.insert_block = next_block; + context.source_block = context.get_original_block(next_block); + + context.function.dfg.set_block_terminator(context.insert_block, TerminatorInstruction::Jmp { + destination: next_block, + arguments: Vec::new(), + }); + + // If the next block to jump to is outside of the loop, return None + loop_.blocks.contains(&context.source_block).then_some(context) + }, + None => { + // Non-constant loop. We have to reset the then and else destination back to + // the original blocks here since we won't be unrolling into the new blocks. + context.function.dfg.set_block_terminator(context.insert_block, TerminatorInstruction::JmpIf { + condition, + then_destination: context.get_original_block(*then_destination), + else_destination: context.get_original_block(*else_destination), + }); + + None + }, + } - predecessors.filter(|pred| !self.reachable_from(*pred, *pred, &mut HashSet::new())).count() + } + other => panic!("Expected loop header to terminate in a JmpIf to the loop body, but found {other:?} instead"), } +} + +struct LoopIteration<'f> { + function: &'f mut Function, + loop_: &'f Loop, + values: HashMap, + blocks: HashMap, + original_blocks: HashMap, + visited_blocks: HashSet, - fn reachable_from( - &self, - current_block: BasicBlockId, - target: BasicBlockId, - visited: &mut HashSet, - ) -> bool { - if visited.contains(¤t_block) { - return false; + insert_block: BasicBlockId, + source_block: BasicBlockId, + induction_value: Option<(BasicBlockId, ValueId)>, +} + +impl<'f> LoopIteration<'f> { + fn new( + function: &'f mut Function, + loop_: &'f Loop, + insert_block: BasicBlockId, + source_block: BasicBlockId, + ) -> Self { + Self { + function, + loop_, + insert_block, + source_block, + values: HashMap::new(), + blocks: HashMap::new(), + original_blocks: HashMap::new(), + visited_blocks: HashSet::new(), + induction_value: None, } + } - visited.insert(current_block); + /// Unroll a single iteration of the loop. + fn unroll_loop_iteration(mut self) -> (BasicBlockId, ValueId) { + let mut next_blocks = self.unroll_loop_block(); + next_blocks.retain(|block| self.loop_.blocks.contains(&self.get_original_block(*block))); - for successor in self.cfg.successors(current_block) { - if successor == target { - return true; - } - if self.reachable_from(successor, target, visited) { - return true; + while let Some(block) = next_blocks.pop() { + self.insert_block = block; + self.source_block = self.get_original_block(block); + + if !self.visited_blocks.contains(&self.source_block) { + let mut blocks = self.unroll_loop_block(); + blocks.retain(|block| self.loop_.blocks.contains(&self.get_original_block(*block))); + next_blocks.append(&mut blocks); } } - false + self.induction_value + .expect("Expected to find the induction variable by end of loop iteration") } - fn handle_jmpif( - &mut self, - condition: ValueId, - then_block: BasicBlockId, - else_block: BasicBlockId, - ) { - match self.get_constant(condition) { - Some(constant) => { - let next_block = if constant.is_zero() { else_block } else { then_block }; - self.inlined_loop_blocks.insert(self.current_block); - println!("Constant jmpif to {next_block}"); - self.handle_jmp(next_block, &[], false); + /// Unroll a single block in the current iteration of the loop + fn unroll_loop_block(&mut self) -> Vec { + self.inline_instructions_from_block(); + self.visited_blocks.insert(self.source_block); + + match self.function.dfg[self.insert_block].unwrap_terminator() { + TerminatorInstruction::JmpIf { condition, then_destination, else_destination } => { + let condition = self.get_value(*condition); + + match self.function.dfg.get_numeric_constant(condition) { + Some(constant) => { + let destination = + if constant.is_zero() { *else_destination } else { *then_destination }; + + let jmp = TerminatorInstruction::Jmp { destination, arguments: Vec::new() }; + self.function.dfg.set_block_terminator(self.insert_block, jmp); + + vec![destination] + } + None => { + vec![*then_destination, *else_destination] + } + } } - None => { - // We only allow dynamic branching if we're not going in a loop - let verify = |block| { - let looped = self.visited_blocks.contains(block); - assert!(!looped, "Dynamic loops are unsupported - {block} was already visited"); - }; - - verify(&then_block); - verify(&else_block); - - println!("Condition = {condition}"); - println!("Non-constant jmpif to {then_block} or {else_block}"); - - self.current_block = then_block; - self.handle_jmp(then_block, &[], true); - self.current_block = else_block; - self.handle_jmp(else_block, &[], true); + TerminatorInstruction::Jmp { destination, arguments } => { + if self.get_original_block(*destination) == self.loop_.header { + assert_eq!(arguments.len(), 1); + self.induction_value = Some((self.insert_block, arguments[0])); + } + vec![*destination] } + TerminatorInstruction::Return { .. } => vec![], } } @@ -168,28 +298,48 @@ impl<'f> Context<'f> { self.values.get(&value).copied().unwrap_or(value) } - fn get_constant(&self, value: ValueId) -> Option { - let value = self.get_value(value); - self.function.dfg.get_numeric_constant(value) + fn get_or_insert_block(&mut self, block: BasicBlockId) -> BasicBlockId { + if let Some(new_block) = self.blocks.get(&block) { + return *new_block; + } + + // If the block is in the loop we create a fresh block for each iteration + if self.loop_.blocks.contains(&block) { + let new_block = self.function.dfg.make_block_with_parameters_from_block(block); + + let old_parameters = self.function.dfg.block_parameters(block); + let new_parameters = self.function.dfg.block_parameters(new_block); + + for (param, new_param) in old_parameters.iter().zip(new_parameters) { + self.values.insert(*param, *new_param); + } + + self.blocks.insert(block, new_block); + self.original_blocks.insert(new_block, block); + new_block + } else { + block + } } - fn inline_instructions_from_block( - &mut self, - jmp_args: &[ValueId], - source_block_id: BasicBlockId, - ) { - let dest_block = self.current_block; - let source_block = &self.function.dfg[source_block_id]; - assert_eq!( - source_block.parameters().len(), - jmp_args.len(), - "Parameter len != arg len when inlining block {source_block_id} into {dest_block}" - ); + fn get_original_block(&self, block: BasicBlockId) -> BasicBlockId { + self.original_blocks.get(&block).copied().unwrap_or(block) + } + + fn inline_instructions_from_block(&mut self) { + let source_block = &self.function.dfg[self.source_block]; + // assert_eq!( + // source_block.parameters().len(), + // jmp_args.len(), + // "Parameter len != arg len when inlining block {} into {}", + // self.source_block, + // self.insert_block, + // ); // Map each parameter to its new value - for (param, arg) in source_block.parameters().iter().zip(jmp_args) { - self.values.insert(*param, *arg); - } + // for (param, arg) in source_block.parameters().iter().zip(jmp_args) { + // self.values.insert(*param, *arg); + // } let instructions = source_block.instructions().to_vec(); @@ -199,11 +349,17 @@ impl<'f> Context<'f> { self.push_instruction(instruction); } - let terminator = self.function.dfg[source_block_id].terminator() - .expect("Expected each block during the simplify_cfg optimization to have a terminator instruction") + let mut terminator = self.function.dfg[self.source_block] + .terminator() + .expect( + "Expected each block during the loop unrolling to have a terminator instruction", + ) .map_values(|id| self.get_value(id)); - self.function.dfg.set_block_terminator(dest_block, terminator); + terminator.mutate_blocks(|block| self.get_or_insert_block(block)); + self.function.dfg.set_block_terminator(self.insert_block, terminator); + + println!("Unrolled block: \n{}", self.function); } fn push_instruction(&mut self, id: InstructionId) { @@ -216,7 +372,7 @@ impl<'f> Context<'f> { let new_results = self.function.dfg.insert_instruction_and_results( instruction, - self.current_block, + self.insert_block, ctrl_typevars, ); Self::insert_new_instruction_results(&mut self.values, &results, new_results); @@ -233,12 +389,10 @@ impl<'f> Context<'f> { match new_results { InsertInstructionResult::SimplifiedTo(new_result) => { - println!("result {} -> {}", old_results[0], new_result); values.insert(old_results[0], new_result); } InsertInstructionResult::Results(new_results) => { for (old_result, new_result) in old_results.iter().zip(new_results) { - println!("result {} -> {}", old_result, new_result); values.insert(*old_result, *new_result); } } @@ -246,3 +400,16 @@ impl<'f> Context<'f> { } } } + +struct Loop { + /// The header block of a loop is the block which dominates all the + /// other blocks in the loop. + header: BasicBlockId, + + /// The start of the back_edge n -> d is the block n at the end of + /// the loop that jumps back to the header block d which restarts the loop. + back_edge_start: BasicBlockId, + + /// All the blocks contained within the loop, including `header` and `back_edge_start`. + blocks: HashSet, +} From 685c9d609cd9c91878282f6ae5a4d7b87708e1f3 Mon Sep 17 00:00:00 2001 From: Jake Fecher Date: Wed, 17 May 2023 11:32:07 -0500 Subject: [PATCH 16/25] Finish loop unrolling --- crates/noirc_evaluator/src/ssa_refactor.rs | 5 +- .../src/ssa_refactor/ir/basic_block.rs | 9 +- .../src/ssa_refactor/ir/dfg.rs | 8 + .../src/ssa_refactor/ir/instruction.rs | 2 + .../src/ssa_refactor/opt/unrolling.rs | 207 +++++++++--------- 5 files changed, 126 insertions(+), 105 deletions(-) diff --git a/crates/noirc_evaluator/src/ssa_refactor.rs b/crates/noirc_evaluator/src/ssa_refactor.rs index f1e167fdd80..8330931a4b4 100644 --- a/crates/noirc_evaluator/src/ssa_refactor.rs +++ b/crates/noirc_evaluator/src/ssa_refactor.rs @@ -7,7 +7,7 @@ //! This module heavily borrows from Cranelift #![allow(dead_code)] -use crate::errors::RuntimeError; +use crate::errors::{RuntimeError, RuntimeErrorKind}; use acvm::{acir::circuit::Circuit, compiler::transformers::IsOpcodeSupported, Language}; use noirc_abi::Abi; @@ -44,7 +44,8 @@ pub fn experimental_create_circuit( _show_output: bool, ) -> Result<(Circuit, Abi), RuntimeError> { optimize_into_acir(_program); - std::process::exit(0); + let error_kind = RuntimeErrorKind::Spanless("Acir-gen is unimplemented".into()); + Err(RuntimeError::new(error_kind, None)) } impl Ssa { diff --git a/crates/noirc_evaluator/src/ssa_refactor/ir/basic_block.rs b/crates/noirc_evaluator/src/ssa_refactor/ir/basic_block.rs index ad9ab914125..333a1211ed5 100644 --- a/crates/noirc_evaluator/src/ssa_refactor/ir/basic_block.rs +++ b/crates/noirc_evaluator/src/ssa_refactor/ir/basic_block.rs @@ -91,10 +91,11 @@ impl BasicBlock { self.terminator().expect("Expected block to have terminator instruction") } - pub(crate) fn mutate_terminator_blocks(&mut self, f: impl FnMut(BasicBlockId) -> BasicBlockId) { - if let Some(terminator) = self.terminator.as_mut() { - terminator.mutate_blocks(f); - } + /// Returns a mutable reference to the terminator of this block. + /// + /// Once this block has finished construction, this is expected to always be Some. + pub(crate) fn unwrap_terminator_mut(&mut self) -> &mut TerminatorInstruction { + self.terminator.as_mut().expect("Expected block to have terminator instruction") } /// Iterate over all the successors of the currently block, as determined by diff --git a/crates/noirc_evaluator/src/ssa_refactor/ir/dfg.rs b/crates/noirc_evaluator/src/ssa_refactor/ir/dfg.rs index 98c97bbb5f2..b4471da2a26 100644 --- a/crates/noirc_evaluator/src/ssa_refactor/ir/dfg.rs +++ b/crates/noirc_evaluator/src/ssa_refactor/ir/dfg.rs @@ -312,6 +312,14 @@ impl DataFlowGraph { ) { self.blocks[block].set_terminator(terminator); } + + /// Sets the terminator instruction for the given basic block + pub(crate) fn get_block_terminator_mut( + &mut self, + block: BasicBlockId, + ) -> &mut TerminatorInstruction { + self.blocks[block].unwrap_terminator_mut() + } } impl std::ops::Index for DataFlowGraph { diff --git a/crates/noirc_evaluator/src/ssa_refactor/ir/instruction.rs b/crates/noirc_evaluator/src/ssa_refactor/ir/instruction.rs index fb8473be3d2..0993a2ab748 100644 --- a/crates/noirc_evaluator/src/ssa_refactor/ir/instruction.rs +++ b/crates/noirc_evaluator/src/ssa_refactor/ir/instruction.rs @@ -254,6 +254,7 @@ pub(crate) enum TerminatorInstruction { } impl TerminatorInstruction { + /// Map each ValueId in this terminator to a new value. pub(crate) fn map_values( &self, mut f: impl FnMut(ValueId) -> ValueId, @@ -274,6 +275,7 @@ impl TerminatorInstruction { } } + /// Mutate each BlockId to a new BlockId specified by the given mapping function. pub(crate) fn mutate_blocks(&mut self, mut f: impl FnMut(BasicBlockId) -> BasicBlockId) { use TerminatorInstruction::*; match self { diff --git a/crates/noirc_evaluator/src/ssa_refactor/opt/unrolling.rs b/crates/noirc_evaluator/src/ssa_refactor/opt/unrolling.rs index 935994f8a14..e763b94a4e7 100644 --- a/crates/noirc_evaluator/src/ssa_refactor/opt/unrolling.rs +++ b/crates/noirc_evaluator/src/ssa_refactor/opt/unrolling.rs @@ -16,7 +16,13 @@ use crate::ssa_refactor::{ ssa_gen::Ssa, }; +/// Arbitrary maximum of 10k loops unrolled in a program to prevent looping forever +/// if a bug causes us to continually unroll the same loop. +const MAX_LOOPS_UNROLLED: u32 = 10_000; + impl Ssa { + /// Unroll all loops in each SSA function. + /// Panics if any loop cannot be unrolled. pub(crate) fn unroll_loops(mut self) -> Ssa { for function in self.functions.values_mut() { unroll_loops_in_function(function); @@ -25,12 +31,11 @@ impl Ssa { } } +/// Unroll all loops within a given function. +/// This will panic if the function has more than MAX_LOOPS_UNROLLED loops to unroll +/// or if the function has loops that cannot be unrolled because it has non-constant indices. fn unroll_loops_in_function(function: &mut Function) { - // Arbitrary maximum of 10k loops unrolled in a program to prevent looping forever - // if a bug causes us to continually unroll the same loop. - let max_loops_unrolled = 10_000; - - for _ in 0..max_loops_unrolled { + for _ in 0..MAX_LOOPS_UNROLLED { // Recompute the cfg & dom_tree after each loop in case we unrolled into another loop. // TODO: Optimize: lazily recompute this only if the next loops' blocks have already been visited. let cfg = ControlFlowGraph::with_function(function); @@ -44,11 +49,28 @@ fn unroll_loops_in_function(function: &mut Function) { } } - panic!("Did not finish unrolling all loops after the maximum of {max_loops_unrolled} loops unrolled") + panic!("Did not finish unrolling all loops after the maximum of {MAX_LOOPS_UNROLLED} loops unrolled") +} + +struct Loop { + /// The header block of a loop is the block which dominates all the + /// other blocks in the loop. + header: BasicBlockId, + + /// The start of the back_edge n -> d is the block n at the end of + /// the loop that jumps back to the header block d which restarts the loop. + back_edge_start: BasicBlockId, + + /// All the blocks contained within the loop, including `header` and `back_edge_start`. + blocks: HashSet, } /// Find a loop in the program by finding a node that dominates any predecessor node. /// The edge where this happens will be the back-edge of the loop. +/// +/// We could change this to return all loops in the function instead, but we'd have to +/// make sure to automatically refresh the list if any blocks within one loop were modified +/// as a result of inlining another. fn find_next_loop( function: &Function, cfg: &ControlFlowGraph, @@ -68,8 +90,9 @@ fn find_next_loop( } } - // Sort loops by block size so that we unroll the smaller, nested loops first. - loops.sort_by_key(|loop_| loop_.blocks.len()); + // Sort loops by block size so that we unroll the smaller, nested loops first as an + // optimization. + loops.sort_by(|loop_a, loop_b| loop_b.blocks.len().cmp(&loop_a.blocks.len())); loops.pop() } @@ -90,6 +113,8 @@ fn find_blocks_in_loop( } }; + // Starting from the back edge of the loop, each predecessor of this block until + // the header is within the loop. let mut stack = vec![]; insert(back_edge_start, &mut stack); @@ -102,6 +127,7 @@ fn find_blocks_in_loop( Loop { header, back_edge_start, blocks } } +/// Unroll a single loop in the function fn unroll_loop(function: &mut Function, cfg: &ControlFlowGraph, loop_: Loop) -> Result<(), ()> { let mut unroll_into = get_pre_header(cfg, &loop_); let mut jump_value = get_induction_variable(function, unroll_into)?; @@ -147,7 +173,9 @@ fn get_induction_variable(function: &Function, block: BasicBlockId) -> Result( function: &'a mut Function, loop_: &'a Loop, @@ -159,63 +187,47 @@ fn unroll_loop_header<'a>( assert_eq!(source_block.parameters().len(), 1, "Expected only 1 argument in loop header",); let first_param = source_block.parameters()[0]; - println!( - "Remembering {} <- {} ({:?})", - first_param, - induction_value, - context.function.dfg.get_numeric_constant(induction_value) - ); context.values.insert(first_param, induction_value); - context.inline_instructions_from_block(); match context.function.dfg[unroll_into].unwrap_terminator() { TerminatorInstruction::JmpIf { condition, then_destination, else_destination } => { - let condition = context.get_value(*condition); - - match context.function.dfg.get_numeric_constant(condition) { - Some(constant) => { - let next_block = if constant.is_zero() { *else_destination } else { *then_destination }; - - // context.insert_block = next_block; - context.source_block = context.get_original_block(next_block); - - context.function.dfg.set_block_terminator(context.insert_block, TerminatorInstruction::Jmp { - destination: next_block, - arguments: Vec::new(), - }); - - // If the next block to jump to is outside of the loop, return None - loop_.blocks.contains(&context.source_block).then_some(context) - }, - None => { - // Non-constant loop. We have to reset the then and else destination back to - // the original blocks here since we won't be unrolling into the new blocks. - context.function.dfg.set_block_terminator(context.insert_block, TerminatorInstruction::JmpIf { - condition, - then_destination: context.get_original_block(*then_destination), - else_destination: context.get_original_block(*else_destination), - }); - - None - }, - } + let next_blocks = context.handle_jmpif(*condition, *then_destination, *else_destination); + + // If there is only 1 next block the jmpif evaluated to a single known block. + // This is the expected case and lets us know if we should loop again or not. + if next_blocks.len() == 1 { + loop_.blocks.contains(&context.source_block).then_some(context) + } else { + // Non-constant loop. We have to reset the then and else destination back to + // the original blocks here since we won't be unrolling into the new blocks. + context.function.dfg.get_block_terminator_mut(context.insert_block) + .mutate_blocks(|block| context.original_blocks[&block]); + None + } } other => panic!("Expected loop header to terminate in a JmpIf to the loop body, but found {other:?} instead"), } } +/// The context object for each loop iteration. +/// Notably each loop iteration maps each loop block to a fresh, unrolled block. struct LoopIteration<'f> { function: &'f mut Function, loop_: &'f Loop, values: HashMap, blocks: HashMap, + + /// Maps unrolled block ids back to the original source block ids original_blocks: HashMap, visited_blocks: HashSet, insert_block: BasicBlockId, source_block: BasicBlockId, + + /// The induction value (and the block it was found in) is the new value for + /// the variable traditionally called `i` on each iteration of the loop. induction_value: Option<(BasicBlockId, ValueId)>, } @@ -240,9 +252,13 @@ impl<'f> LoopIteration<'f> { } /// Unroll a single iteration of the loop. + /// + /// Note that after unrolling a single iteration, the loop is _not_ in a valid state. + /// It is expected the terminator instructions are set up to branch into an empty block + /// for further unrolling. When the loop is finished this will need to be mutated to + /// jump to the end of the loop instead. fn unroll_loop_iteration(mut self) -> (BasicBlockId, ValueId) { let mut next_blocks = self.unroll_loop_block(); - next_blocks.retain(|block| self.loop_.blocks.contains(&self.get_original_block(*block))); while let Some(block) = next_blocks.pop() { self.insert_block = block; @@ -250,7 +266,6 @@ impl<'f> LoopIteration<'f> { if !self.visited_blocks.contains(&self.source_block) { let mut blocks = self.unroll_loop_block(); - blocks.retain(|block| self.loop_.blocks.contains(&self.get_original_block(*block))); next_blocks.append(&mut blocks); } } @@ -261,27 +276,22 @@ impl<'f> LoopIteration<'f> { /// Unroll a single block in the current iteration of the loop fn unroll_loop_block(&mut self) -> Vec { + let mut next_blocks = self.unroll_loop_block_helper(); + next_blocks.retain(|block| { + let b = self.get_original_block(*block); + self.loop_.blocks.contains(&b) + }); + next_blocks + } + + /// Unroll a single block in the current iteration of the loop + fn unroll_loop_block_helper(&mut self) -> Vec { self.inline_instructions_from_block(); self.visited_blocks.insert(self.source_block); match self.function.dfg[self.insert_block].unwrap_terminator() { TerminatorInstruction::JmpIf { condition, then_destination, else_destination } => { - let condition = self.get_value(*condition); - - match self.function.dfg.get_numeric_constant(condition) { - Some(constant) => { - let destination = - if constant.is_zero() { *else_destination } else { *then_destination }; - - let jmp = TerminatorInstruction::Jmp { destination, arguments: Vec::new() }; - self.function.dfg.set_block_terminator(self.insert_block, jmp); - - vec![destination] - } - None => { - vec![*then_destination, *else_destination] - } - } + self.handle_jmpif(*condition, *then_destination, *else_destination) } TerminatorInstruction::Jmp { destination, arguments } => { if self.get_original_block(*destination) == self.loop_.header { @@ -294,10 +304,37 @@ impl<'f> LoopIteration<'f> { } } + /// Find the next branch(es) to take from a jmpif terminator and return them. + /// If only one block is returned, it means the jmpif condition evaluated to a known + /// constant and we can safely take only the given branch. + fn handle_jmpif( + &mut self, + condition: ValueId, + then_destination: BasicBlockId, + else_destination: BasicBlockId, + ) -> Vec { + let condition = self.get_value(condition); + + match self.function.dfg.get_numeric_constant(condition) { + Some(constant) => { + let destination = + if constant.is_zero() { else_destination } else { then_destination }; + + self.source_block = self.get_original_block(destination); + let jmp = TerminatorInstruction::Jmp { destination, arguments: Vec::new() }; + self.function.dfg.set_block_terminator(self.insert_block, jmp); + vec![destination] + } + None => vec![then_destination, else_destination], + } + } + fn get_value(&self, value: ValueId) -> ValueId { self.values.get(&value).copied().unwrap_or(value) } + /// Translate a block id to a block id in the unrolled loop. If the given + /// block id is not within the loop, it is returned as-is. fn get_or_insert_block(&mut self, block: BasicBlockId) -> BasicBlockId { if let Some(new_block) = self.blocks.get(&block) { return *new_block; @@ -311,7 +348,8 @@ impl<'f> LoopIteration<'f> { let new_parameters = self.function.dfg.block_parameters(new_block); for (param, new_param) in old_parameters.iter().zip(new_parameters) { - self.values.insert(*param, *new_param); + // Don't overwrite any existing entries to avoid overwriting the induction variable + self.values.entry(*param).or_insert(*new_param); } self.blocks.insert(block, new_block); @@ -328,38 +366,21 @@ impl<'f> LoopIteration<'f> { fn inline_instructions_from_block(&mut self) { let source_block = &self.function.dfg[self.source_block]; - // assert_eq!( - // source_block.parameters().len(), - // jmp_args.len(), - // "Parameter len != arg len when inlining block {} into {}", - // self.source_block, - // self.insert_block, - // ); - - // Map each parameter to its new value - // for (param, arg) in source_block.parameters().iter().zip(jmp_args) { - // self.values.insert(*param, *arg); - // } - let instructions = source_block.instructions().to_vec(); - // We cannot directly append each instruction since we need to substitute the - // block parameter values. + // We cannot directly append each instruction since we need to substitute any + // instances of the induction variable or any values that were changed as a result + // of the new induction variable value. for instruction in instructions { self.push_instruction(instruction); } let mut terminator = self.function.dfg[self.source_block] - .terminator() - .expect( - "Expected each block during the loop unrolling to have a terminator instruction", - ) - .map_values(|id| self.get_value(id)); + .unwrap_terminator() + .map_values(|value| self.get_value(value)); terminator.mutate_blocks(|block| self.get_or_insert_block(block)); self.function.dfg.set_block_terminator(self.insert_block, terminator); - - println!("Unrolled block: \n{}", self.function); } fn push_instruction(&mut self, id: InstructionId) { @@ -375,6 +396,7 @@ impl<'f> LoopIteration<'f> { self.insert_block, ctrl_typevars, ); + Self::insert_new_instruction_results(&mut self.values, &results, new_results); } @@ -400,16 +422,3 @@ impl<'f> LoopIteration<'f> { } } } - -struct Loop { - /// The header block of a loop is the block which dominates all the - /// other blocks in the loop. - header: BasicBlockId, - - /// The start of the back_edge n -> d is the block n at the end of - /// the loop that jumps back to the header block d which restarts the loop. - back_edge_start: BasicBlockId, - - /// All the blocks contained within the loop, including `header` and `back_edge_start`. - blocks: HashSet, -} From b242fac73d0d4647b07f5e1208c649354740f28e Mon Sep 17 00:00:00 2001 From: Jake Fecher Date: Wed, 17 May 2023 11:36:42 -0500 Subject: [PATCH 17/25] Add doc comment --- crates/noirc_evaluator/src/ssa_refactor/opt/mod.rs | 1 - .../src/ssa_refactor/opt/simplify_cfg.rs | 11 ----------- .../noirc_evaluator/src/ssa_refactor/opt/unrolling.rs | 10 ++++++++++ 3 files changed, 10 insertions(+), 12 deletions(-) delete mode 100644 crates/noirc_evaluator/src/ssa_refactor/opt/simplify_cfg.rs diff --git a/crates/noirc_evaluator/src/ssa_refactor/opt/mod.rs b/crates/noirc_evaluator/src/ssa_refactor/opt/mod.rs index 997472bde84..50ee74cf609 100644 --- a/crates/noirc_evaluator/src/ssa_refactor/opt/mod.rs +++ b/crates/noirc_evaluator/src/ssa_refactor/opt/mod.rs @@ -4,5 +4,4 @@ //! simpler form until the IR only has a single function remaining with 1 block within it. //! Generally, these passes are also expected to minimize the final amount of instructions. mod inlining; -mod simplify_cfg; mod unrolling; diff --git a/crates/noirc_evaluator/src/ssa_refactor/opt/simplify_cfg.rs b/crates/noirc_evaluator/src/ssa_refactor/opt/simplify_cfg.rs deleted file mode 100644 index 5cd0759955e..00000000000 --- a/crates/noirc_evaluator/src/ssa_refactor/opt/simplify_cfg.rs +++ /dev/null @@ -1,11 +0,0 @@ -use crate::ssa_refactor::Ssa; - -impl Ssa { - /// Simplifies the function's control flow graph by removing blocks - pub(crate) fn simplify_cfg(mut self) -> Ssa { - for _function in self.functions.values_mut() { - // Context::new(function).simplify_function_cfg(); - } - self - } -} diff --git a/crates/noirc_evaluator/src/ssa_refactor/opt/unrolling.rs b/crates/noirc_evaluator/src/ssa_refactor/opt/unrolling.rs index e763b94a4e7..819ee018fcd 100644 --- a/crates/noirc_evaluator/src/ssa_refactor/opt/unrolling.rs +++ b/crates/noirc_evaluator/src/ssa_refactor/opt/unrolling.rs @@ -1,3 +1,13 @@ +//! This file contains the loop unrolling pass for the new SSA IR. +//! +//! This pass is divided into three steps: +//! 1. Find a loop in the program (`find_next_loop`) +//! 2. Unroll that loop into its "pre-header" block (`unroll_loop`) +//! 3. Repeat until no more loops are found +//! +//! Note that unrolling loops will fail if there are loops with non-constant +//! indices. This pass also often creates superfluous jmp instructions in the +//! program that will need to be removed by a later simplify cfg pass. use std::collections::{HashMap, HashSet}; use iter_extended::vecmap; From a53c6431373a4b19983ecaeb23d6f8cfdf65a58b Mon Sep 17 00:00:00 2001 From: Jake Fecher Date: Wed, 17 May 2023 11:50:44 -0500 Subject: [PATCH 18/25] Fix bad merge --- crates/noirc_evaluator/src/ssa_refactor/ir/dfg.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/crates/noirc_evaluator/src/ssa_refactor/ir/dfg.rs b/crates/noirc_evaluator/src/ssa_refactor/ir/dfg.rs index cca5387882c..bd23b9ea17b 100644 --- a/crates/noirc_evaluator/src/ssa_refactor/ir/dfg.rs +++ b/crates/noirc_evaluator/src/ssa_refactor/ir/dfg.rs @@ -143,7 +143,7 @@ impl DataFlowGraph { SimplifyResult::Remove => InstructionRemoved, SimplifyResult::None => { let id = self.make_instruction(instruction, ctrl_typevars); - self.blocks[block].insert_instruction(instruction); + self.blocks[block].insert_instruction(id); InsertInstructionResult::Results(self.instruction_results(id)) } } From 4e0a16345db4c69653ffbb1541c48d86fb8551a2 Mon Sep 17 00:00:00 2001 From: Jake Fecher Date: Wed, 17 May 2023 12:21:07 -0500 Subject: [PATCH 19/25] Add test --- .../src/ssa_refactor/opt/unrolling.rs | 116 ++++++++++++++++++ 1 file changed, 116 insertions(+) diff --git a/crates/noirc_evaluator/src/ssa_refactor/opt/unrolling.rs b/crates/noirc_evaluator/src/ssa_refactor/opt/unrolling.rs index 819ee018fcd..18b709b3c83 100644 --- a/crates/noirc_evaluator/src/ssa_refactor/opt/unrolling.rs +++ b/crates/noirc_evaluator/src/ssa_refactor/opt/unrolling.rs @@ -432,3 +432,119 @@ impl<'f> LoopIteration<'f> { } } } + +#[cfg(test)] +mod tests { + use crate::ssa_refactor::{ + ir::{dom::DominatorTree, instruction::BinaryOp, map::Id, types::Type}, + ssa_builder::FunctionBuilder, + ssa_gen::Ssa, + }; + + #[test] + fn unroll_nested_loops() { + // fn main() { + // for i in 0..3 { + // for j in 0..4 { + // assert(i + j > 10); + // } + // } + // } + // + // fn main f0 { + // b0(): + // jmp b1(Field 0) + // b1(v0: Field): // header of outer loop + // v1 = lt v0, Field 3 + // jmpif v1, then: b2, else: b3 + // b2(): + // jmp b4(Field 0) + // b4(v2: Field): // header of inner loop + // v3 = lt v2, Field 4 + // jmpif v3, then: b5, else: b6 + // b5(): + // v4 = add v0, v2 + // v5 = lt Field 10, v4 + // constrain v5 + // v6 = add v2, Field 1 + // jmp b4(v6) + // b6(): // end of inner loop + // v7 = add v0, Field 1 + // jmp b1(v7) + // b3(): // end of outer loop + // return Field 0 + // } + let main_id = Id::test_new(0); + + // Compiling main + let mut builder = FunctionBuilder::new("main".into(), main_id); + + let b1 = builder.insert_block(); + let b2 = builder.insert_block(); + let b3 = builder.insert_block(); + let b4 = builder.insert_block(); + let b5 = builder.insert_block(); + let b6 = builder.insert_block(); + + let v0 = builder.add_block_parameter(b1, Type::field()); + let v2 = builder.add_block_parameter(b4, Type::field()); + + let zero = builder.field_constant(0u128); + let one = builder.field_constant(1u128); + let three = builder.field_constant(3u128); + let four = builder.field_constant(4u128); + let ten = builder.field_constant(10u128); + + builder.terminate_with_jmp(b1, vec![zero]); + + // b1 + builder.switch_to_block(b1); + let v1 = builder.insert_binary(v0, BinaryOp::Lt, three); + builder.terminate_with_jmpif(v1, b2, b3); + + // b2 + builder.switch_to_block(b2); + builder.terminate_with_jmp(b4, vec![zero]); + + // b3 + builder.switch_to_block(b3); + builder.terminate_with_return(vec![zero]); + + // b4 + builder.switch_to_block(b4); + let v3 = builder.insert_binary(v2, BinaryOp::Lt, four); + builder.terminate_with_jmpif(v3, b5, b6); + + // b5 + builder.switch_to_block(b5); + let v4 = builder.insert_binary(v0, BinaryOp::Add, v2); + let v5 = builder.insert_binary(ten, BinaryOp::Lt, v4); + builder.insert_constrain(v5); + let v6 = builder.insert_binary(v2, BinaryOp::Add, one); + builder.terminate_with_jmp(b4, vec![v6]); + + // b6 + builder.switch_to_block(b6); + let v7 = builder.insert_binary(v0, BinaryOp::Add, one); + builder.terminate_with_jmp(b1, vec![v7]); + + // basic_blocks_iter iterates over unreachable blocks as well, so we must filter those out. + let count_reachable_blocks = |ssa: &Ssa| { + let function = ssa.main(); + let dom_tree = DominatorTree::with_function(function); + function + .dfg + .basic_blocks_iter() + .filter(|(block, _)| dom_tree.is_reachable(*block)) + .count() + }; + + let ssa = builder.finish(); + assert_eq!(count_reachable_blocks(&ssa), 7); + + // The final block count is not 1 because the block creates some unnecessary jmps. + // If a simplify cfg pass is ran afterward, the expected block count will be 1. + let ssa = ssa.unroll_loops(); + assert_eq!(count_reachable_blocks(&ssa), 5); + } +} From 5faa69ad8f2913af48f66db70b686b47bd2f2c7e Mon Sep 17 00:00:00 2001 From: Jake Fecher Date: Wed, 17 May 2023 12:31:19 -0500 Subject: [PATCH 20/25] Remove outdated parts of PR --- .../noirc_evaluator/src/ssa_refactor/ir/cfg.rs | 4 ++++ .../noirc_evaluator/src/ssa_refactor/ir/dfg.rs | 16 ---------------- .../src/ssa_refactor/ir/function.rs | 18 ------------------ .../src/ssa_refactor/ir/printer.rs | 4 ++-- 4 files changed, 6 insertions(+), 36 deletions(-) diff --git a/crates/noirc_evaluator/src/ssa_refactor/ir/cfg.rs b/crates/noirc_evaluator/src/ssa_refactor/ir/cfg.rs index f219c874fa3..b2d16b29bfd 100644 --- a/crates/noirc_evaluator/src/ssa_refactor/ir/cfg.rs +++ b/crates/noirc_evaluator/src/ssa_refactor/ir/cfg.rs @@ -84,6 +84,10 @@ impl ControlFlowGraph { ); predecessor_node.successors.insert(to); let successor_node = self.data.entry(to).or_default(); + assert!( + successor_node.predecessors.len() < 2, + "ICE: A cfg node cannot have more than two predecessors" + ); successor_node.predecessors.insert(from); } diff --git a/crates/noirc_evaluator/src/ssa_refactor/ir/dfg.rs b/crates/noirc_evaluator/src/ssa_refactor/ir/dfg.rs index bd23b9ea17b..c4eb4831400 100644 --- a/crates/noirc_evaluator/src/ssa_refactor/ir/dfg.rs +++ b/crates/noirc_evaluator/src/ssa_refactor/ir/dfg.rs @@ -60,8 +60,6 @@ pub(crate) struct DataFlowGraph { signatures: DenseMap, /// All blocks in a function - /// - /// This map is sparse to allow removing unreachable blocks during optimizations blocks: DenseMap, } @@ -103,11 +101,6 @@ impl DataFlowGraph { self.blocks.iter() } - // Remove all blocks in this DFG that do not satisfy the given predicate - // pub(crate) fn retain_blocks(&mut self, mut predicate: impl FnMut(BasicBlockId) -> bool) { - // self.blocks.retain(|id, _| predicate(*id)) - // } - /// Returns the parameters of the given block pub(crate) fn block_parameters(&self, block: BasicBlockId) -> &[ValueId] { self.blocks[block].parameters() @@ -155,15 +148,6 @@ impl DataFlowGraph { self.values.insert(value) } - /// Replaces the value specified by the given ValueId with a new Value. - /// - /// This is the preferred method to call for optimizations simplifying - /// values since other instructions referring to the same ValueId need - /// not be modified to refer to a new ValueId. - pub(crate) fn set_value(&mut self, value_id: ValueId, new_value: Value) { - self.values[value_id] = new_value; - } - /// Creates a new constant value, or returns the Id to an existing one if /// one already exists. pub(crate) fn make_constant(&mut self, value: FieldElement, typ: Type) -> ValueId { diff --git a/crates/noirc_evaluator/src/ssa_refactor/ir/function.rs b/crates/noirc_evaluator/src/ssa_refactor/ir/function.rs index 8bd20050cfc..f37448462b7 100644 --- a/crates/noirc_evaluator/src/ssa_refactor/ir/function.rs +++ b/crates/noirc_evaluator/src/ssa_refactor/ir/function.rs @@ -61,24 +61,6 @@ impl Function { pub(crate) fn parameters(&self) -> &[ValueId] { self.dfg.block_parameters(self.entry_block) } - - // Remove any unreachable blocks from this function. - // To do this, this method must traverse the cfg of this function in program order. - // pub(crate) fn remove_unreachable_blocks(&mut self) { - // let mut reached = HashSet::new(); - // let mut stack: Vec = vec![self.entry_block]; - - // while let Some(block) = stack.pop() { - // reached.insert(block); - // for block in self.dfg[block].successors() { - // if !reached.contains(&block) { - // stack.push(block); - // } - // } - // } - - // self.dfg.retain_blocks(|block| reached.contains(&block)); - // } } /// FunctionId is a reference for a function diff --git a/crates/noirc_evaluator/src/ssa_refactor/ir/printer.rs b/crates/noirc_evaluator/src/ssa_refactor/ir/printer.rs index 403a2087c91..3993a862618 100644 --- a/crates/noirc_evaluator/src/ssa_refactor/ir/printer.rs +++ b/crates/noirc_evaluator/src/ssa_refactor/ir/printer.rs @@ -59,12 +59,12 @@ pub(crate) fn display_block( /// Specialize displaying value ids so that if they refer to a numeric /// constant or a function we print those directly. -pub(crate) fn value(function: &Function, id: ValueId) -> String { +fn value(function: &Function, id: ValueId) -> String { use super::value::Value; match &function.dfg[id] { Value::NumericConstant { constant, typ } => { let value = function.dfg[*constant].value(); - format!("{typ} {value} ({id})") + format!("{typ} {value}") } Value::Function(id) => id.to_string(), Value::Intrinsic(intrinsic) => intrinsic.to_string(), From c270aae030c1c374f60fc86833d340f4c6fc5fa0 Mon Sep 17 00:00:00 2001 From: Jake Fecher Date: Thu, 18 May 2023 13:12:40 -0500 Subject: [PATCH 21/25] Correctly handle loops with non-const indices --- .../src/ssa_refactor/ir/basic_block.rs | 12 ++ .../src/ssa_refactor/opt/unrolling.rs | 135 +++++++++++------- 2 files changed, 92 insertions(+), 55 deletions(-) diff --git a/crates/noirc_evaluator/src/ssa_refactor/ir/basic_block.rs b/crates/noirc_evaluator/src/ssa_refactor/ir/basic_block.rs index 333a1211ed5..d04cb30b6cb 100644 --- a/crates/noirc_evaluator/src/ssa_refactor/ir/basic_block.rs +++ b/crates/noirc_evaluator/src/ssa_refactor/ir/basic_block.rs @@ -124,4 +124,16 @@ impl BasicBlock { }); self.instructions.remove(index); } + + /// Take ownership of this block's terminator, replacing it with an empty return terminator + /// so that no clone is needed. + /// + /// It is expected that this function is used as an optimization on blocks that are no longer + /// reachable or will have their terminator overwritten afterwards. Using this on a reachable + /// block without setting the terminator afterward will result in the empty return terminator + /// being kept, which is likely unwanted. + pub(crate) fn take_terminator(&mut self) -> TerminatorInstruction { + let terminator = self.terminator.as_mut().expect("Expected block to have a terminator"); + std::mem::replace(terminator, TerminatorInstruction::Return { return_values: Vec::new() }) + } } diff --git a/crates/noirc_evaluator/src/ssa_refactor/opt/unrolling.rs b/crates/noirc_evaluator/src/ssa_refactor/opt/unrolling.rs index 18b709b3c83..ee907a635a9 100644 --- a/crates/noirc_evaluator/src/ssa_refactor/opt/unrolling.rs +++ b/crates/noirc_evaluator/src/ssa_refactor/opt/unrolling.rs @@ -1,12 +1,16 @@ //! This file contains the loop unrolling pass for the new SSA IR. //! -//! This pass is divided into three steps: -//! 1. Find a loop in the program (`find_next_loop`) -//! 2. Unroll that loop into its "pre-header" block (`unroll_loop`) -//! 3. Repeat until no more loops are found +//! This pass is divided into a few steps: +//! 1. Find all loops in the program (`find_all_loops`) +//! 2. For each loop: +//! a. If the loop is in our list of loops that previously failed to unroll, skip it. +//! b. If we have previously modified any of the blocks in the loop, +//! restart from step 1 to refresh the context. +//! c. If not, try to unroll the loop. If successful, remember the modified +//! blocks. If not, remember that the loop failed to unroll and leave it +//! unmodified. //! -//! Note that unrolling loops will fail if there are loops with non-constant -//! indices. This pass also often creates superfluous jmp instructions in the +//! Note that this pass also often creates superfluous jmp instructions in the //! program that will need to be removed by a later simplify cfg pass. use std::collections::{HashMap, HashSet}; @@ -26,42 +30,17 @@ use crate::ssa_refactor::{ ssa_gen::Ssa, }; -/// Arbitrary maximum of 10k loops unrolled in a program to prevent looping forever -/// if a bug causes us to continually unroll the same loop. -const MAX_LOOPS_UNROLLED: u32 = 10_000; - impl Ssa { /// Unroll all loops in each SSA function. /// Panics if any loop cannot be unrolled. pub(crate) fn unroll_loops(mut self) -> Ssa { for function in self.functions.values_mut() { - unroll_loops_in_function(function); + find_all_loops(function).unroll_each_loop(function); } self } } -/// Unroll all loops within a given function. -/// This will panic if the function has more than MAX_LOOPS_UNROLLED loops to unroll -/// or if the function has loops that cannot be unrolled because it has non-constant indices. -fn unroll_loops_in_function(function: &mut Function) { - for _ in 0..MAX_LOOPS_UNROLLED { - // Recompute the cfg & dom_tree after each loop in case we unrolled into another loop. - // TODO: Optimize: lazily recompute this only if the next loops' blocks have already been visited. - let cfg = ControlFlowGraph::with_function(function); - let post_order = PostOrder::with_function(function); - let dom_tree = DominatorTree::with_cfg_and_post_order(&cfg, &post_order); - - if let Some(loop_) = find_next_loop(function, &cfg, &dom_tree) { - unroll_loop(function, &cfg, loop_).ok(); - } else { - return; - } - } - - panic!("Did not finish unrolling all loops after the maximum of {MAX_LOOPS_UNROLLED} loops unrolled") -} - struct Loop { /// The header block of a loop is the block which dominates all the /// other blocks in the loop. @@ -75,17 +54,24 @@ struct Loop { blocks: HashSet, } +struct Loops { + /// The loops that failed to be unrolled so that we do not try to unroll them again. + /// Each loop is identified by its header block id. + failed_to_unroll: HashSet, + + yet_to_unroll: Vec, + modified_blocks: HashSet, + cfg: ControlFlowGraph, + dom_tree: DominatorTree, +} + /// Find a loop in the program by finding a node that dominates any predecessor node. /// The edge where this happens will be the back-edge of the loop. -/// -/// We could change this to return all loops in the function instead, but we'd have to -/// make sure to automatically refresh the list if any blocks within one loop were modified -/// as a result of inlining another. -fn find_next_loop( - function: &Function, - cfg: &ControlFlowGraph, - dom_tree: &DominatorTree, -) -> Option { +fn find_all_loops(function: &Function) -> Loops { + let cfg = ControlFlowGraph::with_function(function); + let post_order = PostOrder::with_function(function); + let dom_tree = DominatorTree::with_cfg_and_post_order(&cfg, &post_order); + let mut loops = vec![]; for (block, _) in function.dfg.basic_blocks_iter() { @@ -94,16 +80,47 @@ fn find_next_loop( for predecessor in cfg.predecessors(block) { if dom_tree.is_reachable(predecessor) && dom_tree.dominates(block, predecessor) { // predecessor -> block is the back-edge of a loop - loops.push(find_blocks_in_loop(block, predecessor, cfg)); + loops.push(find_blocks_in_loop(block, predecessor, &cfg)); } } } } - // Sort loops by block size so that we unroll the smaller, nested loops first as an - // optimization. + // Sort loops by block size so that we unroll the smaller, nested loops first as an optimization. loops.sort_by(|loop_a, loop_b| loop_b.blocks.len().cmp(&loop_a.blocks.len())); - loops.pop() + + Loops { + failed_to_unroll: HashSet::new(), + yet_to_unroll: loops, + modified_blocks: HashSet::new(), + cfg, + dom_tree, + } +} + +impl Loops { + /// Unroll all loops within a given function. + /// Any loops which fail to be unrolled (due to using non-constant indices) will be unmodified. + fn unroll_each_loop(mut self, function: &mut Function) { + while let Some(next_loop) = self.yet_to_unroll.pop() { + // If we've previously modified a block in this loop we need to refresh the context. + // This happens any time we have nested loops. + if next_loop.blocks.iter().any(|block| self.modified_blocks.contains(block)) { + let mut new_context = find_all_loops(function); + new_context.failed_to_unroll = std::mem::take(&mut self.failed_to_unroll); + return new_context.unroll_each_loop(function); + } + + // Don't try to unroll the loop again if it is known to fail + if !self.failed_to_unroll.contains(&next_loop.header) { + if unroll_loop(function, &self.cfg, &next_loop).is_ok() { + self.modified_blocks.extend(next_loop.blocks); + } else { + self.failed_to_unroll.insert(next_loop.header); + } + } + } + } } /// Return each block that is in a loop starting in the given header block. @@ -138,11 +155,11 @@ fn find_blocks_in_loop( } /// Unroll a single loop in the function -fn unroll_loop(function: &mut Function, cfg: &ControlFlowGraph, loop_: Loop) -> Result<(), ()> { - let mut unroll_into = get_pre_header(cfg, &loop_); +fn unroll_loop(function: &mut Function, cfg: &ControlFlowGraph, loop_: &Loop) -> Result<(), ()> { + let mut unroll_into = get_pre_header(cfg, loop_); let mut jump_value = get_induction_variable(function, unroll_into)?; - while let Some(context) = unroll_loop_header(function, &loop_, unroll_into, jump_value) { + while let Some(context) = unroll_loop_header(function, loop_, unroll_into, jump_value) { let (last_block, last_value) = context.unroll_loop_iteration(); unroll_into = last_block; jump_value = last_value; @@ -192,28 +209,36 @@ fn unroll_loop_header<'a>( unroll_into: BasicBlockId, induction_value: ValueId, ) -> Option> { - let mut context = LoopIteration::new(function, loop_, unroll_into, loop_.header); + // We insert into a fresh block first and move instructions into the unroll_into block later + // only once we verify the jmpif instruction has a constant condition. If it does not, we can + // just discard this fresh block and leave the loop unmodified. + let fresh_block = function.dfg.make_block(); + + let mut context = LoopIteration::new(function, loop_, fresh_block, loop_.header); let source_block = &context.function.dfg[context.source_block]; assert_eq!(source_block.parameters().len(), 1, "Expected only 1 argument in loop header",); + // Insert the current value of the loop induction variable into our context. let first_param = source_block.parameters()[0]; context.values.insert(first_param, induction_value); context.inline_instructions_from_block(); - match context.function.dfg[unroll_into].unwrap_terminator() { + match context.function.dfg[fresh_block].unwrap_terminator() { TerminatorInstruction::JmpIf { condition, then_destination, else_destination } => { let next_blocks = context.handle_jmpif(*condition, *then_destination, *else_destination); // If there is only 1 next block the jmpif evaluated to a single known block. // This is the expected case and lets us know if we should loop again or not. if next_blocks.len() == 1 { + let mut instructions = std::mem::take(context.function.dfg[fresh_block].instructions_mut()); + let terminator = context.function.dfg[fresh_block].take_terminator(); + context.function.dfg[unroll_into].instructions_mut().append(&mut instructions); + context.function.dfg.set_block_terminator(unroll_into, terminator); + loop_.blocks.contains(&context.source_block).then_some(context) } else { - // Non-constant loop. We have to reset the then and else destination back to - // the original blocks here since we won't be unrolling into the new blocks. - context.function.dfg.get_block_terminator_mut(context.insert_block) - .mutate_blocks(|block| context.original_blocks[&block]); - + // If this case is reached the loop either uses non-constant indices or we need + // another pass, such as mem2reg to resolve them to constants. None } } From 1f47a3c7fb1f9aa7e1a4c9ca2b03152f48e9a1a2 Mon Sep 17 00:00:00 2001 From: Jake Fecher Date: Mon, 22 May 2023 09:04:50 -0700 Subject: [PATCH 22/25] Address PR comments --- .../src/ssa_refactor/ir/instruction.rs | 25 +++++++++++-------- .../src/ssa_refactor/opt/unrolling.rs | 10 +++++--- 2 files changed, 21 insertions(+), 14 deletions(-) diff --git a/crates/noirc_evaluator/src/ssa_refactor/ir/instruction.rs b/crates/noirc_evaluator/src/ssa_refactor/ir/instruction.rs index 0993a2ab748..7ca23c6f8a9 100644 --- a/crates/noirc_evaluator/src/ssa_refactor/ir/instruction.rs +++ b/crates/noirc_evaluator/src/ssa_refactor/ir/instruction.rs @@ -359,28 +359,26 @@ impl Binary { } BinaryOp::Mod => { if rhs_is_one { - return SimplifyResult::SimplifiedTo(self.lhs); + let zero = dfg.make_constant(FieldElement::zero(), operand_type); + return SimplifyResult::SimplifiedTo(zero); } } BinaryOp::Eq => { if self.lhs == self.rhs { - return SimplifyResult::SimplifiedTo( - dfg.make_constant(FieldElement::one(), Type::bool()), - ); + let one = dfg.make_constant(FieldElement::one(), Type::bool()); + return SimplifyResult::SimplifiedTo(one); } } BinaryOp::Lt => { if self.lhs == self.rhs { - return SimplifyResult::SimplifiedTo( - dfg.make_constant(FieldElement::zero(), Type::bool()), - ); + let zero = dfg.make_constant(FieldElement::zero(), Type::bool()); + return SimplifyResult::SimplifiedTo(zero); } } BinaryOp::And => { if lhs_is_zero || rhs_is_zero { - return SimplifyResult::SimplifiedTo( - dfg.make_constant(FieldElement::zero(), operand_type), - ); + let zero = dfg.make_constant(FieldElement::zero(), operand_type); + return SimplifyResult::SimplifiedTo(zero); } } BinaryOp::Or => { @@ -391,7 +389,12 @@ impl Binary { return SimplifyResult::SimplifiedTo(self.lhs); } } - BinaryOp::Xor => (), + BinaryOp::Xor => { + if self.lhs == self.rhs { + let zero = dfg.make_constant(FieldElement::zero(), Type::bool()); + return SimplifyResult::SimplifiedTo(zero); + } + } BinaryOp::Shl => { if rhs_is_zero { return SimplifyResult::SimplifiedTo(self.lhs); diff --git a/crates/noirc_evaluator/src/ssa_refactor/opt/unrolling.rs b/crates/noirc_evaluator/src/ssa_refactor/opt/unrolling.rs index ee907a635a9..bdd65e30749 100644 --- a/crates/noirc_evaluator/src/ssa_refactor/opt/unrolling.rs +++ b/crates/noirc_evaluator/src/ssa_refactor/opt/unrolling.rs @@ -32,7 +32,7 @@ use crate::ssa_refactor::{ impl Ssa { /// Unroll all loops in each SSA function. - /// Panics if any loop cannot be unrolled. + /// If any loop cannot be unrolled, it is left as-is or in a partially unrolled state. pub(crate) fn unroll_loops(mut self) -> Ssa { for function in self.functions.values_mut() { find_all_loops(function).unroll_each_loop(function); @@ -188,7 +188,11 @@ fn get_pre_header(cfg: &ControlFlowGraph, loop_: &Loop) -> BasicBlockId { fn get_induction_variable(function: &Function, block: BasicBlockId) -> Result { match function.dfg[block].terminator() { Some(TerminatorInstruction::Jmp { arguments, .. }) => { - assert_eq!(arguments.len(), 1); + // This assumption will no longer be valid if e.g. mutable variables are represented as + // block parameters. If that becomes the case we'll need to figure out which variable + // is generally constant and increasing to guess which parameter is the induction + // variable. + assert_eq!(arguments.len(), 1, "It is expected that a loop's induction variable is the only block parameter of the loop header"); let value = arguments[0]; if function.dfg.get_numeric_constant(value).is_some() { Ok(value) @@ -242,7 +246,7 @@ fn unroll_loop_header<'a>( None } } - other => panic!("Expected loop header to terminate in a JmpIf to the loop body, but found {other:?} instead"), + other => unreachable!("Expected loop header to terminate in a JmpIf to the loop body, but found {other:?} instead"), } } From bf4fa228e9c96f9f5d3cd9ee2a539bc62f8d58f3 Mon Sep 17 00:00:00 2001 From: Jake Fecher Date: Mon, 22 May 2023 13:01:35 -0700 Subject: [PATCH 23/25] Fix inlining bug and add a test for loops which fail to unroll --- .../src/ssa_refactor/ir/dfg.rs | 22 +-- .../src/ssa_refactor/opt/unrolling.rs | 130 +++++++++++++++--- 2 files changed, 123 insertions(+), 29 deletions(-) diff --git a/crates/noirc_evaluator/src/ssa_refactor/ir/dfg.rs b/crates/noirc_evaluator/src/ssa_refactor/ir/dfg.rs index 7d076752399..12d36c787ca 100644 --- a/crates/noirc_evaluator/src/ssa_refactor/ir/dfg.rs +++ b/crates/noirc_evaluator/src/ssa_refactor/ir/dfg.rs @@ -285,14 +285,6 @@ impl DataFlowGraph { self.blocks[block].set_terminator(terminator); } - /// Sets the terminator instruction for the given basic block - pub(crate) fn get_block_terminator_mut( - &mut self, - block: BasicBlockId, - ) -> &mut TerminatorInstruction { - self.blocks[block].unwrap_terminator_mut() - } - /// Replaces the value specified by the given ValueId with a new Value. /// /// This is the preferred method to call for optimizations simplifying @@ -301,6 +293,20 @@ impl DataFlowGraph { pub(crate) fn set_value(&mut self, value_id: ValueId, new_value: Value) { self.values[value_id] = new_value; } + + /// Moves the entirety of the given block's contents into the destination block. + /// The source block afterward will be left in a valid but emptied state. The + /// destination block will also have its terminator overwritten with that of the + /// source block. + pub(crate) fn inline_block(&mut self, source: BasicBlockId, destination: BasicBlockId) { + let source = &mut self.blocks[source]; + let mut instructions = std::mem::take(source.instructions_mut()); + let terminator = source.take_terminator(); + + let destination = &mut self.blocks[destination]; + destination.instructions_mut().append(&mut instructions); + destination.set_terminator(terminator); + } } impl std::ops::Index for DataFlowGraph { diff --git a/crates/noirc_evaluator/src/ssa_refactor/opt/unrolling.rs b/crates/noirc_evaluator/src/ssa_refactor/opt/unrolling.rs index bdd65e30749..71df2db7472 100644 --- a/crates/noirc_evaluator/src/ssa_refactor/opt/unrolling.rs +++ b/crates/noirc_evaluator/src/ssa_refactor/opt/unrolling.rs @@ -107,7 +107,7 @@ impl Loops { // This happens any time we have nested loops. if next_loop.blocks.iter().any(|block| self.modified_blocks.contains(block)) { let mut new_context = find_all_loops(function); - new_context.failed_to_unroll = std::mem::take(&mut self.failed_to_unroll); + new_context.failed_to_unroll = self.failed_to_unroll; return new_context.unroll_each_loop(function); } @@ -154,7 +154,8 @@ fn find_blocks_in_loop( Loop { header, back_edge_start, blocks } } -/// Unroll a single loop in the function +/// Unroll a single loop in the function. +/// Returns Err(()) if it failed to unroll and Ok(()) otherwise. fn unroll_loop(function: &mut Function, cfg: &ControlFlowGraph, loop_: &Loop) -> Result<(), ()> { let mut unroll_into = get_pre_header(cfg, loop_); let mut jump_value = get_induction_variable(function, unroll_into)?; @@ -220,7 +221,7 @@ fn unroll_loop_header<'a>( let mut context = LoopIteration::new(function, loop_, fresh_block, loop_.header); let source_block = &context.function.dfg[context.source_block]; - assert_eq!(source_block.parameters().len(), 1, "Expected only 1 argument in loop header",); + assert_eq!(source_block.parameters().len(), 1, "Expected only 1 argument in loop header"); // Insert the current value of the loop induction variable into our context. let first_param = source_block.parameters()[0]; @@ -234,10 +235,11 @@ fn unroll_loop_header<'a>( // If there is only 1 next block the jmpif evaluated to a single known block. // This is the expected case and lets us know if we should loop again or not. if next_blocks.len() == 1 { - let mut instructions = std::mem::take(context.function.dfg[fresh_block].instructions_mut()); - let terminator = context.function.dfg[fresh_block].take_terminator(); - context.function.dfg[unroll_into].instructions_mut().append(&mut instructions); - context.function.dfg.set_block_terminator(unroll_into, terminator); + context.function.dfg.inline_block(fresh_block, unroll_into); + + // The fresh block is gone now so we're committing to insert into the original + // unroll_into block from now on. + context.insert_block = unroll_into; loop_.blocks.contains(&context.source_block).then_some(context) } else { @@ -255,7 +257,14 @@ fn unroll_loop_header<'a>( struct LoopIteration<'f> { function: &'f mut Function, loop_: &'f Loop, + + /// Maps pre-unrolled ValueIds to unrolled ValueIds. + /// These will often be the exact same as before, unless the ValueId was + /// dependent on the loop induction variable which is changing on each iteration. values: HashMap, + + /// Maps pre-unrolled block ids from within the loop to new block ids of each loop + /// block for each loop iteration. blocks: HashMap, /// Maps unrolled block ids back to the original source block ids @@ -267,6 +276,8 @@ struct LoopIteration<'f> { /// The induction value (and the block it was found in) is the new value for /// the variable traditionally called `i` on each iteration of the loop. + /// This is None until we visit the block which jumps back to the start of the + /// loop, at which point we record its value and the block it was found in. induction_value: Option<(BasicBlockId, ValueId)>, } @@ -360,6 +371,7 @@ impl<'f> LoopIteration<'f> { if constant.is_zero() { else_destination } else { then_destination }; self.source_block = self.get_original_block(destination); + let jmp = TerminatorInstruction::Jmp { destination, arguments: Vec::new() }; self.function.dfg.set_block_terminator(self.insert_block, jmp); vec![destination] @@ -368,6 +380,11 @@ impl<'f> LoopIteration<'f> { } } + /// Map a ValueId in the original pre-unrolled ssa to its new id in the unrolled SSA. + /// This is often the same ValueId as most values don't change while unrolling. The main + /// exception is instructions referencing the induction variable (or the variable itself) + /// which may have been simplified to another form. Block parameters or values outside the + /// loop shouldn't change at all and won't be present inside self.values. fn get_value(&self, value: ValueId) -> ValueId { self.values.get(&value).copied().unwrap_or(value) } @@ -470,6 +487,13 @@ mod tests { ssa_gen::Ssa, }; + // basic_blocks_iter iterates over unreachable blocks as well, so we must filter those out. + fn count_reachable_blocks_in_main(ssa: &Ssa) -> usize { + let function = ssa.main(); + let dom_tree = DominatorTree::with_function(function); + function.dfg.basic_blocks_iter().filter(|(block, _)| dom_tree.is_reachable(*block)).count() + } + #[test] fn unroll_nested_loops() { // fn main() { @@ -557,23 +581,87 @@ mod tests { let v7 = builder.insert_binary(v0, BinaryOp::Add, one); builder.terminate_with_jmp(b1, vec![v7]); - // basic_blocks_iter iterates over unreachable blocks as well, so we must filter those out. - let count_reachable_blocks = |ssa: &Ssa| { - let function = ssa.main(); - let dom_tree = DominatorTree::with_function(function); - function - .dfg - .basic_blocks_iter() - .filter(|(block, _)| dom_tree.is_reachable(*block)) - .count() - }; - let ssa = builder.finish(); - assert_eq!(count_reachable_blocks(&ssa), 7); + assert_eq!(count_reachable_blocks_in_main(&ssa), 7); - // The final block count is not 1 because the block creates some unnecessary jmps. + // Expected output: + // + // fn main f0 { + // b0(): + // constrain Field 0 + // constrain Field 0 + // constrain Field 0 + // constrain Field 0 + // jmp b23() + // b23(): + // constrain Field 0 + // constrain Field 0 + // constrain Field 0 + // constrain Field 0 + // jmp b27() + // b27(): + // constrain Field 0 + // constrain Field 0 + // constrain Field 0 + // constrain Field 0 + // jmp b31() + // b31(): + // jmp b3() + // b3(): + // return Field 0 + // } + // The final block count is not 1 because unrolling creates some unnecessary jmps. // If a simplify cfg pass is ran afterward, the expected block count will be 1. let ssa = ssa.unroll_loops(); - assert_eq!(count_reachable_blocks(&ssa), 5); + assert_eq!(count_reachable_blocks_in_main(&ssa), 5); + } + + // Test that the pass can still be run on loops which fail to unroll properly + #[test] + fn fail_to_unroll_loop() { + // fn main f0 { + // b0(v0: Field): + // jmp b1(v0) + // b1(v1: Field): + // v2 = lt v1, 5 + // jmpif v2, then: b2, else: b3 + // b2(): + // v3 = add v1, Field 1 + // jmp b1(v3) + // b3(): + // return Field 0 + // } + let main_id = Id::test_new(0); + let mut builder = FunctionBuilder::new("main".into(), main_id); + + let b1 = builder.insert_block(); + let b2 = builder.insert_block(); + let b3 = builder.insert_block(); + + let v0 = builder.add_parameter(Type::field()); + let v1 = builder.add_block_parameter(b1, Type::field()); + + builder.terminate_with_jmp(b1, vec![v0]); + + builder.switch_to_block(b1); + let five = builder.field_constant(5u128); + let v2 = builder.insert_binary(v1, BinaryOp::Lt, five); + builder.terminate_with_jmpif(v2, b2, b3); + + builder.switch_to_block(b2); + let one = builder.field_constant(1u128); + let v3 = builder.insert_binary(v1, BinaryOp::Add, one); + builder.terminate_with_jmp(b1, vec![v3]); + + builder.switch_to_block(b3); + let zero = builder.field_constant(0u128); + builder.terminate_with_return(vec![zero]); + + let ssa = builder.finish(); + assert_eq!(count_reachable_blocks_in_main(&ssa), 4); + + // Expected ssa is unchanged + let ssa = ssa.unroll_loops(); + assert_eq!(count_reachable_blocks_in_main(&ssa), 4); } } From c67728261954b7e02fab01a8e7971d6b41594d16 Mon Sep 17 00:00:00 2001 From: Jake Fecher Date: Mon, 22 May 2023 13:07:29 -0700 Subject: [PATCH 24/25] Update simplify_cfg to use new inline_block method --- .../src/ssa_refactor/opt/simplify_cfg.rs | 23 ++++++------------- 1 file changed, 7 insertions(+), 16 deletions(-) diff --git a/crates/noirc_evaluator/src/ssa_refactor/opt/simplify_cfg.rs b/crates/noirc_evaluator/src/ssa_refactor/opt/simplify_cfg.rs index 7c91b5f0fe5..faf3d21a68b 100644 --- a/crates/noirc_evaluator/src/ssa_refactor/opt/simplify_cfg.rs +++ b/crates/noirc_evaluator/src/ssa_refactor/opt/simplify_cfg.rs @@ -102,25 +102,16 @@ fn remove_block_parameters( fn try_inline_into_predecessor( function: &mut Function, cfg: &mut ControlFlowGraph, - block_id: BasicBlockId, - predecessor_id: BasicBlockId, + block: BasicBlockId, + predecessor: BasicBlockId, ) -> bool { - let mut successors = cfg.successors(predecessor_id); - if successors.len() == 1 && successors.next() == Some(block_id) { + let mut successors = cfg.successors(predecessor); + if successors.len() == 1 && successors.next() == Some(block) { drop(successors); + function.dfg.inline_block(block, predecessor); - // First remove all the instructions and terminator from the block we're removing - let block = &mut function.dfg[block_id]; - let mut instructions = std::mem::take(block.instructions_mut()); - let terminator = block.take_terminator(); - - // Then append each to the predecessor - let predecessor = &mut function.dfg[predecessor_id]; - predecessor.instructions_mut().append(&mut instructions); - - predecessor.set_terminator(terminator); - cfg.recompute_block(function, block_id); - cfg.recompute_block(function, predecessor_id); + cfg.recompute_block(function, block); + cfg.recompute_block(function, predecessor); true } else { false From 83c4b5fdc4fdcd0edbcec92d91c7f976050b7e03 Mon Sep 17 00:00:00 2001 From: Jake Fecher Date: Mon, 22 May 2023 13:15:32 -0700 Subject: [PATCH 25/25] Remove now-unneeded test helper function --- .../src/ssa_refactor/opt/unrolling.rs | 18 +++++------------- 1 file changed, 5 insertions(+), 13 deletions(-) diff --git a/crates/noirc_evaluator/src/ssa_refactor/opt/unrolling.rs b/crates/noirc_evaluator/src/ssa_refactor/opt/unrolling.rs index 71df2db7472..dba64dde6b4 100644 --- a/crates/noirc_evaluator/src/ssa_refactor/opt/unrolling.rs +++ b/crates/noirc_evaluator/src/ssa_refactor/opt/unrolling.rs @@ -482,18 +482,10 @@ impl<'f> LoopIteration<'f> { #[cfg(test)] mod tests { use crate::ssa_refactor::{ - ir::{dom::DominatorTree, instruction::BinaryOp, map::Id, types::Type}, + ir::{instruction::BinaryOp, map::Id, types::Type}, ssa_builder::FunctionBuilder, - ssa_gen::Ssa, }; - // basic_blocks_iter iterates over unreachable blocks as well, so we must filter those out. - fn count_reachable_blocks_in_main(ssa: &Ssa) -> usize { - let function = ssa.main(); - let dom_tree = DominatorTree::with_function(function); - function.dfg.basic_blocks_iter().filter(|(block, _)| dom_tree.is_reachable(*block)).count() - } - #[test] fn unroll_nested_loops() { // fn main() { @@ -582,7 +574,7 @@ mod tests { builder.terminate_with_jmp(b1, vec![v7]); let ssa = builder.finish(); - assert_eq!(count_reachable_blocks_in_main(&ssa), 7); + assert_eq!(ssa.main().reachable_blocks().len(), 7); // Expected output: // @@ -613,7 +605,7 @@ mod tests { // The final block count is not 1 because unrolling creates some unnecessary jmps. // If a simplify cfg pass is ran afterward, the expected block count will be 1. let ssa = ssa.unroll_loops(); - assert_eq!(count_reachable_blocks_in_main(&ssa), 5); + assert_eq!(ssa.main().reachable_blocks().len(), 5); } // Test that the pass can still be run on loops which fail to unroll properly @@ -658,10 +650,10 @@ mod tests { builder.terminate_with_return(vec![zero]); let ssa = builder.finish(); - assert_eq!(count_reachable_blocks_in_main(&ssa), 4); + assert_eq!(ssa.main().reachable_blocks().len(), 4); // Expected ssa is unchanged let ssa = ssa.unroll_loops(); - assert_eq!(count_reachable_blocks_in_main(&ssa), 4); + assert_eq!(ssa.main().reachable_blocks().len(), 4); } }