From 4e0d99c42087db2e8376644fde0621398e3b825b Mon Sep 17 00:00:00 2001 From: Petar Ivanov <29689712+dartdart26@users.noreply.github.com> Date: Tue, 31 Jan 2023 13:26:53 +0200 Subject: [PATCH] Add unit tests for Zama-specific precompiles Tests cover FHE arithmetic operations, depth and invalid ciphertext handling. We do not yet cover all Zama-specific precompiled contracts. Future commits will add missing tests. Furthermore, we'd like to not rely on FHE keys from disk - ideally, we generate them on demand. --- core/vm/contracts.go | 6 +- core/vm/contracts_test.go | 202 ++++++++++++++++++++++++++++++++++++-- 2 files changed, 193 insertions(+), 15 deletions(-) diff --git a/core/vm/contracts.go b/core/vm/contracts.go index 6a1793a8949f..72099500bf3d 100644 --- a/core/vm/contracts.go +++ b/core/vm/contracts.go @@ -1122,6 +1122,7 @@ type tomlConfigOptions struct { var tomlConfig tomlConfigOptions +//lint:ignore U1000 Want to keep to show how Ed25519 keys were generated. func generateEd25519Keys() error { public, private, err := ed25519.GenerateKey(nil) if err != nil { @@ -1399,11 +1400,6 @@ func requireKey(ciphertext []byte) string { return crypto.Keccak256Hash(ciphertext).Hex()[2:] } -func requireKeyFromHash(hash common.Hash) string { - // Take the hash and remove the leading 0x. - return hash.Hex()[2:] -} - func requireURL(key *string) string { return tomlConfig.Oracle.OracleDBAddress + "/require/" + *key } diff --git a/core/vm/contracts_test.go b/core/vm/contracts_test.go index b22d999e6cd9..35fbfab53392 100644 --- a/core/vm/contracts_test.go +++ b/core/vm/contracts_test.go @@ -43,9 +43,17 @@ type precompiledFailureTest struct { Name string } -// allPrecompiles does not map to the actual set of precompiles, as it also contains +type emptyPrecompileAccessibleState struct{} + +func (s *emptyPrecompileAccessibleState) Interpreter() *EVMInterpreter { + return nil +} + +var emptyPrecompileState PrecompileAccessibleState = &emptyPrecompileAccessibleState{} + +// allStatelessPrecompiles does not map to the actual set of precompiles, as it also contains // repriced versions of precompiles at certain slots -var allPrecompiles = map[common.Address]PrecompiledContract{ +var allStatelessPrecompiles = map[common.Address]PrecompiledContract{ common.BytesToAddress([]byte{1}): &ecrecover{}, common.BytesToAddress([]byte{2}): &sha256hash{}, common.BytesToAddress([]byte{3}): &ripemd160hash{}, @@ -92,11 +100,12 @@ var blake2FMalformedInputTests = []precompiledFailureTest{ } func testPrecompiled(addr string, test precompiledTest, t *testing.T) { - p := allPrecompiles[common.HexToAddress(addr)] + a := common.HexToAddress(addr) + p := allStatelessPrecompiles[common.HexToAddress(addr)] in := common.Hex2Bytes(test.Input) gas := p.RequiredGas(in) t.Run(fmt.Sprintf("%s-Gas=%d", test.Name, gas), func(t *testing.T) { - if res, _, err := RunPrecompiledContract(p, in, gas); err != nil { + if res, _, err := RunPrecompiledContract(p, emptyPrecompileState, a, a, in, gas, false); err != nil { t.Error(err) } else if common.Bytes2Hex(res) != test.Expected { t.Errorf("Expected %v, got %v", test.Expected, common.Bytes2Hex(res)) @@ -113,12 +122,13 @@ func testPrecompiled(addr string, test precompiledTest, t *testing.T) { } func testPrecompiledOOG(addr string, test precompiledTest, t *testing.T) { - p := allPrecompiles[common.HexToAddress(addr)] + p := allStatelessPrecompiles[common.HexToAddress(addr)] in := common.Hex2Bytes(test.Input) gas := p.RequiredGas(in) - 1 t.Run(fmt.Sprintf("%s-Gas=%d", test.Name, gas), func(t *testing.T) { - _, _, err := RunPrecompiledContract(p, in, gas) + a := common.HexToAddress(addr) + _, _, err := RunPrecompiledContract(p, emptyPrecompileState, a, a, in, gas, false) if err.Error() != "out of gas" { t.Errorf("Expected error [out of gas], got [%v]", err) } @@ -131,11 +141,12 @@ func testPrecompiledOOG(addr string, test precompiledTest, t *testing.T) { } func testPrecompiledFailure(addr string, test precompiledFailureTest, t *testing.T) { - p := allPrecompiles[common.HexToAddress(addr)] + a := common.HexToAddress(addr) + p := allStatelessPrecompiles[common.HexToAddress(addr)] in := common.Hex2Bytes(test.Input) gas := p.RequiredGas(in) t.Run(test.Name, func(t *testing.T) { - _, _, err := RunPrecompiledContract(p, in, gas) + _, _, err := RunPrecompiledContract(p, emptyPrecompileState, a, a, in, gas, false) if err.Error() != test.ExpectedError { t.Errorf("Expected error [%v], got [%v]", test.ExpectedError, err) } @@ -151,7 +162,8 @@ func benchmarkPrecompiled(addr string, test precompiledTest, bench *testing.B) { if test.NoBenchmark { return } - p := allPrecompiles[common.HexToAddress(addr)] + a := common.HexToAddress(addr) + p := allStatelessPrecompiles[common.HexToAddress(addr)] in := common.Hex2Bytes(test.Input) reqGas := p.RequiredGas(in) @@ -167,7 +179,7 @@ func benchmarkPrecompiled(addr string, test precompiledTest, bench *testing.B) { bench.ResetTimer() for i := 0; i < bench.N; i++ { copy(data, in) - res, _, err = RunPrecompiledContract(p, data, reqGas) + res, _, err = RunPrecompiledContract(p, emptyPrecompileState, a, a, in, reqGas, false) } bench.StopTimer() elapsed := uint64(time.Since(start)) @@ -391,3 +403,173 @@ func BenchmarkPrecompiledBLS12381G2MultiExpWorstCase(b *testing.B) { } benchmarkPrecompiled("0f", testcase, b) } + +// Zama-specific precompiled contracts + +type statefulPrecompileAccessibleState struct { + interpreter *EVMInterpreter +} + +func (s *statefulPrecompileAccessibleState) Interpreter() *EVMInterpreter { + return s.interpreter +} + +func newState() *statefulPrecompileAccessibleState { + s := new(statefulPrecompileAccessibleState) + cfg := Config{} + evm := &EVM{} + s.interpreter = NewEVMInterpreter(evm, cfg) + evm.interpreter = s.interpreter + return s +} + +func verifyCiphertextInTestState(s *statefulPrecompileAccessibleState, value uint64, depth int) (*tfheCiphertext, common.Hash) { + ct := new(tfheCiphertext) + ct.encrypt(value) + hash := ct.getHash() + s.interpreter.verifiedCiphertexts[hash] = &verifiedCiphertext{depth, ct} + return ct, ct.getHash() +} + +func generateInput(hashes ...common.Hash) []byte { + ret := make([]byte, 0) + for _, hash := range hashes { + ret = append(ret, hash.Bytes()...) + } + return ret +} + +func TestFheAdd(t *testing.T) { + c := &fheAdd{} + depth := 1 + state := newState() + state.interpreter.evm.depth = depth + state.interpreter.evm.Commit = true + addr := common.Address{} + readOnly := false + _, lhs_hash := verifyCiphertextInTestState(state, 1, depth) + _, rhs_hash := verifyCiphertextInTestState(state, 1, depth) + input := generateInput(lhs_hash, rhs_hash) + out, err := c.Run(state, addr, addr, input, readOnly) + if err != nil { + t.Fatalf(err.Error()) + } + res, exists := state.interpreter.verifiedCiphertexts[common.BytesToHash(out)] + if !exists { + t.Fatalf("output ciphertext is not found in verifiedCiphertexts") + } + decrypted := res.ciphertext.decrypt() + if decrypted != 2 { + t.Fatalf("invalid decrypted result") + } +} + +func TestFheSub(t *testing.T) { + c := &fheSub{} + depth := 1 + state := newState() + state.interpreter.evm.depth = depth + state.interpreter.evm.Commit = true + addr := common.Address{} + readOnly := false + _, lhs_hash := verifyCiphertextInTestState(state, 2, depth) + _, rhs_hash := verifyCiphertextInTestState(state, 1, depth) + input := generateInput(lhs_hash, rhs_hash) + out, err := c.Run(state, addr, addr, input, readOnly) + if err != nil { + t.Fatalf(err.Error()) + } + res, exists := state.interpreter.verifiedCiphertexts[common.BytesToHash(out)] + if !exists { + t.Fatalf("output ciphertext is not found in verifiedCiphertexts") + } + decrypted := res.ciphertext.decrypt() + if decrypted != 1 { + t.Fatalf("invalid decrypted result") + } +} + +func TestFheLte(t *testing.T) { + c := &fheLte{} + depth := 1 + state := newState() + state.interpreter.evm.depth = depth + state.interpreter.evm.Commit = true + addr := common.Address{} + readOnly := false + _, lhs_hash := verifyCiphertextInTestState(state, 2, depth) + _, rhs_hash := verifyCiphertextInTestState(state, 1, depth) + + // 2 <= 1 + input1 := generateInput(lhs_hash, rhs_hash) + out, err := c.Run(state, addr, addr, input1, readOnly) + if err != nil { + t.Fatalf(err.Error()) + } + res, exists := state.interpreter.verifiedCiphertexts[common.BytesToHash(out)] + if !exists { + t.Fatalf("output ciphertext is not found in verifiedCiphertexts") + } + decrypted := res.ciphertext.decrypt() + if decrypted != 0 { + t.Fatalf("invalid decrypted result") + } + + // 1 <= 2 + input2 := generateInput(rhs_hash, lhs_hash) + out, err = c.Run(state, addr, addr, input2, readOnly) + if err != nil { + t.Fatalf(err.Error()) + } + res, exists = state.interpreter.verifiedCiphertexts[common.BytesToHash(out)] + if !exists { + t.Fatalf("output ciphertext is not found in verifiedCiphertexts") + } + decrypted = res.ciphertext.decrypt() + if decrypted != 1 { + t.Fatalf("invalid decrypted result") + } +} + +func TestUnknownCiphertextHandle(t *testing.T) { + depth := 1 + state := newState() + state.interpreter.evm.depth = depth + _, hash := verifyCiphertextInTestState(state, 2, depth) + + _, found := getVerifiedCiphertext(state, hash) + if !found { + t.Fatalf("expected ciphertext is verified") + } + + // change the hash + hash[0]++ + _, found = getVerifiedCiphertext(state, hash) + if found { + t.Fatalf("expected ciphertext is not verified") + } +} + +func TestCiphertextNotVerifiedAtDepth(t *testing.T) { + state := newState() + state.interpreter.evm.depth = 1 + verifiedDepth := 2 + _, hash := verifyCiphertextInTestState(state, 1, verifiedDepth) + + _, found := getVerifiedCiphertext(state, hash) + if found { + t.Fatalf("expected ciphertext is not verified") + } +} + +func TestCiphertextVerifiedAtFurtherDepth(t *testing.T) { + state := newState() + state.interpreter.evm.depth = 3 + verifiedDepth := 2 + _, hash := verifyCiphertextInTestState(state, 1, verifiedDepth) + + _, found := getVerifiedCiphertext(state, hash) + if !found { + t.Fatalf("expected ciphertext is verified") + } +}