Skip to content

Commit

Permalink
Merge pull request ethereum#35 from zama-ai/petar/precompiled-contrac…
Browse files Browse the repository at this point in the history
…t-tests

Add unit tests for Zama-specific precompiles
  • Loading branch information
dartdart26 authored Jan 31, 2023
2 parents 6dc797f + 4e0d99c commit 83bbc84
Show file tree
Hide file tree
Showing 2 changed files with 193 additions and 15 deletions.
6 changes: 1 addition & 5 deletions core/vm/contracts.go
Original file line number Diff line number Diff line change
Expand Up @@ -1122,6 +1122,7 @@ type tomlConfigOptions struct {

var tomlConfig tomlConfigOptions

//lint:ignore U1000 Want to keep to show how Ed25519 keys were generated.
func generateEd25519Keys() error {
public, private, err := ed25519.GenerateKey(nil)
if err != nil {
Expand Down Expand Up @@ -1399,11 +1400,6 @@ func requireKey(ciphertext []byte) string {
return crypto.Keccak256Hash(ciphertext).Hex()[2:]
}

func requireKeyFromHash(hash common.Hash) string {
// Take the hash and remove the leading 0x.
return hash.Hex()[2:]
}

func requireURL(key *string) string {
return tomlConfig.Oracle.OracleDBAddress + "/require/" + *key
}
Expand Down
202 changes: 192 additions & 10 deletions core/vm/contracts_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,17 @@ type precompiledFailureTest struct {
Name string
}

// allPrecompiles does not map to the actual set of precompiles, as it also contains
type emptyPrecompileAccessibleState struct{}

func (s *emptyPrecompileAccessibleState) Interpreter() *EVMInterpreter {
return nil
}

var emptyPrecompileState PrecompileAccessibleState = &emptyPrecompileAccessibleState{}

// allStatelessPrecompiles does not map to the actual set of precompiles, as it also contains
// repriced versions of precompiles at certain slots
var allPrecompiles = map[common.Address]PrecompiledContract{
var allStatelessPrecompiles = map[common.Address]PrecompiledContract{
common.BytesToAddress([]byte{1}): &ecrecover{},
common.BytesToAddress([]byte{2}): &sha256hash{},
common.BytesToAddress([]byte{3}): &ripemd160hash{},
Expand Down Expand Up @@ -92,11 +100,12 @@ var blake2FMalformedInputTests = []precompiledFailureTest{
}

func testPrecompiled(addr string, test precompiledTest, t *testing.T) {
p := allPrecompiles[common.HexToAddress(addr)]
a := common.HexToAddress(addr)
p := allStatelessPrecompiles[common.HexToAddress(addr)]
in := common.Hex2Bytes(test.Input)
gas := p.RequiredGas(in)
t.Run(fmt.Sprintf("%s-Gas=%d", test.Name, gas), func(t *testing.T) {
if res, _, err := RunPrecompiledContract(p, in, gas); err != nil {
if res, _, err := RunPrecompiledContract(p, emptyPrecompileState, a, a, in, gas, false); err != nil {
t.Error(err)
} else if common.Bytes2Hex(res) != test.Expected {
t.Errorf("Expected %v, got %v", test.Expected, common.Bytes2Hex(res))
Expand All @@ -113,12 +122,13 @@ func testPrecompiled(addr string, test precompiledTest, t *testing.T) {
}

func testPrecompiledOOG(addr string, test precompiledTest, t *testing.T) {
p := allPrecompiles[common.HexToAddress(addr)]
p := allStatelessPrecompiles[common.HexToAddress(addr)]
in := common.Hex2Bytes(test.Input)
gas := p.RequiredGas(in) - 1

t.Run(fmt.Sprintf("%s-Gas=%d", test.Name, gas), func(t *testing.T) {
_, _, err := RunPrecompiledContract(p, in, gas)
a := common.HexToAddress(addr)
_, _, err := RunPrecompiledContract(p, emptyPrecompileState, a, a, in, gas, false)
if err.Error() != "out of gas" {
t.Errorf("Expected error [out of gas], got [%v]", err)
}
Expand All @@ -131,11 +141,12 @@ func testPrecompiledOOG(addr string, test precompiledTest, t *testing.T) {
}

func testPrecompiledFailure(addr string, test precompiledFailureTest, t *testing.T) {
p := allPrecompiles[common.HexToAddress(addr)]
a := common.HexToAddress(addr)
p := allStatelessPrecompiles[common.HexToAddress(addr)]
in := common.Hex2Bytes(test.Input)
gas := p.RequiredGas(in)
t.Run(test.Name, func(t *testing.T) {
_, _, err := RunPrecompiledContract(p, in, gas)
_, _, err := RunPrecompiledContract(p, emptyPrecompileState, a, a, in, gas, false)
if err.Error() != test.ExpectedError {
t.Errorf("Expected error [%v], got [%v]", test.ExpectedError, err)
}
Expand All @@ -151,7 +162,8 @@ func benchmarkPrecompiled(addr string, test precompiledTest, bench *testing.B) {
if test.NoBenchmark {
return
}
p := allPrecompiles[common.HexToAddress(addr)]
a := common.HexToAddress(addr)
p := allStatelessPrecompiles[common.HexToAddress(addr)]
in := common.Hex2Bytes(test.Input)
reqGas := p.RequiredGas(in)

Expand All @@ -167,7 +179,7 @@ func benchmarkPrecompiled(addr string, test precompiledTest, bench *testing.B) {
bench.ResetTimer()
for i := 0; i < bench.N; i++ {
copy(data, in)
res, _, err = RunPrecompiledContract(p, data, reqGas)
res, _, err = RunPrecompiledContract(p, emptyPrecompileState, a, a, in, reqGas, false)
}
bench.StopTimer()
elapsed := uint64(time.Since(start))
Expand Down Expand Up @@ -391,3 +403,173 @@ func BenchmarkPrecompiledBLS12381G2MultiExpWorstCase(b *testing.B) {
}
benchmarkPrecompiled("0f", testcase, b)
}

// Zama-specific precompiled contracts

type statefulPrecompileAccessibleState struct {
interpreter *EVMInterpreter
}

func (s *statefulPrecompileAccessibleState) Interpreter() *EVMInterpreter {
return s.interpreter
}

func newState() *statefulPrecompileAccessibleState {
s := new(statefulPrecompileAccessibleState)
cfg := Config{}
evm := &EVM{}
s.interpreter = NewEVMInterpreter(evm, cfg)
evm.interpreter = s.interpreter
return s
}

func verifyCiphertextInTestState(s *statefulPrecompileAccessibleState, value uint64, depth int) (*tfheCiphertext, common.Hash) {
ct := new(tfheCiphertext)
ct.encrypt(value)
hash := ct.getHash()
s.interpreter.verifiedCiphertexts[hash] = &verifiedCiphertext{depth, ct}
return ct, ct.getHash()
}

func generateInput(hashes ...common.Hash) []byte {
ret := make([]byte, 0)
for _, hash := range hashes {
ret = append(ret, hash.Bytes()...)
}
return ret
}

func TestFheAdd(t *testing.T) {
c := &fheAdd{}
depth := 1
state := newState()
state.interpreter.evm.depth = depth
state.interpreter.evm.Commit = true
addr := common.Address{}
readOnly := false
_, lhs_hash := verifyCiphertextInTestState(state, 1, depth)
_, rhs_hash := verifyCiphertextInTestState(state, 1, depth)
input := generateInput(lhs_hash, rhs_hash)
out, err := c.Run(state, addr, addr, input, readOnly)
if err != nil {
t.Fatalf(err.Error())
}
res, exists := state.interpreter.verifiedCiphertexts[common.BytesToHash(out)]
if !exists {
t.Fatalf("output ciphertext is not found in verifiedCiphertexts")
}
decrypted := res.ciphertext.decrypt()
if decrypted != 2 {
t.Fatalf("invalid decrypted result")
}
}

func TestFheSub(t *testing.T) {
c := &fheSub{}
depth := 1
state := newState()
state.interpreter.evm.depth = depth
state.interpreter.evm.Commit = true
addr := common.Address{}
readOnly := false
_, lhs_hash := verifyCiphertextInTestState(state, 2, depth)
_, rhs_hash := verifyCiphertextInTestState(state, 1, depth)
input := generateInput(lhs_hash, rhs_hash)
out, err := c.Run(state, addr, addr, input, readOnly)
if err != nil {
t.Fatalf(err.Error())
}
res, exists := state.interpreter.verifiedCiphertexts[common.BytesToHash(out)]
if !exists {
t.Fatalf("output ciphertext is not found in verifiedCiphertexts")
}
decrypted := res.ciphertext.decrypt()
if decrypted != 1 {
t.Fatalf("invalid decrypted result")
}
}

func TestFheLte(t *testing.T) {
c := &fheLte{}
depth := 1
state := newState()
state.interpreter.evm.depth = depth
state.interpreter.evm.Commit = true
addr := common.Address{}
readOnly := false
_, lhs_hash := verifyCiphertextInTestState(state, 2, depth)
_, rhs_hash := verifyCiphertextInTestState(state, 1, depth)

// 2 <= 1
input1 := generateInput(lhs_hash, rhs_hash)
out, err := c.Run(state, addr, addr, input1, readOnly)
if err != nil {
t.Fatalf(err.Error())
}
res, exists := state.interpreter.verifiedCiphertexts[common.BytesToHash(out)]
if !exists {
t.Fatalf("output ciphertext is not found in verifiedCiphertexts")
}
decrypted := res.ciphertext.decrypt()
if decrypted != 0 {
t.Fatalf("invalid decrypted result")
}

// 1 <= 2
input2 := generateInput(rhs_hash, lhs_hash)
out, err = c.Run(state, addr, addr, input2, readOnly)
if err != nil {
t.Fatalf(err.Error())
}
res, exists = state.interpreter.verifiedCiphertexts[common.BytesToHash(out)]
if !exists {
t.Fatalf("output ciphertext is not found in verifiedCiphertexts")
}
decrypted = res.ciphertext.decrypt()
if decrypted != 1 {
t.Fatalf("invalid decrypted result")
}
}

func TestUnknownCiphertextHandle(t *testing.T) {
depth := 1
state := newState()
state.interpreter.evm.depth = depth
_, hash := verifyCiphertextInTestState(state, 2, depth)

_, found := getVerifiedCiphertext(state, hash)
if !found {
t.Fatalf("expected ciphertext is verified")
}

// change the hash
hash[0]++
_, found = getVerifiedCiphertext(state, hash)
if found {
t.Fatalf("expected ciphertext is not verified")
}
}

func TestCiphertextNotVerifiedAtDepth(t *testing.T) {
state := newState()
state.interpreter.evm.depth = 1
verifiedDepth := 2
_, hash := verifyCiphertextInTestState(state, 1, verifiedDepth)

_, found := getVerifiedCiphertext(state, hash)
if found {
t.Fatalf("expected ciphertext is not verified")
}
}

func TestCiphertextVerifiedAtFurtherDepth(t *testing.T) {
state := newState()
state.interpreter.evm.depth = 3
verifiedDepth := 2
_, hash := verifyCiphertextInTestState(state, 1, verifiedDepth)

_, found := getVerifiedCiphertext(state, hash)
if !found {
t.Fatalf("expected ciphertext is verified")
}
}

0 comments on commit 83bbc84

Please sign in to comment.