Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Compilation to Triton VM #200

Open
greenhat opened this issue May 28, 2023 · 10 comments
Open

Compilation to Triton VM #200

greenhat opened this issue May 28, 2023 · 10 comments

Comments

@greenhat
Copy link

greenhat commented May 28, 2023

Hi! I've been working on a compiler framework for zero-knowledge VMs as a side project. It's called OmniZK, and the idea behind it is to help build compilers from high-level languages to various ZK VMs. Its design resembles the MLIR (LLVM) architecture, where IR transformations are implemented generically and reused with different custom IR dialects.

I've implemented a proof-of-concept Wasm frontend (parser, IR dialect, etc.) and Triton VM backend (IR dialect, codegen, etc.) with Wasm -> Triton VM compilation.

Check out how the following Fibonacci example in Rust:

use ozk_stdlib::*;

pub fn fib_seq() {
    let n = pub_input() as u32;
    let mut a: u32 = 0;
    let mut b: u32 = 1;
    for _ in 0..n {
        let c = a + b;
        a = b;
        b = c;
    }
    pub_output(a as u64);
}

is compiled to the following fully executable Triton VM code. Keep in mind that it's in the proof-of-concept stage, there are no optimizations yet so the code is quite verbose.

The most challenging part was to convert Wasm blocks/loops and branching into functions, jumps, and recursion. I used skiz, recurse, and global var for tracking branching targets.

Wasm global and local variables are converted into direct memory access. If you could add local variable support to Triton VM, converted to direct memory access in a predefined memory region under the hood, that would significantly reduce the number of emitted Triton VM instructions since local vars are widely used in Wasm codegen for Rust.

I/O is implemented via extern functions in Rust (Wasm imports) for now, with the idea of using the Wasm component model and WASI in the future.

I plan to implement the rest of Wasm and Triton VM instructions and compile more complex Rust programs.

I'm looking for feedback, and I'd be happy to answer any questions.

@Sword-Smith
Copy link
Collaborator

Sword-Smith commented May 30, 2023

Very nice work. Did you already see the attempt by @sshine and myself on writing a compiler from Rust to TASM (Triton VM assembly)? It' s a bit more of a "brute force" attempt than yours, as we are translating from the RUST AST, to our own AST, and then directly to TASM, we don't go through WASM as you do.

This makes me think that your attempt could be a better generic solution than what we started on. Something that will be more extensible. We, for example, only plan to support a rather small subset of Rust, whereas I think your solution could support the entire Rust language, as you use a well-known intermediate language.

I tried running your tests, and 2 out of 15 fail. Maybe because your test conditions are too strict? The way we test our compiler is by running its code on the actual Triton VM and inspecting its output. The wrapper functions for that test functionality was a bit cumbersome to write, but once you have it set up, it's a very convenient way of testing the generated code. Have a look here:
https://github.com/TritonVM/tasm-lang/blob/master/src/tests_and_benchmarks/shared_test.rs#LL246C8-L246C31

Feel free to just copy the code or use it to write your own testers, if you're interested.

You might also be interested in the TASM standard library that we wrote. That should make it fairly easy for you to add u64 support, list support, and some more functionality. If you depend on this library, you can import many code snippets very easily, so you don't have to maintain those yourself.

FYI I had to run this command before I could run your tests:
rustup target add wasm32-unknown-unknown

@jan-ferdinand
Copy link
Member

Cool stuff, @greenhat!

local variable support to Triton VM

Thanks for the suggestion. There are some vague plans to add a few registers to Triton VM, which could behave like local variables. However, due to how arithmetization of registers works in zkVMs, there would have to be a finite (and likely rather small) number of them, so I'm not sure they could be used for general compilation from WASM. Is there a limit to the number of local variables that can be declared in a WASM function?

Is using the stack infeasible for local variables?

@greenhat
Copy link
Author

Very nice work. Did you already see the attempt by @sshine and myself on writing a compiler from Rust to TASM (Triton VM assembly)? It' s a bit more of a "brute force" attempt than yours, as we are translating from the RUST AST, to our own AST, and then directly to TASM, we don't go through WASM as you do.

This makes me think that your attempt could be a better generic solution than what we started on. Something that will be more extensible. We, for example, only plan to support a rather small subset of Rust, whereas I think your solution could support the entire Rust language, as you use a well-known intermediate language.

It depends on how far you can get with it. I'd guess that complex types would be quite challenging. That was the first thing that came to my mind.
Initially, I had thoughts to get MIR or even HIR from rustc, but then I decided to go with Wasm, because it's simpler, mostly stack-based, and it opens the way to support other high-level languages in the future.

I tried running your tests, and 2 out of 15 fail. Maybe because your test conditions are too strict? The way we test our compiler is by running its code on the actual Triton VM and inspecting its output. The wrapper functions for that test functionality was a bit cumbersome to write, but once you have it set up, it's a very convenient way of testing the generated code. Have a look here: https://github.com/TritonVM/tasm-lang/blob/master/src/tests_and_benchmarks/shared_test.rs#LL246C8-L246C31

Oh, I was not prepared that someone will try to run my tests. :) It was most likely a rustc mangling symbols issue that I have not found time to fix until today (it's way too easy to re-generate with expect-test).
I fixed it and tested it on CI so if you could pull main again, it should work now.

I employ a similar approach to test generated code, running it on the Triton VM and comparing the output with the one I got from running either Rust or Wasm code natively. See https://github.com/greenhat/wasm2zk/blob/15d6c02ab55db8c34bbaaf127c42cd15780e0c29/crates/codegen-tritonvm/src/codegen/sem_tests/fib.rs#L8-L20
and https://github.com/greenhat/wasm2zk/blob/15d6c02ab55db8c34bbaaf127c42cd15780e0c29/crates/codegen-tritonvm/src/codegen/sem_tests.rs#L62-L63
BTW, I appreciate the tracing feature! It was very handy when I was debugging my codegen.

You might also be interested in the TASM standard library that we wrote. That should make it fairly easy for you to add u64 support, list support, and some more functionality. If you depend on this library, you can import many code snippets very easily, so you don't have to maintain those yourself.

Great! I'll use it for sure.

FYI I had to run this command before I could run your tests: rustup target add wasm32-unknown-unknown

Thanks! I added it to the README.

@greenhat
Copy link
Author

Cool stuff, @greenhat!

Thank you!

local variable support to Triton VM

Thanks for the suggestion. There are some vague plans to add a few registers to Triton VM, which could behave like local variables. However, due to how arithmetization of registers works in zkVMs, there would have to be a finite (and likely rather small) number of them, so I'm not sure they could be used for general compilation from WASM. Is there a limit to the number of local variables that can be declared in a WASM function?

I don't think there is a limit besides that it's u32. Actually, when I asked for it I probably thought of Miden procedures that have local vars easily "mappable" from Wasm. My bad. I'm adding Miden to OmniZK now and got confused. I somehow forgot that in Triton VM functions are more like labels, so there is no notion of local space for functions besides jump stack which is not fit for this kind of task. If so, then please nevermind,

Is using the stack infeasible for local variables?

I suspect Wasm codegen in LLVM uses locals for almost all the SSA vars. I thought of putting them on the stack but the amount of tracking in order to access them in the function code was overwhelming so I opt out to mapping locals to memory at the end of the range so it's fewer chances to conflict with native Wasm memory ops.

@aszepieniec
Copy link
Collaborator

aszepieniec commented Jun 1, 2023

Awesome stuff, hats off ^^

To get a feel for useful programs that need to run on Triton VM, here is a snippet of rust code that we are struggling to compile to TASM.

    fn verify_raw(public_input: &[BFieldElement], secret_witness: &[BFieldElement]) {
        let removal_records_integrity_witness =
            *RemovalRecordsIntegrityWitness::decode(secret_witness).unwrap();
        let items = removal_records_integrity_witness
            .input_utxos
            .iter()
            .map(Hash::hash)
            .collect_vec();
        let mut digests_of_derived_index_sets = items
            .iter()
            .zip(removal_records_integrity_witness.membership_proofs.iter())
            .map(|(utxo, msmp)| {
                AbsoluteIndexSet::new(&get_swbf_indices::<Hash>(
                    &Hash::hash(utxo),
                    &msmp.sender_randomness,
                    &msmp.receiver_preimage,
                    msmp.auth_path_aocl.leaf_index,
                ))
                .encode()
            })
            .map(|x| Hash::hash_varlen(&x))
            .collect_vec();
        digests_of_derived_index_sets.sort();
        let mut digests_of_claimed_index_sets = removal_records_integrity_witness
            .kernel
            .inputs
            .iter()
            .map(|input| input.absolute_indices.encode())
            .map(|e| Hash::hash_varlen(&e))
            .collect_vec();
        digests_of_claimed_index_sets.sort();
        assert_eq!(digests_of_derived_index_sets, digests_of_claimed_index_sets);
        assert!(items
            .iter()
            .zip(removal_records_integrity_witness.membership_proofs.iter())
            .map(|(item, msmp)| {
                (
                    commit::<Hash>(
                        item,
                        &msmp.sender_randomness,
                        &msmp.receiver_preimage.hash::<Hash>(),
                    ),
                    &msmp.auth_path_aocl,
                )
            })
            .all(|(cc, mp)| {
                mp.verify(
                    &removal_records_integrity_witness
                        .mutator_set_accumulator
                        .kernel
                        .aocl
                        .get_peaks(),
                    &cc.canonical_commitment,
                    removal_records_integrity_witness
                        .mutator_set_accumulator
                        .kernel
                        .aocl
                        .count_leaves(),
                )
                .0
            }));
        assert_eq!(
            removal_records_integrity_witness
                .mutator_set_accumulator
                .hash(),
            removal_records_integrity_witness.kernel.mutator_set_hash
        );
    }

@Sword-Smith
Copy link
Collaborator

Sword-Smith commented Jun 1, 2023

Adding a bit of context to @aszepieniec's code example

pub fn get_swbf_indices<H: AlgebraicHasher>(
    item: &Digest,
    sender_randomness: &Digest,
    receiver_preimage: &Digest,
    aocl_leaf_index: u64,
) -> [u128; NUM_TRIALS as usize] {
    let batch_index: u128 = aocl_leaf_index as u128 / BATCH_SIZE as u128;
    let batch_offset: u128 = batch_index * CHUNK_SIZE as u128;
    let leaf_index_bfes = aocl_leaf_index.encode();
    let leaf_index_bfes_len = leaf_index_bfes.len();
    let input = [
        item.encode(),
        sender_randomness.encode(),
        receiver_preimage.encode(),
        leaf_index_bfes,
        // Pad with zeros until length is a multiple of RATE; according to spec
        vec![BFieldElement::zero(); DIGEST_LENGTH - leaf_index_bfes_len],
    ]
    .concat();
    assert_eq!(input.len() % DIGEST_LENGTH, 0);
    let mut sponge = <H as SpongeHasher>::init();
    H::absorb_repeatedly(&mut sponge, input.iter());
    H::sample_indices(&mut sponge, WINDOW_SIZE, NUM_TRIALS as usize)
        .into_iter()
        .map(|sample_index| sample_index as u128 + batch_offset)
        .collect_vec()
        .try_into()
        .unwrap()
}


pub struct AbsoluteIndexSet([u128; NUM_TRIALS as usize]);

pub fn commit<H: AlgebraicHasher>(
    item: &Digest,
    sender_randomness: &Digest,
    receiver_digest: &Digest,
) -> AdditionRecord {
    let canonical_commitment =
        H::hash_pair(&H::hash_pair(item, sender_randomness), receiver_digest);

    AdditionRecord::new(canonical_commitment)
}

The mp.verify logic is implemented is tasm-lib. It's just MMR authentication path verification.

I'll probably work on adding an implementation to tasm-lib of hash_varlen of a list living in memory.

@greenhat
Copy link
Author

greenhat commented Jun 2, 2023

Thank you, @aszepieniec and @Sword-Smith! I was unsure of what to tackle next after the Fibonacci sequence. It seems quite challenging, so I plan to break it into pieces and work my way up.

@aszepieniec
Copy link
Collaborator

aszepieniec commented Jun 7, 2023

Just so you know: we just pushed (tasm) code snippets implementing higher-order functions map and zip to tasm-lib, in case you're interested. We also have memcpy. And hash_varlen.

@aszepieniec
Copy link
Collaborator

@greenhat: Out of curiosity: how are you (planning to) represent lists?

Here's how we do it in tasm-lib:

  • Unsafe lists start with the length, followed by all list elements in order. All of this lives in memory; the stack contains only the address of the length.
  • Safe lists start with the length, followed by the capacity, followed by all elements. Likewise, all of this lives in memory and the stack contains only the address of the length.

The reason why you might prefer safe lists to unsafe lists is because you might want to use memory located after the list. But if the list grows without a capacity to limit it, you risk overwriting that memory. So there is a lot more bounds checking going on in safe lists than in unsafe lists but in exchange you get better assurances of non-interfering data structures.

In addition to safe and unsafe lists, we also use a static and a dynamic allocator. They simulate a heap, so yo can ask them for addresses in RAM that correspond to x-many contiguous words. When you build a new list, you allocate the requisite memory using one of these allocators.

The reason why we are curious is because using the same list format would mean that algorithms and data structures building on top of that would be compatible across compilers.

@greenhat
Copy link
Author

@aszepieniec Thanks for the detailed explanation!

Good question! Wasm does not have lists - https://webassembly.github.io/spec/core/syntax/instructions.html#
I expect them to be compiled into memory access instructions, although I did not test it yet.
I hope it would be possible to write conversions to pass them into/from tasm-lib.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants