Skip to content

Commit

Permalink
refactor!: remove terminal_state, allow running a VM state instead
Browse files Browse the repository at this point in the history
  • Loading branch information
jan-ferdinand committed Dec 4, 2023
1 parent 08bbc41 commit fbd58f1
Show file tree
Hide file tree
Showing 3 changed files with 76 additions and 109 deletions.
75 changes: 29 additions & 46 deletions triton-vm/src/instruction.rs
Original file line number Diff line number Diff line change
Expand Up @@ -651,6 +651,8 @@ impl<'a> Arbitrary<'a> for InstructionLabel {

#[cfg(test)]
mod tests {
use assert2::assert;
use assert2::let_assert;
use std::collections::HashMap;

use itertools::Itertools;
Expand All @@ -669,6 +671,7 @@ mod tests {
use crate::triton_asm;
use crate::triton_program;
use crate::vm::tests::test_program_for_call_recurse_return;
use crate::vm::VMState;
use crate::NonDeterminism;
use crate::Program;

Expand Down Expand Up @@ -733,7 +736,7 @@ mod tests {
for instruction in Instruction::iter() {
let expected_opcode = instruction.computed_opcode();
let opcode = instruction.opcode();
assert_eq!(expected_opcode, opcode, "{instruction}");
assert!(expected_opcode == opcode, "{instruction}");
}
}

Expand All @@ -754,10 +757,7 @@ mod tests {
let all_opcodes = Instruction::iter().map(|instruction| instruction.opcode());
let highest_opcode = all_opcodes.max().unwrap();
let num_required_bits_for_highest_opcode = highest_opcode.ilog2() + 1;
assert_eq!(
InstructionBit::COUNT,
num_required_bits_for_highest_opcode as usize
);
assert!(InstructionBit::COUNT == num_required_bits_for_highest_opcode as usize);
}

#[test]
Expand All @@ -771,7 +771,7 @@ mod tests {
Pop(N2),
];

assert_eq!(expected, instructions);
assert!(expected == instructions);
}

#[test]
Expand Down Expand Up @@ -800,7 +800,7 @@ mod tests {
#[test]
fn instruction_to_opcode_to_instruction_is_consistent() {
for instr in ALL_INSTRUCTIONS {
assert_eq!(instr, instr.opcode().try_into().unwrap());
assert!(instr == instr.opcode().try_into().unwrap());
}
}

Expand Down Expand Up @@ -852,8 +852,8 @@ mod tests {
fn instruction_size_is_consistent_with_having_arguments() {
for instruction in Instruction::iter() {
match instruction.has_arg() {
true => assert_eq!(2, instruction.size()),
false => assert_eq!(1, instruction.size()),
true => assert!(2 == instruction.size()),
false => assert!(1 == instruction.size()),
}
}
}
Expand All @@ -864,10 +864,7 @@ mod tests {
for instruction in Instruction::iter() {
let opcode = instruction.opcode();
println!("Testing instruction {instruction} with opcode {opcode}.");
assert_eq!(
instruction.has_arg(),
opcode & argument_indicator_bit_mask != 0
);
assert!(instruction.has_arg() == (opcode & argument_indicator_bit_mask != 0));
}
}

Expand All @@ -878,9 +875,8 @@ mod tests {
let instruction = instruction.replace_default_argument_if_illegal();
let opcode = instruction.opcode();
println!("Testing instruction {instruction} with opcode {opcode}.");
assert_eq!(
instruction.shrinks_op_stack(),
opcode & shrink_stack_indicator_bit_mask != 0
assert!(
instruction.shrinks_op_stack() == (opcode & shrink_stack_indicator_bit_mask != 0)
);
}
}
Expand All @@ -891,10 +887,7 @@ mod tests {
for instruction in Instruction::iter() {
let opcode = instruction.opcode();
println!("Testing instruction {instruction} with opcode {opcode}.");
assert_eq!(
instruction.is_u32_instruction(),
opcode & u32_indicator_bit_mask != 0
);
assert!(instruction.is_u32_instruction() == (opcode & u32_indicator_bit_mask != 0));
}
}

Expand All @@ -904,7 +897,7 @@ mod tests {
println!("Testing instruction bit {instruction_bit}.");
let bit_index = usize::from(instruction_bit);
let recovered_instruction_bit = InstructionBit::try_from(bit_index).unwrap();
assert_eq!(instruction_bit, recovered_instruction_bit);
assert!(instruction_bit == recovered_instruction_bit);
}
}

Expand Down Expand Up @@ -932,9 +925,8 @@ mod tests {

let stack_size_difference = (stack_size_after_test_instruction as i32)
- (stack_size_before_test_instruction as i32);
assert_eq!(
test_instruction.op_stack_size_influence(),
stack_size_difference,
assert!(
test_instruction.op_stack_size_influence() == stack_size_difference,
"{test_instruction}"
);
}
Expand Down Expand Up @@ -977,42 +969,33 @@ mod tests {
let non_determinism: NonDeterminism<_> = vec![BFIELD_ZERO].into();
let non_determinism = non_determinism.with_digests(mock_digests);

let terminal_state = program
.terminal_state(public_input, non_determinism)
.unwrap();
terminal_state.op_stack.stack.len()
let mut vm_state = VMState::new(&program, public_input, non_determinism);
let_assert!(Ok(()) = vm_state.run());
vm_state.op_stack.stack.len()
}

#[test]
fn labelled_instructions_act_on_op_stack_as_indicated() {
for test_instruction in all_instructions_without_args() {
let labelled_instruction =
test_instruction.map_call_address(|_| "dummy_label".to_string());
for instruction in all_instructions_without_args() {
let labelled_instruction = instruction.map_call_address(|_| "dummy_label".to_string());
let labelled_instruction = LabelledInstruction::Instruction(labelled_instruction);

assert_eq!(
test_instruction.op_stack_size_influence(),
labelled_instruction.op_stack_size_influence()
);
assert_eq!(
test_instruction.grows_op_stack(),
labelled_instruction.grows_op_stack()
);
assert_eq!(
test_instruction.changes_op_stack_size(),
labelled_instruction.changes_op_stack_size()
assert!(
instruction.op_stack_size_influence()
== labelled_instruction.op_stack_size_influence()
);
assert_eq!(
test_instruction.shrinks_op_stack(),
labelled_instruction.shrinks_op_stack()
assert!(instruction.grows_op_stack() == labelled_instruction.grows_op_stack());
assert!(
instruction.changes_op_stack_size() == labelled_instruction.changes_op_stack_size()
);
assert!(instruction.shrinks_op_stack() == labelled_instruction.shrinks_op_stack());
}
}

#[test]
fn labels_indicate_no_change_to_op_stack() {
let labelled_instruction = LabelledInstruction::Label("dummy_label".to_string());
assert_eq!(0, labelled_instruction.op_stack_size_influence());
assert!(0 == labelled_instruction.op_stack_size_influence());
assert!(!labelled_instruction.grows_op_stack());
assert!(!labelled_instruction.changes_op_stack_size());
assert!(!labelled_instruction.shrinks_op_stack());
Expand Down
77 changes: 25 additions & 52 deletions triton-vm/src/program.rs
Original file line number Diff line number Diff line change
Expand Up @@ -395,42 +395,20 @@ impl Program {
/// If an error is encountered, the returned [`VMError`] contains the [`VMState`] at the point
/// of execution failure.
///
/// See also [`trace_execution`][trace_execution], [`terminal_state`][terminal_state], and
/// [`profile`][profile].
/// See also [`trace_execution`][trace_execution] and [`profile`][profile].
///
/// [trace_execution]: Self::trace_execution
/// [terminal_state]: Self::terminal_state
/// [profile]: Self::profile
pub fn run(
&self,
public_input: PublicInput,
non_determinism: NonDeterminism<BFieldElement>,
) -> Result<Vec<BFieldElement>> {
let terminal_state = self.terminal_state(public_input, non_determinism)?;
Ok(terminal_state.public_output)
}

/// Similar to [`run`][run], but returns the entire [`VMState`] instead of just the public
/// output.
///
/// See also [`trace_execution`][trace_execution] and [`profile`][profile].
///
/// [run]: Self::run
/// [trace_execution]: Self::trace_execution
/// [profile]: Self::profile
pub fn terminal_state(
&self,
public_input: PublicInput,
non_determinism: NonDeterminism<BFieldElement>,
) -> Result<VMState> {
let mut state = VMState::new(self, public_input, non_determinism);
while !state.halting {
let maybe_error = state.step();
if let Err(err) = maybe_error {
return Err(VMError::new(err, state));
}
if let Err(err) = state.run() {
return Err(VMError::new(err, state));
}
Ok(state)
Ok(state.public_output)
}

/// Trace the execution of a [`Program`]. That is, [`run`][run] the [`Program`] and additionally
Expand All @@ -439,11 +417,9 @@ impl Program {
/// 1. an [`AlgebraicExecutionTrace`], and
/// 1. the output of the program.
///
/// See also [`run`][run], [`terminal_state`][terminal_state], and
/// [`profile`][profile].
/// See also [`run`][run] and [`profile`][profile].
///
/// [run]: Self::run
/// [terminal_state]: Self::terminal_state
/// [profile]: Self::profile
pub fn trace_execution(
&self,
Expand Down Expand Up @@ -474,12 +450,10 @@ impl Program {
/// in each callable block of instructions. This function returns a Result wrapping a program
/// profiler report, which is a Vec of [`ProfileLine`]s.
///
/// See also [`run`][run], [`trace_execution`][trace_execution], and
/// [`terminal_state`][terminal_state].
/// See also [`run`][run] and [`trace_execution`][trace_execution].
///
/// [run]: Self::run
/// [trace_execution]: Self::trace_execution
/// [terminal_state]: Self::terminal_state
pub fn profile(
&self,
public_input: PublicInput,
Expand Down Expand Up @@ -786,6 +760,7 @@ where

#[cfg(test)]
mod tests {
use assert2::assert;
use assert2::let_assert;
use itertools::Itertools;
use proptest::prelude::*;
Expand Down Expand Up @@ -861,7 +836,7 @@ mod tests {
.map(BFieldElement::new);
let expected_digest = Digest::new(expected_digest);

assert_eq!(expected_digest, digest);
assert!(expected_digest == digest);
}

#[test]
Expand All @@ -875,31 +850,31 @@ mod tests {
let tokens = thread_rng().gen::<[BFieldElement; 12]>().to_vec();
let public_input = PublicInput::new(tokens.clone());

assert_eq!(public_input, tokens.clone().into());
assert_eq!(public_input, (&tokens).into());
assert_eq!(public_input, tokens[..].into());
assert_eq!(public_input, (&tokens[..]).into());
assert!(public_input == tokens.clone().into());
assert!(public_input == (&tokens).into());
assert!(public_input == tokens[..].into());
assert!(public_input == (&tokens[..]).into());

let tokens = tokens.into_iter().map(|e| e.value()).collect_vec();
assert_eq!(public_input, tokens.into());
assert!(public_input == tokens.into());

assert_eq!(PublicInput::new(vec![]), [].into());
assert!(PublicInput::new(vec![]) == [].into());
}

#[test]
fn from_various_types_to_non_determinism() {
let tokens = thread_rng().gen::<[BFieldElement; 12]>().to_vec();
let non_determinism = NonDeterminism::new(tokens.clone());

assert_eq!(non_determinism, tokens.clone().into());
assert_eq!(non_determinism, tokens[..].into());
assert_eq!(non_determinism, (&tokens[..]).into());
assert!(non_determinism == tokens.clone().into());
assert!(non_determinism == tokens[..].into());
assert!(non_determinism == (&tokens[..]).into());

let tokens = tokens.into_iter().map(|e| e.value()).collect_vec();
assert_eq!(non_determinism, tokens.into());
assert!(non_determinism == tokens.into());

assert_eq!(NonDeterminism::<u64>::new(vec![]), [].into());
assert_eq!(NonDeterminism::<BFieldElement>::new(vec![]), [].into());
assert!(NonDeterminism::<u64>::new(vec![]) == [].into());
assert!(NonDeterminism::<BFieldElement>::new(vec![]) == [].into());
}

#[test]
Expand All @@ -924,7 +899,7 @@ mod tests {
);
let program_from_code = Program::from_code(&source_code).unwrap();
let program_from_macro = triton_program!({ source_code });
assert_eq!(program_from_code, program_from_macro);
assert!(program_from_code == program_from_macro);
}

#[test]
Expand All @@ -940,12 +915,10 @@ mod tests {
fn test_profile() {
let program = CALCULATE_NEW_MMR_PEAKS_FROM_APPEND_WITH_SAFE_LISTS.clone();
let (profile_output, profile) = program.profile([].into(), [].into()).unwrap();
let terminal_state = program.terminal_state([].into(), [].into()).unwrap();
assert_eq!(profile_output, terminal_state.public_output);
assert_eq!(
profile.last().unwrap().cycle_count(),
terminal_state.cycle_count
);
let mut vm_state = VMState::new(&program, [].into(), [].into());
let_assert!(Ok(()) = vm_state.run());
assert!(profile_output == vm_state.public_output);
assert!(profile.last().unwrap().cycle_count() == vm_state.cycle_count);

println!("Profile of Tasm Program:");
for line in profile {
Expand Down
Loading

0 comments on commit fbd58f1

Please sign in to comment.