diff --git a/lucet-module/src/runtime.rs b/lucet-module/src/runtime.rs index 8b6b536d5..337ae41a6 100644 --- a/lucet-module/src/runtime.rs +++ b/lucet-module/src/runtime.rs @@ -11,16 +11,10 @@ pub struct InstanceRuntimeData { /// `instruction_count_adj` set to some negative value and /// `instruction_count_bound` adjusted upward in compensation. /// `instruction_count_adj` is incremented as execution proceeds; on each - /// increment, the Wasm code checks the carry flag. If the value crosses - /// zero (becomes positive), then we have exceeded the bound and we must - /// yield. At any point, the `adj` value can be adjusted downward by - /// transferring the count to the `bound`. - /// - /// Note that the bound-yield is only triggered if the `adj` value - /// transitions from negative to non-negative; in other words, it is - /// edge-triggered, not level-triggered. So entering code that has been - /// instrumented for instruction counting with `adj` >= 0 will result in no - /// bound ever triggered (until 2^64 instructions execute). + /// increment, the Wasm code checks the sign. If the value is greater than + /// zero, then we have exceeded the bound and we must yield. At any point, + /// the `adj` value can be adjusted downward by transferring the count to + /// the `bound`. pub instruction_count_adj: i64, pub instruction_count_bound: i64, pub stack_limit: u64, diff --git a/lucet-runtime/lucet-runtime-internals/src/future.rs b/lucet-runtime/lucet-runtime-internals/src/future.rs index d905f5724..6bbdb401b 100644 --- a/lucet-runtime/lucet-runtime-internals/src/future.rs +++ b/lucet-runtime/lucet-runtime-internals/src/future.rs @@ -1,5 +1,6 @@ use crate::error::Error; use crate::instance::{InstanceHandle, InternalRunResult, RunResult, State, TerminationDetails}; +use crate::module::FunctionHandle; use crate::val::{UntypedRetVal, Val}; use crate::vmctx::{Vmctx, VmctxInternal}; use std::any::Any; @@ -130,6 +131,46 @@ impl InstanceHandle { entrypoint: &'a str, args: &'a [Val], runtime_bound: Option, + ) -> Result { + let func = self.module.get_export_func(entrypoint)?; + self.run_async_internal(func, args, runtime_bound).await + } + + /// Run the module's [start function][start], if one exists. + /// + /// If there is no start function in the module, this does nothing. + /// + /// All of the other restrictions on the start function, what it may do, and + /// the requirement that it must be invoked first, are described in the + /// documentation for `Instance::run_start()`. This async version of that + /// function satisfies the requirement to run the start function first, as + /// long as the async function fully returns (not just yields). + /// + /// This method is similar to `Instance::run_start()`, except that it bounds + /// runtime between async future yields (invocations of `.poll()` on the + /// underlying generated future) if `runtime_bound` is provided. This + /// behaves the same way as `Instance::run_async()`. + pub async fn run_async_start<'a>( + &'a mut self, + runtime_bound: Option, + ) -> Result<(), Error> { + println!("run_async_start: get_start_func = {:?}", self.module.get_start_func()); + if let Some(start) = self.module.get_start_func()? { + if !self.is_not_started() { + return Err(Error::StartAlreadyRun); + } + self.run_async_internal(start, &[], runtime_bound).await?; + } + Ok(()) + } + + /// Shared async run-loop implementation for both `run_async()` and + /// `run_start_async()`. + async fn run_async_internal<'a>( + &'a mut self, + func: FunctionHandle, + args: &'a [Val], + runtime_bound: Option, ) -> Result { if self.is_yielded() { return Err(Error::Unsupported( @@ -158,7 +199,6 @@ impl InstanceHandle { ) } else { // This is the first iteration, call the entrypoint: - let func = self.module.get_export_func(entrypoint)?; self.run_func(func, args, true, runtime_bound) }; match run_result? { diff --git a/lucet-runtime/lucet-runtime-internals/src/instance.rs b/lucet-runtime/lucet-runtime-internals/src/instance.rs index f0bbdc6f1..f9523bb63 100644 --- a/lucet-runtime/lucet-runtime-internals/src/instance.rs +++ b/lucet-runtime/lucet-runtime-internals/src/instance.rs @@ -427,12 +427,10 @@ pub(crate) enum InternalRunResult { } impl InternalRunResult { - pub(crate) fn unwrap(self) -> Result { + pub(crate) fn unwrap(self) -> RunResult { match self { - InternalRunResult::Normal(result) => Ok(result), - InternalRunResult::BoundExpired => Err(Error::InvalidArgument( - "should not have had a runtime bound", - )), + InternalRunResult::Normal(result) => result, + InternalRunResult::BoundExpired => panic!("should not have had a runtime bound"), } } } @@ -517,7 +515,7 @@ impl Instance { /// in the future. pub fn run(&mut self, entrypoint: &str, args: &[Val]) -> Result { let func = self.module.get_export_func(entrypoint)?; - self.run_func(func, &args, false, None)?.unwrap() + Ok(self.run_func(func, &args, false, None)?.unwrap()) } /// Run a function with arguments in the guest context from the [WebAssembly function @@ -533,7 +531,7 @@ impl Instance { args: &[Val], ) -> Result { let func = self.module.get_func_from_idx(table_idx, func_idx)?; - self.run_func(func, &args, false, None)?.unwrap() + Ok(self.run_func(func, &args, false, None)?.unwrap()) } /// Resume execution of an instance that has yielded without providing a value to the guest. @@ -564,7 +562,7 @@ impl Instance { /// The foreign code safety caveat of [`Instance::run()`](struct.Instance.html#method.run) /// applies. pub fn resume_with_val(&mut self, val: A) -> Result { - self.resume_with_val_impl(val, false, None)?.unwrap() + Ok(self.resume_with_val_impl(val, false, None)?.unwrap()) } pub(crate) fn resume_with_val_impl( @@ -587,7 +585,7 @@ impl Instance { self.resumed_val = Some(Box::new(val) as Box); - self.set_instruction_bound_delta(max_insn_count.unwrap_or(0)); + self.set_instruction_bound_delta(max_insn_count); self.swap_and_return(async_context) } @@ -611,7 +609,7 @@ impl Instance { "can only call resume_bounded() on an instance that hit an instruction bound", )); } - self.set_instruction_bound_delta(max_insn_count); + self.set_instruction_bound_delta(Some(max_insn_count)); self.swap_and_return(true) } @@ -650,10 +648,7 @@ impl Instance { if !self.is_not_started() { return Err(Error::StartAlreadyRun); } - let res = self.run_func(start, &[], false, None)?.unwrap()?; - if res.is_yielded() { - return Err(Error::StartYielded); - } + self.run_func(start, &[], false, None)?; } Ok(()) } @@ -958,11 +953,12 @@ impl Instance { /// value is *crossed*, but not if execution *begins* with the value exceeded. Hence `delta` /// must be greater than zero for this to set up the instance state to trigger a yield. #[inline] - pub fn set_instruction_bound_delta(&mut self, delta: u64) { + pub fn set_instruction_bound_delta(&mut self, delta: Option) { let implicits = self.get_instance_implicits_mut(); let sum = implicits.instruction_count_adj + implicits.instruction_count_bound; + let delta = delta.unwrap_or(i64::MAX as u64); let delta = i64::try_from(delta).expect("delta too large"); - implicits.instruction_count_bound = sum + delta; + implicits.instruction_count_bound = sum.wrapping_add(delta); implicits.instruction_count_adj = -delta; } @@ -1140,7 +1136,7 @@ impl Instance { let mut args_with_vmctx = vec![Val::from(self.alloc.slot().heap)]; args_with_vmctx.extend_from_slice(args); - self.set_instruction_bound_delta(inst_count_bound.unwrap_or(0)); + self.set_instruction_bound_delta(inst_count_bound); let self_ptr = self as *mut _; Context::init_with_callback( diff --git a/lucet-runtime/tests/instruction_counting.rs b/lucet-runtime/tests/instruction_counting.rs index d712e182f..4a5e8d267 100644 --- a/lucet-runtime/tests/instruction_counting.rs +++ b/lucet-runtime/tests/instruction_counting.rs @@ -26,7 +26,7 @@ pub fn wasm_test>( Ok(dlmodule) } -pub fn get_instruction_count_test_files() -> Vec { +pub fn get_instruction_count_test_files(want_start_function: bool) -> Vec { std::fs::read_dir("./tests/instruction_counting") .expect("can iterate test files") .map(|ent| { @@ -37,12 +37,14 @@ pub fn get_instruction_count_test_files() -> Vec { ); ent }) + .filter(|ent| want_start_function == ent.path().to_str().unwrap().contains("_start.wat")) .collect() } #[test] pub fn check_instruction_count_off() { - let files: Vec = get_instruction_count_test_files(); + let files: Vec = + get_instruction_count_test_files(/* want_start_function = */ false); assert!( !files.is_empty(), @@ -71,7 +73,8 @@ pub fn check_instruction_count_off() { #[test] pub fn check_instruction_count() { - let files: Vec = get_instruction_count_test_files(); + let files: Vec = + get_instruction_count_test_files(/* want_start_function = */ false); assert!( !files.is_empty(), @@ -136,7 +139,16 @@ fn dummy_waker() -> Waker { #[test] pub fn check_instruction_count_with_periodic_yields() { - let files: Vec = get_instruction_count_test_files(); + check_instruction_count_with_periodic_yields_internal(/* want_start_function = */ false); +} + +#[test] +pub fn check_instruction_count_with_periodic_yields_start_func() { + check_instruction_count_with_periodic_yields_internal(/* want_start_function = */ true); +} + +fn check_instruction_count_with_periodic_yields_internal(want_start_function: bool) { + let files: Vec = get_instruction_count_test_files(want_start_function); assert!( !files.is_empty(), @@ -154,15 +166,13 @@ pub fn check_instruction_count_with_periodic_yields() { .new_instance(module) .expect("instance can be created"); - let yields = { + fn future_loop(mut future: std::pin::Pin>) -> u64 { let mut yields = 0; - let mut future = Box::pin(inst.run_async("test_function", &[], Some(1000))); let waker = dummy_waker(); let mut context = Context::from_waker(&waker); loop { match future.as_mut().poll(&mut context) { - Poll::Ready(val) => { - val.expect("instance runs"); + Poll::Ready(_) => { break; } Poll::Pending => { @@ -174,6 +184,14 @@ pub fn check_instruction_count_with_periodic_yields() { } } yields + } + + let yields = if want_start_function { + let future = Box::pin(inst.run_async_start(Some(1000))); + future_loop(future) + } else { + let future = Box::pin(inst.run_async("test_function", &[], Some(1000))); + future_loop(future) }; let instruction_count = inst diff --git a/lucet-runtime/tests/instruction_counting/long_loop_start.wat b/lucet-runtime/tests/instruction_counting/long_loop_start.wat new file mode 100644 index 000000000..347d53c58 --- /dev/null +++ b/lucet-runtime/tests/instruction_counting/long_loop_start.wat @@ -0,0 +1,17 @@ +(module + (start $start) + (func $start (export "_start") (local i32) + loop + local.get 0 + i32.const 1 + i32.add + local.tee 0 + i32.const 10000 + i32.ne + br_if 0 + end + ) + (func $instruction_count (export "instruction_count") (result i64) + i64.const 70000 + ) +) diff --git a/lucetc/src/compiler.rs b/lucetc/src/compiler.rs index 7b42c5ed5..4bcd8e5f8 100644 --- a/lucetc/src/compiler.rs +++ b/lucetc/src/compiler.rs @@ -341,6 +341,22 @@ impl<'a> Compiler<'a> { self.decls.get_module_data(self.module_features()) } + fn get_local_count(body: &FunctionBody, name: &str) -> Result { + let error_mapper = |e| Error::FunctionTranslation { + symbol: name.to_string(), + source: Box::new(Error::from(e)), + }; + Ok(body + .get_locals_reader() + .map_err(error_mapper)? + .into_iter() + .map(|result| result.map_err(error_mapper)) + .collect::, Error>>()? + .into_iter() + .map(|(count, _ty)| count) + .sum()) + } + pub fn object_file(self) -> Result { let mut function_manifest_ctx = ClifDataContext::new(); let mut function_manifest_bytes = Cursor::new(Vec::new()); @@ -359,7 +375,15 @@ impl<'a> Compiler<'a> { let func = decls .get_func(unique_func_ix) .expect("decl exists for func body"); - let mut func_info = FuncInfo::new(&decls, &codegen_context, count_instructions); + let arg_count = func.signature.params.len() as u32; + let local_count = Self::get_local_count(&func_body, func.name.symbol())?; + let mut func_info = FuncInfo::new( + &decls, + &codegen_context, + count_instructions, + arg_count, + local_count, + ); let mut clif_context = ClifContext::new(); clif_context.func.name = func.name.as_externalname(); clif_context.func.signature = func.signature.clone(); @@ -536,8 +560,15 @@ impl<'a> Compiler<'a> { .decls .get_func(unique_func_ix) .expect("decl exists for func body"); - let mut func_info = - FuncInfo::new(&self.decls, &self.codegen_context, self.count_instructions); + let arg_count = func.signature.params.len() as u32; + let local_count = Self::get_local_count(&body, func.name.symbol())?; + let mut func_info = FuncInfo::new( + &self.decls, + &self.codegen_context, + self.count_instructions, + arg_count, + local_count, + ); let mut clif_context = ClifContext::new(); clif_context.func.name = func.name.as_externalname(); clif_context.func.signature = func.signature.clone(); diff --git a/lucetc/src/error.rs b/lucetc/src/error.rs index 5bff927de..02fe39cd5 100644 --- a/lucetc/src/error.rs +++ b/lucetc/src/error.rs @@ -1,6 +1,7 @@ use crate::types::SignatureError; use crate::validate::Error as ValidationError; use cranelift_module::ModuleError as ClifModuleError; +use cranelift_wasm::wasmparser::BinaryReaderError as ClifWasmReaderError; use cranelift_wasm::WasmError as ClifWasmError; use lucet_module::error::Error as LucetModuleError; use object; @@ -27,6 +28,8 @@ pub enum Error { MissingWasmPreamble, #[error("Wasm validation: {0}")] WasmValidation(#[from] wasmparser::BinaryReaderError), + #[error("Wasm validation: {0}")] + ClifWasmValidation(#[from] ClifWasmReaderError), #[error("Wat input: {0}")] WatInput(#[from] wabt::Error), #[error("Object artifact: {1}. {0:?}")] diff --git a/lucetc/src/function.rs b/lucetc/src/function.rs index 2da784854..18805bf16 100644 --- a/lucetc/src/function.rs +++ b/lucetc/src/function.rs @@ -8,7 +8,7 @@ use cranelift_codegen::cursor::FuncCursor; use cranelift_codegen::entity::EntityRef; use cranelift_codegen::ir::{self, condcodes::IntCC, InstBuilder}; use cranelift_codegen::isa::TargetFrontendConfig; -use cranelift_frontend::FunctionBuilder; +use cranelift_frontend::{FunctionBuilder, Variable}; use cranelift_module::{Linkage, Module as ClifModule, ModuleError as ClifModuleError}; use cranelift_wasm::{ wasmparser::Operator, FuncEnvironment, FuncIndex, FuncTranslationState, GlobalIndex, @@ -22,10 +22,16 @@ pub struct FuncInfo<'a> { module_decls: &'a ModuleDecls<'a>, codegen_context: &'a CodegenContext, count_instructions: bool, - scope_costs: Vec, + scope_costs: Vec, vmctx_value: Option, global_base_value: Option, runtime_funcs: HashMap, + instr_count_var: Variable, +} + +struct ScopeInfo { + cost: u32, + is_loop: bool, } impl<'a> FuncInfo<'a> { @@ -33,15 +39,24 @@ impl<'a> FuncInfo<'a> { module_decls: &'a ModuleDecls<'a>, codegen_context: &'a CodegenContext, count_instructions: bool, + arg_count: u32, + local_count: u32, ) -> Self { Self { module_decls, codegen_context, count_instructions, - scope_costs: vec![0], + scope_costs: vec![ScopeInfo { + cost: 0, + is_loop: false, + }], vmctx_value: None, global_base_value: None, runtime_funcs: HashMap::new(), + // variable indices correspond to Wasm bytecode's index space, + // so we designate a new one after all the Wasm locals to hold + // the instruction count. + instr_count_var: Variable::with_u32(arg_count + local_count), } } @@ -89,7 +104,51 @@ impl<'a> FuncInfo<'a> { }) } - fn update_instruction_count_instrumentation( + fn get_instr_count_addr_offset( + &mut self, + builder: &mut FunctionBuilder<'_>, + ) -> (ir::Value, ir::immediates::Offset32) { + let instr_count_offset: ir::immediates::Offset32 = + (-(std::mem::size_of::() as i32) + + offset_of!(InstanceRuntimeData, instruction_count_adj) as i32) + .into(); + let vmctx_gv = self.get_vmctx(builder.func); + let addr = builder.ins().global_value(self.pointer_type(), vmctx_gv); + (addr, instr_count_offset) + } + + fn load_instr_count(&mut self, builder: &mut FunctionBuilder<'_>) { + let (addr, instr_count_offset) = self.get_instr_count_addr_offset(builder); + let trusted_mem = ir::MemFlags::trusted(); + + // Do the equivalent of: + // + // let instruction_count_ptr: &mut i64 = vmctx.instruction_count; + // let instruction_count: i64 = *instruction_count_ptr; + // vars[instr_count] = instruction_count; + let cur_instr_count = + builder + .ins() + .load(ir::types::I64, trusted_mem, addr, instr_count_offset); + builder.def_var(self.instr_count_var, cur_instr_count); + } + + fn save_instr_count(&mut self, builder: &mut FunctionBuilder<'_>) { + let (addr, instr_count_offset) = self.get_instr_count_addr_offset(builder); + let trusted_mem = ir::MemFlags::trusted(); + + // Do the equivalent of: + // + // let instruction_count_ptr: &mut i64 = vmctx.instruction_count; + // let instruction_count = vars[instr_count]; + // *instruction_count_ptr = instruction_count; + let new_instr_count = builder.use_var(self.instr_count_var); + builder + .ins() + .store(trusted_mem, new_instr_count, addr, instr_count_offset); + } + + fn update_instruction_count_instrumentation_pre( &mut self, op: &Operator<'_>, builder: &mut FunctionBuilder<'_>, @@ -121,60 +180,63 @@ impl<'a> FuncInfo<'a> { // anyway (exception dispatch is much more expensive than a single wasm op) // * Return corresponds to exactly one function call, so we can count it by resetting the // stack to 1 at return of a function. + // + // We keep a cache of the counter in the pinned register. We load it in the prologue, save + // it in the epilogue, and save and reload it around calls. (We could alter our ABI to + // preserve the pinned reg across calls within Wasm, and save and reload it only around + // hostcalls, as long as we could load and save it in a trampoline wrapping the initial + // Wasm entry. We haven't yet done this.) /// Flush the currently-accumulated instruction count to the counter in the instance data, /// invoking the yield hostcall if we hit the bound. fn flush_counter(environ: &mut FuncInfo<'_>, builder: &mut FunctionBuilder<'_>) { - if environ.scope_costs.last() == Some(&0) { - return; + match environ.scope_costs.last() { + Some(info) if info.cost == 0 => return, + _ => {} } - let instr_count_offset: ir::immediates::Offset32 = - (-(std::mem::size_of::() as i32) - + offset_of!(InstanceRuntimeData, instruction_count_adj) as i32) - .into(); - let vmctx_gv = environ.get_vmctx(builder.func); - let addr = builder.ins().global_value(environ.pointer_type(), vmctx_gv); - let trusted_mem = ir::MemFlags::trusted(); // Now insert a sequence of clif that is, functionally: // - // let instruction_count_ptr: &mut u64 = vmctx.instruction_count; - // let mut instruction_count: u64 = *instruction_count_ptr; + // let mut instruction_count = vars[instr_count]; // instruction_count += ; - // *instruction_count_ptr = instruction_count; + // vars[instr_count] = instruction_count; - let cur_instr_count = - builder - .ins() - .load(ir::types::I64, trusted_mem, addr, instr_count_offset); + let cur_instr_count = builder.use_var(environ.instr_count_var); let update_const = builder.ins().iconst( ir::types::I64, - i64::from(*environ.scope_costs.last().unwrap()), + i64::from(environ.scope_costs.last().unwrap().cost), ); - let (new_instr_count, flags) = builder.ins().iadd_ifcout(cur_instr_count, update_const); - builder - .ins() - .store(trusted_mem, new_instr_count, addr, instr_count_offset); + let new_instr_count = builder.ins().iadd(cur_instr_count, update_const); + builder.def_var(environ.instr_count_var, new_instr_count); + environ.scope_costs.last_mut().unwrap().cost = 0; + }; + fn do_check(environ: &mut FuncInfo<'_>, builder: &mut FunctionBuilder<'_>) { let yield_block = builder.create_block(); let continuation_block = builder.create_block(); + // If `adj` is positive, branch to yield block. + let zero = builder.ins().iconst(ir::types::I64, 0); + let new_instr_count = builder.use_var(environ.instr_count_var); + let cmp = builder.ins().ifcmp(new_instr_count, zero); builder .ins() - .brif(IntCC::UnsignedLessThan, flags, yield_block, &[]); + .brif(IntCC::SignedGreaterThanOrEqual, cmp, yield_block, &[]); builder.ins().jump(continuation_block, &[]); builder.seal_block(yield_block); builder.switch_to_block(yield_block); + environ.save_instr_count(builder); let yield_hostcall = environ.get_runtime_func(RuntimeFunc::YieldAtBoundExpiration, &mut builder.func); + let vmctx_gv = environ.get_vmctx(builder.func); + let addr = builder.ins().global_value(environ.pointer_type(), vmctx_gv); builder.ins().call(yield_hostcall, &[addr]); + environ.load_instr_count(builder); builder.ins().jump(continuation_block, &[]); builder.seal_block(continuation_block); builder.switch_to_block(continuation_block); - - *environ.scope_costs.last_mut().unwrap() = 0; - }; + } // Only update or flush the counter when the scope is not sealed. // @@ -216,7 +278,9 @@ impl<'a> FuncInfo<'a> { // everything else, just call it one operation. _ => 1, }; - self.scope_costs.last_mut().map(|x| *x += op_cost); + self.scope_costs + .last_mut() + .map(|ref mut info| info.cost += op_cost); // apply flushing behavior if applicable match op { @@ -231,7 +295,22 @@ impl<'a> FuncInfo<'a> { | Operator::Br { .. } | Operator::BrIf { .. } | Operator::BrTable { .. } => { + let do_check_and_save = match op { + Operator::Call { .. } + | Operator::CallIndirect { .. } + | Operator::Return => true, + Operator::Br { relative_depth } | Operator::BrIf { relative_depth } => { + // only if loop backedge + self.scope_costs[self.scope_costs.len() - 1 - *relative_depth as usize] + .is_loop + } + _ => false, + }; flush_counter(self, builder); + if do_check_and_save { + self.save_instr_count(builder); + do_check(self, builder); + } } Operator::End => { // We have to be really careful here to avoid violating a cranelift invariant: @@ -269,7 +348,7 @@ impl<'a> FuncInfo<'a> { // we shouldn't (because they're unreachable), or we didn't flush the counter before // starting to also instrument unreachable instructions (and would have tried to // overcount) - assert_eq!(*self.scope_costs.last().unwrap(), 0); + assert_eq!(self.scope_costs.last().unwrap().cost, 0); } // finally, we might have to set up a new counter for a new scope, or fix up counts a bit. @@ -283,12 +362,18 @@ impl<'a> FuncInfo<'a> { // reachable, the "called" function won't return! if reachable { // add 1 to count the return from the called function - self.scope_costs.last_mut().map(|x| *x = 1); + self.scope_costs + .last_mut() + .map(|ref mut info| info.cost += 1); } } Operator::Block { .. } | Operator::Loop { .. } | Operator::If { .. } => { // opening a scope, which starts having executed zero wasm ops - self.scope_costs.push(0); + let is_loop = match op { + Operator::Loop { .. } => true, + _ => false, + }; + self.scope_costs.push(ScopeInfo { cost: 0, is_loop }); } Operator::End => { // close the top scope @@ -298,6 +383,45 @@ impl<'a> FuncInfo<'a> { } Ok(()) } + + fn update_instruction_count_instrumentation_post( + &mut self, + op: &Operator<'_>, + builder: &mut FunctionBuilder<'_>, + reachable: bool, + ) -> WasmResult<()> { + // Handle reloads after calls. + let is_call = match op { + Operator::Call { .. } | Operator::CallIndirect { .. } => true, + _ => false, + }; + if reachable && is_call { + self.load_instr_count(builder); + } + Ok(()) + } + + fn update_instruction_count_instrumentation_before_func( + &mut self, + builder: &mut FunctionBuilder<'_>, + ) -> WasmResult<()> { + if self.count_instructions { + builder.declare_var(self.instr_count_var, ir::types::I64); + self.load_instr_count(builder); + } + Ok(()) + } + + fn update_instruction_count_instrumentation_after_func( + &mut self, + builder: &mut FunctionBuilder<'_>, + reachable: bool, + ) -> WasmResult<()> { + if reachable { + self.save_instr_count(builder); + } + Ok(()) + } } /// Get the local trampoline function to do safety checks before calling an imported hostcall. @@ -761,7 +885,41 @@ impl<'a> FuncEnvironment for FuncInfo<'a> { state: &FuncTranslationState, ) -> WasmResult<()> { if self.count_instructions { - self.update_instruction_count_instrumentation(op, builder, state.reachable())?; + self.update_instruction_count_instrumentation_pre(op, builder, state.reachable())?; + } + Ok(()) + } + + fn after_translate_operator( + &mut self, + op: &Operator<'_>, + builder: &mut FunctionBuilder<'_>, + state: &FuncTranslationState, + ) -> WasmResult<()> { + if self.count_instructions { + self.update_instruction_count_instrumentation_post(op, builder, state.reachable())?; + } + Ok(()) + } + + fn before_translate_function( + &mut self, + builder: &mut FunctionBuilder<'_>, + _state: &FuncTranslationState, + ) -> WasmResult<()> { + if self.count_instructions { + self.update_instruction_count_instrumentation_before_func(builder)?; + } + Ok(()) + } + + fn after_translate_function( + &mut self, + builder: &mut FunctionBuilder<'_>, + state: &FuncTranslationState, + ) -> WasmResult<()> { + if self.count_instructions { + self.update_instruction_count_instrumentation_after_func(builder, state.reachable())?; } Ok(()) }