diff --git a/suave/builder/builder.go b/suave/builder/builder.go index 0d5c30ce9..40f11b32d 100644 --- a/suave/builder/builder.go +++ b/suave/builder/builder.go @@ -1,6 +1,8 @@ package builder import ( + "context" + "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/core" "github.com/ethereum/go-ethereum/core/state" @@ -10,12 +12,14 @@ import ( ) type builder struct { - config *builderConfig - txns []*types.Transaction - receipts []*types.Receipt - state *state.StateDB - gasPool *core.GasPool - gasUsed *uint64 + config *builderConfig + txns []*types.Transaction + receipts []*types.Receipt + state *state.StateDB + gasPool *core.GasPool + gasUsed *uint64 + ctx context.Context + cancelFunc context.CancelFunc } type builderConfig struct { @@ -28,12 +32,15 @@ type builderConfig struct { func newBuilder(config *builderConfig) *builder { gp := core.GasPool(config.header.GasLimit) var gasUsed uint64 + ctx, cancel := context.WithCancel(context.Background()) return &builder{ - config: config, - state: config.preState.Copy(), - gasPool: &gp, - gasUsed: &gasUsed, + config: config, + state: config.preState.Copy(), + gasPool: &gp, + gasUsed: &gasUsed, + ctx: ctx, + cancelFunc: cancel, } } @@ -75,3 +82,7 @@ func (b *builder) AddTransaction(txn *types.Transaction) (*types.SimulateTransac return result, nil } + +func (b *builder) Terminate() { + b.cancelFunc() +} diff --git a/suave/builder/session_manager.go b/suave/builder/session_manager.go index 91bd356f7..669481a78 100644 --- a/suave/builder/session_manager.go +++ b/suave/builder/session_manager.go @@ -12,6 +12,7 @@ import ( "github.com/ethereum/go-ethereum/core" "github.com/ethereum/go-ethereum/core/state" "github.com/ethereum/go-ethereum/core/types" + "github.com/ethereum/go-ethereum/event" "github.com/ethereum/go-ethereum/params" "github.com/google/uuid" ) @@ -29,6 +30,9 @@ type blockchain interface { // Config returns the chain config Config() *params.ChainConfig + + // SubscribeChainHeadEvent to subscribe to ChainHeadEvent + SubscribeChainHeadEvent(ch chan<- core.ChainHeadEvent) event.Subscription } type Config struct { @@ -44,6 +48,11 @@ type SessionManager struct { sessionsLock sync.RWMutex blockchain blockchain config *Config + subscription event.Subscription + chainHeadChan chan core.ChainHeadEvent + exitCh chan struct{} + closed bool + closeMu sync.RWMutex } func NewSessionManager(blockchain blockchain, config *Config) *SessionManager { @@ -68,12 +77,25 @@ func NewSessionManager(blockchain blockchain, config *Config) *SessionManager { sessionTimers: make(map[string]*time.Timer), blockchain: blockchain, config: config, + exitCh: make(chan struct{}), } + + s.chainHeadChan = make(chan core.ChainHeadEvent, 100) + s.subscription = s.blockchain.SubscribeChainHeadEvent(s.chainHeadChan) + go s.listenForChainHeadEvents() + return s } // NewSession creates a new builder session and returns the session id func (s *SessionManager) NewSession(ctx context.Context) (string, error) { + s.closeMu.RLock() + if s.closed { + s.closeMu.RUnlock() + return "", fmt.Errorf("session manager is closed") + } + s.closeMu.RUnlock() + // Wait for session to become available select { case <-s.sem: @@ -161,3 +183,58 @@ func (s *SessionManager) AddTransaction(sessionId string, tx *types.Transaction) } return builder.AddTransaction(tx) } + +func (s *SessionManager) listenForChainHeadEvents() { + for { + select { + case _, ok := <-s.chainHeadChan: + if !ok { + return + } + s.terminateAllSessions() + case <-s.exitCh: + return + } + } +} + +func (s *SessionManager) terminateAllSessions() error { + s.sessionsLock.Lock() + defer s.sessionsLock.Unlock() + + for id, session := range s.sessions { + session.Terminate() + + delete(s.sessions, id) + + if timer, exists := s.sessionTimers[id]; exists { + timer.Stop() + delete(s.sessionTimers, id) + } + + select { + case s.sem <- struct{}{}: + default: + return fmt.Errorf("released more sessions than are open") + } + } + return nil +} + +func (s *SessionManager) Close() { + s.closeMu.Lock() + defer s.closeMu.Unlock() + + if s.closed { + return + } + + close(s.exitCh) + + if s.subscription != nil { + s.subscription.Unsubscribe() + } + + s.terminateAllSessions() + s.closed = true +} diff --git a/suave/builder/session_manager_test.go b/suave/builder/session_manager_test.go index 015699124..8094190b3 100644 --- a/suave/builder/session_manager_test.go +++ b/suave/builder/session_manager_test.go @@ -4,19 +4,30 @@ import ( "context" "crypto/ecdsa" "math/big" + "sync" "testing" "time" "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/consensus" + "github.com/ethereum/go-ethereum/core" "github.com/ethereum/go-ethereum/core/rawdb" "github.com/ethereum/go-ethereum/core/state" "github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/crypto" + "github.com/ethereum/go-ethereum/event" "github.com/ethereum/go-ethereum/params" + "github.com/ethereum/go-ethereum/trie" "github.com/stretchr/testify/require" ) +type MockSubscription struct{} + +func (m *MockSubscription) Unsubscribe() {} +func (m *MockSubscription) Err() <-chan error { + return nil +} + func TestSessionManager_SessionTimeout(t *testing.T) { mngr, _ := newSessionManager(t, &Config{ SessionIdleTimeout: 500 * time.Millisecond, @@ -105,6 +116,95 @@ func TestSessionManager_StartSession(t *testing.T) { require.NotNil(t, receipt) } +func TestSessionManager_TerminateAllSessionsOnNewBlock(t *testing.T) { + mngr, bMock := newSessionManager(t, &Config{}) + + sessionIDs := make([]string, 3) + for i := 0; i < 3; i++ { + id, err := mngr.NewSession(context.TODO()) + require.NoError(t, err) + sessionIDs[i] = id + } + + require.Len(t, mngr.sessions, 3) + + bMock.triggerNewBlock() + + time.Sleep(100 * time.Millisecond) + + require.Empty(t, mngr.sessions) + + for _, id := range sessionIDs { + _, err := mngr.getSession(id) + require.Error(t, err) + } +} + +func TestSessionManager_Close(t *testing.T) { + mngr, _ := newSessionManager(t, &Config{}) + + id, err := mngr.NewSession(context.TODO()) + require.NoError(t, err) + + mngr.Close() + + require.Empty(t, mngr.sessions) + + _, err = mngr.getSession(id) + require.Error(t, err) + + _, err = mngr.NewSession(context.TODO()) + require.Error(t, err) + require.Contains(t, err.Error(), "session manager is closed") + + require.NotPanics(t, func() { mngr.Close() }) +} + +func TestSessionManager_ConcurrentAccess(t *testing.T) { + mngr, _ := newSessionManager(t, &Config{MaxConcurrentSessions: 10}) + + var wg sync.WaitGroup + for i := 0; i < 20; i++ { + wg.Add(1) + go func() { + defer wg.Done() + id, err := mngr.NewSession(context.TODO()) + if err == nil { + time.Sleep(10 * time.Millisecond) + _, err := mngr.getSession(id) + require.NoError(t, err) + } + }() + } + wg.Wait() + + require.LessOrEqual(t, len(mngr.sessions), 10) +} + +func TestSessionManager_TerminateOngoingTransactions(t *testing.T) { + mngr, bMock := newSessionManager(t, &Config{}) + + id, err := mngr.NewSession(context.TODO()) + require.NoError(t, err) + + done := make(chan struct{}) + go func() { + defer close(done) + time.Sleep(500 * time.Millisecond) + _, err := mngr.AddTransaction(id, bMock.state.newTransfer(t, common.Address{}, big.NewInt(1))) + require.Error(t, err) + }() + + time.Sleep(100 * time.Millisecond) + + bMock.triggerNewBlock() + + <-done + + _, err = mngr.getSession(id) + require.Error(t, err) +} + func newSessionManager(t *testing.T, cfg *Config) (*SessionManager, *blockchainMock) { if cfg == nil { cfg = &Config{} @@ -113,13 +213,39 @@ func newSessionManager(t *testing.T, cfg *Config) (*SessionManager, *blockchainM state := newMockState(t) bMock := &blockchainMock{ - state: state, + state: state, + chainHeadChan: make(chan core.ChainHeadEvent, 10), + blockNumber: 1, } return NewSessionManager(bMock, cfg), bMock } type blockchainMock struct { - state *mockState + state *mockState + chainHeadChan chan core.ChainHeadEvent + blockNumber uint64 +} + +func (b *blockchainMock) triggerNewBlock() { + b.chainHeadChan <- core.ChainHeadEvent{Block: types.NewBlock(&types.Header{Number: big.NewInt(int64(b.blockNumber))}, nil, nil, nil, trie.NewStackTrie(nil))} + b.blockNumber++ +} + +func (b *blockchainMock) SubscribeChainHeadEvent(ch chan<- core.ChainHeadEvent) event.Subscription { + return event.NewSubscription(func(quit <-chan struct{}) error { + for { + select { + case ev := <-b.chainHeadChan: + select { + case ch <- ev: + case <-quit: + return nil + } + case <-quit: + return nil + } + } + }) } func (b *blockchainMock) Engine() consensus.Engine {