diff --git a/core/state/journal.go b/core/state/journal.go index 6131e4c0a825..ad4a654fc6a2 100644 --- a/core/state/journal.go +++ b/core/state/journal.go @@ -20,7 +20,6 @@ import ( "maps" "github.com/ethereum/go-ethereum/common" - "github.com/ethereum/go-ethereum/core/tracing" "github.com/holiman/uint256" ) @@ -202,7 +201,7 @@ func (ch selfDestructChange) revert(s *StateDB) { obj := s.getStateObject(*ch.account) if obj != nil { obj.selfDestructed = ch.prev - obj.setBalanceLogged(ch.prevbalance, tracing.BalanceChangeRevert) + obj.setBalance(ch.prevbalance) } } @@ -234,7 +233,7 @@ func (ch touchChange) copy() journalEntry { } func (ch balanceChange) revert(s *StateDB) { - s.getStateObject(*ch.account).setBalanceLogged(ch.prev, tracing.BalanceChangeRevert) + s.getStateObject(*ch.account).setBalance(ch.prev) } func (ch balanceChange) dirtied() *common.Address { @@ -249,7 +248,7 @@ func (ch balanceChange) copy() journalEntry { } func (ch nonceChange) revert(s *StateDB) { - s.getStateObject(*ch.account).setNonceLogged(ch.prev) + s.getStateObject(*ch.account).setNonce(ch.prev) } func (ch nonceChange) dirtied() *common.Address { @@ -264,7 +263,7 @@ func (ch nonceChange) copy() journalEntry { } func (ch codeChange) revert(s *StateDB) { - s.getStateObject(*ch.account).setCodeLogged(common.BytesToHash(ch.prevhash), ch.prevcode) + s.getStateObject(*ch.account).setCode(common.BytesToHash(ch.prevhash), ch.prevcode) } func (ch codeChange) dirtied() *common.Address { @@ -280,7 +279,7 @@ func (ch codeChange) copy() journalEntry { } func (ch storageChange) revert(s *StateDB) { - s.getStateObject(*ch.account).setStateLogged(ch.key, ch.prevvalue, ch.origvalue) + s.getStateObject(*ch.account).setState(ch.key, ch.prevvalue, ch.origvalue) } func (ch storageChange) dirtied() *common.Address { diff --git a/core/state/state_object.go b/core/state/state_object.go index fc0f6ddbf482..880b715b4b37 100644 --- a/core/state/state_object.go +++ b/core/state/state_object.go @@ -257,11 +257,6 @@ func (s *stateObject) SetState(key, value common.Hash) { prevvalue: prev, origvalue: origin, }) - s.setStateLogged(key, value, origin) -} - -func (s *stateObject) setStateLogged(key, value, origin common.Hash) { - prev, _ := s.getState(key) if s.db.logger != nil && s.db.logger.OnStorageChange != nil { s.db.logger.OnStorageChange(s.address, key, prev, value) } @@ -519,10 +514,6 @@ func (s *stateObject) SetBalance(amount *uint256.Int, reason tracing.BalanceChan account: &s.address, prev: new(uint256.Int).Set(s.data.Balance), }) - s.setBalanceLogged(amount, reason) -} - -func (s *stateObject) setBalanceLogged(amount *uint256.Int, reason tracing.BalanceChangeReason) { if s.db.logger != nil && s.db.logger.OnBalanceChange != nil { s.db.logger.OnBalanceChange(s.address, s.Balance().ToBig(), amount.ToBig(), reason) } @@ -604,11 +595,6 @@ func (s *stateObject) SetCode(codeHash common.Hash, code []byte) { prevhash: s.CodeHash(), prevcode: prevcode, }) - s.setCodeLogged(codeHash, code) -} - -func (s *stateObject) setCodeLogged(codeHash common.Hash, code []byte) { - prevcode := s.Code() if s.db.logger != nil && s.db.logger.OnCodeChange != nil { s.db.logger.OnCodeChange(s.address, common.BytesToHash(s.CodeHash()), prevcode, codeHash, code) } @@ -626,10 +612,6 @@ func (s *stateObject) SetNonce(nonce uint64) { account: &s.address, prev: s.data.Nonce, }) - s.setNonceLogged(nonce) -} - -func (s *stateObject) setNonceLogged(nonce uint64) { if s.db.logger != nil && s.db.logger.OnNonceChange != nil { s.db.logger.OnNonceChange(s.address, s.data.Nonce, nonce) } diff --git a/core/tracing/journal.go b/core/tracing/journal.go new file mode 100644 index 000000000000..75d83b47a0f6 --- /dev/null +++ b/core/tracing/journal.go @@ -0,0 +1,233 @@ +// Copyright 2024 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package tracing + +import ( + "fmt" + "math/big" + "sort" + + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/core/types" +) + +type revision struct { + id int + journalIndex int +} + +// journal is a state change journal to be wrapped around a tracer. +// It will emit the state change hooks with reverse values when a call reverts. +type journal struct { + entries []entry + hooks *Hooks + + validRevisions []revision + nextRevisionId int + curRevisionId int +} + +type entry interface { + revert(tracer *Hooks) +} + +// WrapWithJournal wraps the given tracer with a journaling layer. +func WrapWithJournal(hooks *Hooks) (*Hooks, error) { + if hooks == nil { + return nil, fmt.Errorf("wrapping nil tracer") + } + // No state change to journal. + if hooks.OnBalanceChange == nil && hooks.OnNonceChange == nil && hooks.OnCodeChange == nil && hooks.OnStorageChange == nil { + return hooks, nil + } + var ( + j = &journal{entries: make([]entry, 0), hooks: hooks} + wrapped = &Hooks{ + OnTxEnd: j.OnTxEnd, + OnEnter: j.OnEnter, + OnExit: j.OnExit, + } + ) + if hooks.OnBalanceChange != nil { + wrapped.OnBalanceChange = j.OnBalanceChange + } + if hooks.OnNonceChange != nil { + wrapped.OnNonceChange = j.OnNonceChange + } + if hooks.OnCodeChange != nil { + wrapped.OnCodeChange = j.OnCodeChange + } + if hooks.OnStorageChange != nil { + wrapped.OnStorageChange = j.OnStorageChange + } + return wrapped, nil +} + +// reset clears the journal, after this operation the journal can be used anew. +// It is semantically similar to calling 'NewJournal', but the underlying slices +// can be reused. +func (j *journal) reset() { + j.entries = j.entries[:0] + j.validRevisions = j.validRevisions[:0] + j.nextRevisionId = 0 +} + +// snapshot returns an identifier for the current revision of the state. +func (j *journal) snapshot() int { + id := j.nextRevisionId + j.nextRevisionId++ + j.validRevisions = append(j.validRevisions, revision{id, j.length()}) + return id +} + +// revertToSnapshot reverts all state changes made since the given revision. +func (j *journal) revertToSnapshot(revid int, hooks *Hooks) { + // Find the snapshot in the stack of valid snapshots. + idx := sort.Search(len(j.validRevisions), func(i int) bool { + return j.validRevisions[i].id >= revid + }) + if idx == len(j.validRevisions) || j.validRevisions[idx].id != revid { + panic(fmt.Errorf("revision id %v cannot be reverted", revid)) + } + snapshot := j.validRevisions[idx].journalIndex + + // Replay the journal to undo changes and remove invalidated snapshots + j.revert(hooks, snapshot) + j.validRevisions = j.validRevisions[:idx] +} + +// revert undoes a batch of journaled modifications. +func (j *journal) revert(hooks *Hooks, snapshot int) { + for i := len(j.entries) - 1; i >= snapshot; i-- { + // Undo the changes made by the operation + j.entries[i].revert(hooks) + } + j.entries = j.entries[:snapshot] +} + +// length returns the current number of entries in the journal. +func (j *journal) length() int { + return len(j.entries) +} + +func (j *journal) OnTxEnd(receipt *types.Receipt, err error) { + j.reset() +} + +func (j *journal) OnEnter(depth int, typ byte, from common.Address, to common.Address, input []byte, gas uint64, value *big.Int) { + j.curRevisionId = j.snapshot() + if j.hooks != nil && j.hooks.OnEnter != nil { + j.hooks.OnEnter(depth, typ, from, to, input, gas, value) + } +} + +func (j *journal) OnExit(depth int, output []byte, gasUsed uint64, err error, reverted bool) { + if reverted { + j.revertToSnapshot(j.curRevisionId, j.hooks) + } + j.curRevisionId-- + if j.hooks != nil && j.hooks.OnExit != nil { + j.hooks.OnExit(depth, output, gasUsed, err, reverted) + } +} + +func (j *journal) OnBalanceChange(addr common.Address, prev, new *big.Int, reason BalanceChangeReason) { + j.entries = append(j.entries, balanceChange{addr: addr, prev: prev, new: new}) + if j.hooks != nil && j.hooks.OnBalanceChange != nil { + j.hooks.OnBalanceChange(addr, prev, new, reason) + } +} + +func (j *journal) OnNonceChange(addr common.Address, prev, new uint64) { + j.entries = append(j.entries, nonceChange{addr: addr, prev: prev, new: new}) + if j.hooks != nil && j.hooks.OnNonceChange != nil { + j.hooks.OnNonceChange(addr, prev, new) + } +} + +func (j *journal) OnCodeChange(addr common.Address, prevCodeHash common.Hash, prevCode []byte, codeHash common.Hash, code []byte) { + j.entries = append(j.entries, codeChange{ + addr: addr, + prevCodeHash: prevCodeHash, + prevCode: prevCode, + newCodeHash: codeHash, + newCode: code, + }) + if j.hooks != nil && j.hooks.OnCodeChange != nil { + j.hooks.OnCodeChange(addr, codeHash, code, prevCodeHash, prevCode) + } +} + +func (j *journal) OnStorageChange(addr common.Address, slot common.Hash, prev, new common.Hash) { + j.entries = append(j.entries, storageChange{addr: addr, slot: slot, prev: prev, new: new}) + if j.hooks != nil && j.hooks.OnStorageChange != nil { + j.hooks.OnStorageChange(addr, slot, new, prev) + } +} + +type ( + balanceChange struct { + addr common.Address + prev *big.Int + new *big.Int + } + + nonceChange struct { + addr common.Address + prev uint64 + new uint64 + } + + codeChange struct { + addr common.Address + prevCodeHash common.Hash + prevCode []byte + newCodeHash common.Hash + newCode []byte + } + + storageChange struct { + addr common.Address + slot common.Hash + prev common.Hash + new common.Hash + } +) + +func (b balanceChange) revert(hooks *Hooks) { + if hooks.OnBalanceChange != nil { + hooks.OnBalanceChange(b.addr, b.new, b.prev, BalanceChangeRevert) + } +} + +func (n nonceChange) revert(hooks *Hooks) { + if hooks.OnNonceChange != nil { + hooks.OnNonceChange(n.addr, n.new, n.prev) + } +} + +func (c codeChange) revert(hooks *Hooks) { + if hooks.OnCodeChange != nil { + hooks.OnCodeChange(c.addr, c.newCodeHash, c.newCode, c.prevCodeHash, c.prevCode) + } +} + +func (s storageChange) revert(hooks *Hooks) { + if hooks.OnStorageChange != nil { + hooks.OnStorageChange(s.addr, s.slot, s.new, s.prev) + } +} diff --git a/core/tracing/journal_test.go b/core/tracing/journal_test.go new file mode 100644 index 000000000000..f9b613dbc6d9 --- /dev/null +++ b/core/tracing/journal_test.go @@ -0,0 +1,68 @@ +package tracing + +import ( + "errors" + "math/big" + "testing" + + "github.com/ethereum/go-ethereum/common" +) + +type testTracer struct { + bal *big.Int + nonce uint64 +} + +func (t *testTracer) OnBalanceChange(addr common.Address, prev *big.Int, new *big.Int, reason BalanceChangeReason) { + t.bal = new +} + +func (t *testTracer) OnNonceChange(addr common.Address, prev uint64, new uint64) { + t.nonce = new +} + +func TestJournalIntegration(t *testing.T) { + tr := &testTracer{} + wr, err := WrapWithJournal(&Hooks{OnBalanceChange: tr.OnBalanceChange, OnNonceChange: tr.OnNonceChange}) + if err != nil { + t.Fatalf("failed to wrap test tracer: %v", err) + } + addr := common.HexToAddress("0x1234") + wr.OnEnter(0, 0, addr, addr, nil, 1000, big.NewInt(0)) + wr.OnBalanceChange(addr, nil, big.NewInt(100), BalanceChangeUnspecified) + wr.OnEnter(1, 0, addr, addr, nil, 1000, big.NewInt(0)) + wr.OnNonceChange(addr, 0, 1) + wr.OnBalanceChange(addr, big.NewInt(100), big.NewInt(200), BalanceChangeUnspecified) + wr.OnBalanceChange(addr, big.NewInt(200), big.NewInt(250), BalanceChangeUnspecified) + wr.OnExit(0, nil, 100, errors.New("revert"), true) + wr.OnExit(0, nil, 150, nil, false) + if tr.bal.Cmp(big.NewInt(100)) != 0 { + t.Fatalf("unexpected balance: %v", tr.bal) + } + if tr.nonce != 0 { + t.Fatalf("unexpected nonce: %v", tr.nonce) + } +} + +func TestJournalTopRevert(t *testing.T) { + tr := &testTracer{} + wr, err := WrapWithJournal(&Hooks{OnBalanceChange: tr.OnBalanceChange, OnNonceChange: tr.OnNonceChange}) + if err != nil { + t.Fatalf("failed to wrap test tracer: %v", err) + } + addr := common.HexToAddress("0x1234") + wr.OnEnter(0, 0, addr, addr, nil, 1000, big.NewInt(0)) + wr.OnBalanceChange(addr, big.NewInt(0), big.NewInt(100), BalanceChangeUnspecified) + wr.OnEnter(1, 0, addr, addr, nil, 1000, big.NewInt(0)) + wr.OnNonceChange(addr, 0, 1) + wr.OnBalanceChange(addr, big.NewInt(100), big.NewInt(200), BalanceChangeUnspecified) + wr.OnBalanceChange(addr, big.NewInt(200), big.NewInt(250), BalanceChangeUnspecified) + wr.OnExit(0, nil, 100, errors.New("revert"), true) + wr.OnExit(0, nil, 150, errors.New("revert"), true) + if tr.bal.Cmp(big.NewInt(0)) != 0 { + t.Fatalf("unexpected balance: %v", tr.bal) + } + if tr.nonce != 0 { + t.Fatalf("unexpected nonce: %v", tr.nonce) + } +}