diff --git a/go.mod b/go.mod index 16acc2a..e191db3 100644 --- a/go.mod +++ b/go.mod @@ -16,6 +16,7 @@ require ( require ( github.com/davecgh/go-spew v1.1.1 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect + golang.org/x/sync v0.8.0 // indirect golang.org/x/text v0.14.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index e3dd1f5..b84eb4e 100644 --- a/go.sum +++ b/go.sum @@ -14,6 +14,8 @@ github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsT github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= golang.org/x/net v0.23.0 h1:7EYJ93RZ9vYSZAIb2x3lnuvqO5zneoD6IvWjuhfxjTs= golang.org/x/net v0.23.0/go.mod h1:JKghWKKOSdJwpW2GEx0Ja7fmaKnMsbu+MWVZTokSYmg= +golang.org/x/sync v0.8.0 h1:3NFvSEYkUoMifnESzZl15y791HH1qU2xm6eCJU5ZPXQ= +golang.org/x/sync v0.8.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ= golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4= diff --git a/internal/state/completion.go b/internal/state/completion.go index a04ada6..b8bdd65 100644 --- a/internal/state/completion.go +++ b/internal/state/completion.go @@ -1,6 +1,7 @@ package state import ( + "context" "log/slog" "github.com/restatedev/sdk-go/internal/log" @@ -30,6 +31,11 @@ func (m *Machine) ackable(entryIndex uint32) wire.AckableMessage { } func (m *Machine) Write(message wire.Message) { + if m.ctx.Err() != nil { + // the main context being cancelled means the client is no longer interested in our response + // and so creating new entries is pointless and we should shut down the state machine. + panic(m.newClientGoneAway(context.Cause(m.ctx))) + } if message, ok := message.(wire.CompleteableMessage); ok && !message.Completed() { m.pendingMutex.Lock() m.pendingCompletions[m.entryIndex] = message diff --git a/internal/state/state.go b/internal/state/state.go index 706fedc..7306f89 100644 --- a/internal/state/state.go +++ b/internal/state/state.go @@ -267,7 +267,7 @@ type Machine struct { suspend func(error) handler restate.Handler - protocol *wire.Protocol + protocol wire.Protocol // state key string @@ -432,10 +432,17 @@ The journal entry at position %d was: m.log.LogAttrs(m.ctx, slog.LevelError, "Error sending failure message", log.Error(typ.err)) } + return + case *clientGoneAway: + m.log.LogAttrs(m.ctx, slog.LevelWarn, "Cancelling invocation as the incoming request context was cancelled", log.Error(typ.err)) return case *wire.SuspensionPanic: if m.ctx.Err() != nil { - m.log.WarnContext(m.ctx, "Cancelling invocation as the incoming request was cancelled") + // special case; awaiting a pre-existing sleep or awakeable doesn't create a new entry + // so doesn't hit the clientGoneAway code path, but instead it just looks like a suspension. + // so here we should differentiate between the causes; if the main context is cancelled, + // this isn't a suspension. + m.log.LogAttrs(m.ctx, slog.LevelWarn, "Cancelling invocation as the incoming request context was cancelled", log.Error(typ.Err)) return } if stderrors.Is(typ.Err, io.EOF) { @@ -637,6 +644,12 @@ func replayOrNew[M wire.Message, O any]( } defer m.entryMutex.Unlock() + if m.ctx.Err() != nil { + // the main context being cancelled means the client is no longer interested in our response + // and so creating new entries is pointless and we should shut down the state machine. + panic(m.newClientGoneAway(context.Cause(m.ctx))) + } + if m.failure != nil { // maybe the user will try to catch our panics, but we will just keep producing them panic(m.failure) @@ -676,3 +689,13 @@ func (m *Machine) newConcurrentContextUse(entry wire.Type) *concurrentContextUse m.failure = c return c } + +type clientGoneAway struct { + err error +} + +func (m *Machine) newClientGoneAway(err error) *clientGoneAway { + c := &clientGoneAway{err} + m.failure = c + return c +} diff --git a/internal/state/state_test.go b/internal/state/state_test.go new file mode 100644 index 0000000..b5536a2 --- /dev/null +++ b/internal/state/state_test.go @@ -0,0 +1,357 @@ +package state + +import ( + "context" + "fmt" + "io" + "log/slog" + "testing" + "time" + + restate "github.com/restatedev/sdk-go" + protocol "github.com/restatedev/sdk-go/generated/dev/restate/service" + "github.com/restatedev/sdk-go/internal/errors" + "github.com/restatedev/sdk-go/internal/wire" + "github.com/stretchr/testify/require" + "golang.org/x/sync/errgroup" +) + +type testParams struct { + input chan<- wire.Message + output <-chan wire.Message + wait func() error + cancel func() +} + +var clientDisconnectError = fmt.Errorf("client disconnected") + +func testHandler(handler restate.Handler) testParams { + machine := NewMachine(handler, nil, nil) + inputC := make(chan wire.Message) + outputC := make(chan wire.Message) + ctx, cancel := context.WithCancelCause(context.Background()) + machine.protocol = mockProtocol{input: inputC, output: outputC} + + eg := errgroup.Group{} + eg.Go(func() error { + return machine.Start(ctx, false, slog.Default().Handler()) + }) + + return testParams{inputC, outputC, eg.Wait, func() { cancel(clientDisconnectError) }} +} + +// closed request body should lead to suspension the next time we need a completion or ack +func TestRequestClosed(t *testing.T) { + var ctxErr error + var seenPanic any + var tp testParams + tp = testHandler(restate.NewServiceHandler(func(ctx restate.Context, _ restate.Void) (restate.Void, error) { + close(tp.input) + + // writing out journal entries still works - this shouldnt panic + after := ctx.After(time.Minute) + + ctxErr = ctx.Err() + + func() { + defer func() { + seenPanic = recover() + if seenPanic != nil { + panic(seenPanic) + } + }() + + // this should panic as it needs a completion + after.Done() + }() + + return restate.Void{}, nil + })) + + tp.input <- &wire.StartMessage{ + StartMessage: protocol.StartMessage{ + Id: []byte("abc"), + DebugId: "abc", + KnownEntries: 1, + }, + } + tp.input <- &wire.InputEntryMessage{InputEntryMessage: protocol.InputEntryMessage{}} + + _ = <-tp.output // sleep + _ = <-tp.output // suspension + + require.NoError(t, tp.wait()) + require.NoError(t, ctxErr, "invocation context was cancelled") + require.IsType(t, &wire.SuspensionPanic{}, seenPanic, "awaiting the sleep didn't create suspension panic") +} + +// closed http2 context (ie, client went away) should cancel the context provided and will lead to a panic on the +// next operation (write or await on previous entry) +func TestResponseClosed(t *testing.T) { + type test struct { + name string + beforeCancel func(ctx restate.Context) any + producedEntries int + afterCancel func(ctx restate.Context, setupState any) + expectedPanic any + } + + tests := []test{ + { + name: "awakeable should lead to client gone away panic", + afterCancel: func(ctx restate.Context, _ any) { + ctx.Awakeable() + }, + expectedPanic: &clientGoneAway{}, + }, + { + name: "starting run should lead to client gone away panic", + afterCancel: func(ctx restate.Context, _ any) { + ctx.Run(func(ctx restate.RunContext) (any, error) { + panic("run should not be executed") + }, restate.Void{}) + }, + expectedPanic: &clientGoneAway{}, + }, + { + name: "awaiting sleep should lead to suspension panic", + beforeCancel: func(ctx restate.Context) any { + return ctx.After(time.Minute) + }, + afterCancel: func(ctx restate.Context, setupState any) { + setupState.(restate.After).Done() + }, + producedEntries: 1, + expectedPanic: &wire.SuspensionPanic{}, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + var tp testParams + var ctxErr error + var seenPanic any + var state any + tp = testHandler(restate.NewServiceHandler(func(ctx restate.Context, _ restate.Void) (restate.Void, error) { + if test.beforeCancel != nil { + state = test.beforeCancel(ctx) + } + + tp.cancel() + + ctxErr = ctx.Err() + + func() { + defer func() { + seenPanic = recover() + if seenPanic != nil { + panic(seenPanic) + } + }() + + test.afterCancel(ctx, state) + }() + + return restate.Void{}, nil + })) + tp.input <- &wire.StartMessage{ + StartMessage: protocol.StartMessage{ + Id: []byte("abc"), + DebugId: "abc", + KnownEntries: 1, + }, + } + tp.input <- &wire.InputEntryMessage{InputEntryMessage: protocol.InputEntryMessage{}} + for i := 0; i < test.producedEntries; i++ { + <-tp.output + } + + require.NoError(t, tp.wait()) + require.Equal(t, context.Canceled, ctxErr, "invocation context wasnt cancelled") + require.IsType(t, test.expectedPanic, seenPanic, "unexpected panic") + }) + } +} + +// disconnect mid-run should cancel the run context and panic with a write error +func TestInFlightRunDisconnect(t *testing.T) { + var beforeCancelErr, afterCancelErr error + var seenPanic any + var tp testParams + tp = testHandler(restate.NewServiceHandler(func(ctx restate.Context, _ restate.Void) (restate.Void, error) { + func() { + defer func() { + seenPanic = recover() + if seenPanic != nil { + panic(seenPanic) + } + }() + + _ = ctx.Run(func(ctx restate.RunContext) (any, error) { + beforeCancelErr = ctx.Err() + tp.cancel() + afterCancelErr = ctx.Err() + + return nil, nil + }, restate.Void{}) + }() + + return restate.Void{}, nil + })) + + tp.input <- &wire.StartMessage{ + StartMessage: protocol.StartMessage{ + Id: []byte("abc"), + DebugId: "abc", + KnownEntries: 1, + }, + } + tp.input <- &wire.InputEntryMessage{InputEntryMessage: protocol.InputEntryMessage{}} + + require.NoError(t, tp.wait()) + require.Nil(t, beforeCancelErr, "run context should not be cancelled early") + require.Equal(t, context.Canceled, afterCancelErr, "run context should be cancelled") + require.IsType(t, &clientGoneAway{}, seenPanic, "after the run should lead to a client gone away panic") +} + +// suspension mid-run should commit the run result to the runtime, but then panic with suspension when +// trying to get the ack. +func TestInFlightRunSuspension(t *testing.T) { + var beforeCancelErr, afterCancelErr error + var seenPanic any + var tp testParams + tp = testHandler(restate.NewServiceHandler(func(ctx restate.Context, _ restate.Void) (restate.Void, error) { + func() { + defer func() { + seenPanic = recover() + if seenPanic != nil { + panic(seenPanic) + } + }() + + _ = ctx.Run(func(ctx restate.RunContext) (any, error) { + beforeCancelErr = ctx.Err() + close(tp.input) + afterCancelErr = ctx.Err() + + return nil, nil + }, restate.Void{}) + }() + + return restate.Void{}, nil + })) + + tp.input <- &wire.StartMessage{ + StartMessage: protocol.StartMessage{ + Id: []byte("abc"), + DebugId: "abc", + KnownEntries: 1, + }, + } + tp.input <- &wire.InputEntryMessage{InputEntryMessage: protocol.InputEntryMessage{}} + + <-tp.output // run + <-tp.output // output + + require.NoError(t, tp.wait()) + require.Nil(t, beforeCancelErr, "run context should not be cancelled before request closed") + require.Nil(t, afterCancelErr, "run context should not be cancelled after request closed") + require.IsType(t, &wire.SuspensionPanic{}, seenPanic, "after the run should lead to a suspension panic") +} + +func TestInvocationCanceled(t *testing.T) { + type test struct { + name string + fn func(ctx restate.Context) error + } + + tests := []test{ + { + name: "awakeable should return canceled error", + fn: func(ctx restate.Context) error { + awakeable := ctx.Awakeable() + return awakeable.Result(restate.Void{}) + }, + }, + { + name: "sleep should return canceled error", + fn: func(ctx restate.Context) error { + after := ctx.After(time.Minute) + return after.Done() + }, + }, + { + name: "call should return cancelled error", + fn: func(ctx restate.Context) error { + fut, err := ctx.Service("foo", "bar").RequestFuture(restate.Void{}) + if err != nil { + return err + } + return fut.Response(restate.Void{}) + }, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + var seenErr error + tp := testHandler(restate.NewServiceHandler(func(ctx restate.Context, _ restate.Void) (restate.Void, error) { + seenErr = test.fn(ctx) + return restate.Void{}, seenErr + })) + tp.input <- &wire.StartMessage{ + StartMessage: protocol.StartMessage{ + Id: []byte("abc"), + DebugId: "abc", + KnownEntries: 1, + }, + } + tp.input <- &wire.InputEntryMessage{InputEntryMessage: protocol.InputEntryMessage{}} + entry := <-tp.output // awakeable, sleep, or call entry + require.Implements(t, (*wire.CompleteableMessage)(nil), entry) + + // complete it with a cancellation + tp.input <- &wire.CompletionMessage{CompletionMessage: protocol.CompletionMessage{ + EntryIndex: 1, + Result: &protocol.CompletionMessage_Failure{ + Failure: &protocol.Failure{ + Code: 409, + Message: "canceled", + }, + }, + }} + + <-tp.output // output + <-tp.output // end + + require.NoError(t, tp.wait()) + require.Equal(t, &errors.CodeError{ + Code: 409, + Inner: &errors.TerminalError{ + Inner: fmt.Errorf("canceled"), + }, + }, seenErr) + }) + } +} + +type mockProtocol struct { + input <-chan wire.Message + output chan<- wire.Message +} + +var _ wire.Protocol = mockProtocol{} + +func (m mockProtocol) Read() (wire.Message, wire.Type, error) { + msg, ok := <-m.input + if !ok { + return nil, 0, io.EOF + } + + return msg, wire.MessageType(msg), nil +} + +func (m mockProtocol) Write(_ wire.Type, message wire.Message) error { + m.output <- message + return nil +} diff --git a/internal/wire/wire.go b/internal/wire/wire.go index b8c6ae3..0a7c5b8 100644 --- a/internal/wire/wire.go +++ b/internal/wire/wire.go @@ -104,6 +104,8 @@ func MessageType(message Message) Type { switch message.(type) { case *StartMessage: return StartMessageType + case *CompletionMessage: + return CompletionMessageType case *SuspensionMessage: return SuspensionMessageType case *InputEntryMessage: @@ -168,25 +170,32 @@ func (r *Reader) Next() <-chan ReaderMessage { return r.ch } -// Protocol implements the wire protocol to abstract receiving +type Protocol interface { + Read() (Message, Type, error) + Write(typ Type, message Message) error +} + +// Protocol implements the wire protoc to abstract receiving // and sending messages // Note that Protocol is not concurrent safe and it's up to the user // to make sure it's used correctly -type Protocol struct { +type protoc struct { stream io.ReadWriter } -func NewProtocol(stream io.ReadWriter) *Protocol { - return &Protocol{stream} +var _ Protocol = (*protoc)(nil) + +func NewProtocol(stream io.ReadWriter) *protoc { + return &protoc{stream} } // ReadHeader from stream -func (s *Protocol) header() (header Header, err error) { +func (s *protoc) header() (header Header, err error) { err = binary.Read(s.stream, binary.BigEndian, &header) return } -func (s *Protocol) Read() (Message, Type, error) { +func (s *protoc) Read() (Message, Type, error) { header, err := s.header() if err != nil { return nil, 0, fmt.Errorf("failed to read message header: %w", err) @@ -211,7 +220,7 @@ func (s *Protocol) Read() (Message, Type, error) { return msg, header.TypeCode, nil } -func (s *Protocol) Write(typ Type, message Message) error { +func (s *protoc) Write(typ Type, message Message) error { var flag Flag if message, ok := message.(CompleteableMessage); ok && message.Completed() {