From 2e4c0c634ac04b94c2ba73875fb09f830a428b55 Mon Sep 17 00:00:00 2001 From: Jack Kleeman Date: Fri, 12 Jul 2024 17:50:30 +0200 Subject: [PATCH 1/2] Move logs to slog While less performant, its in the stdlib which is probably more suitable for an SDK. Users can register a slog handler to intercept logs --- example/checkout.go | 12 ++-- example/main.go | 8 +-- example/ticket_service.go | 5 +- example/user_session.go | 5 +- go.mod | 4 -- go.sum | 16 ----- internal/log/log.go | 82 +++++++++++++++++++++ internal/state/call.go | 4 -- internal/state/completion.go | 31 ++++---- internal/state/state.go | 133 ++++++++++++++++++++--------------- internal/state/sys.go | 14 +++- internal/wire/wire.go | 58 +++++++-------- rcontext/rcontext.go | 28 ++++++++ router.go | 19 +++-- server/restate.go | 57 +++++++++++---- 15 files changed, 311 insertions(+), 165 deletions(-) create mode 100644 internal/log/log.go create mode 100644 rcontext/rcontext.go diff --git a/example/checkout.go b/example/checkout.go index 4ed1cde..2a22c4e 100644 --- a/example/checkout.go +++ b/example/checkout.go @@ -1,13 +1,11 @@ package main import ( - "context" "fmt" "math/rand" "github.com/google/uuid" restate "github.com/restatedev/sdk-go" - "github.com/rs/zerolog/log" ) type PaymentRequest struct { @@ -29,7 +27,7 @@ func (c *checkout) Name() string { const CheckoutServiceName = "Checkout" func (c *checkout) Payment(ctx restate.Context, request PaymentRequest) (response PaymentResponse, err error) { - uuid, err := restate.RunAs(ctx, func(ctx context.Context) (string, error) { + uuid, err := restate.RunAs(ctx, func(ctx restate.RunContext) (string, error) { uuid := uuid.New() return uuid.String(), nil }) @@ -45,13 +43,13 @@ func (c *checkout) Payment(ctx restate.Context, request PaymentRequest) (respons price := len(request.Tickets) * 30 response.Price = price - _, err = restate.RunAs(ctx, func(ctx context.Context) (bool, error) { - log := log.With().Str("uuid", uuid).Int("price", price).Logger() + _, err = restate.RunAs(ctx, func(ctx restate.RunContext) (bool, error) { + log := ctx.Log().With("uuid", uuid, "price", price) if rand.Float64() < 0.5 { - log.Info().Msg("payment succeeded") + log.Info("payment succeeded") return true, nil } else { - log.Error().Msg("payment failed") + log.Error("payment failed") return false, fmt.Errorf("failed to pay") } }) diff --git a/example/main.go b/example/main.go index c8a5e13..12574af 100644 --- a/example/main.go +++ b/example/main.go @@ -2,26 +2,22 @@ package main import ( "context" + "log/slog" "os" restate "github.com/restatedev/sdk-go" "github.com/restatedev/sdk-go/server" - "github.com/rs/zerolog" - "github.com/rs/zerolog/log" ) func main() { - log.Logger = log.Output(zerolog.ConsoleWriter{Out: os.Stderr}) - zerolog.SetGlobalLevel(zerolog.InfoLevel) - server := server.NewRestate(). Bind(restate.Object(&userSession{})). Bind(restate.Object(&ticketService{})). Bind(restate.Service(&checkout{})) if err := server.Start(context.Background(), ":9080"); err != nil { - log.Error().Err(err).Msg("application exited unexpectedly") + slog.Error("application exited unexpectedly", "err", err.Error()) os.Exit(1) } } diff --git a/example/ticket_service.go b/example/ticket_service.go index 8280e88..aa98ae8 100644 --- a/example/ticket_service.go +++ b/example/ticket_service.go @@ -4,7 +4,6 @@ import ( "errors" restate "github.com/restatedev/sdk-go" - "github.com/rs/zerolog/log" ) type TicketStatus int @@ -36,7 +35,7 @@ func (t *ticketService) Reserve(ctx restate.ObjectContext, _ restate.Void) (bool func (t *ticketService) Unreserve(ctx restate.ObjectContext, _ restate.Void) (void restate.Void, err error) { ticketId := ctx.Key() - log.Info().Str("ticket", ticketId).Msg("un-reserving ticket") + ctx.Log().Info("un-reserving ticket", "ticket", ticketId) status, err := restate.GetAs[TicketStatus](ctx, "status") if err != nil && !errors.Is(err, restate.ErrKeyNotFound) { return void, err @@ -52,7 +51,7 @@ func (t *ticketService) Unreserve(ctx restate.ObjectContext, _ restate.Void) (vo func (t *ticketService) MarkAsSold(ctx restate.ObjectContext, _ restate.Void) (void restate.Void, err error) { ticketId := ctx.Key() - log.Info().Str("ticket", ticketId).Msg("mark ticket as sold") + ctx.Log().Info("mark ticket as sold", "ticket", ticketId) status, err := restate.GetAs[TicketStatus](ctx, "status") if err != nil && !errors.Is(err, restate.ErrKeyNotFound) { diff --git a/example/user_session.go b/example/user_session.go index d8288d9..54b2a86 100644 --- a/example/user_session.go +++ b/example/user_session.go @@ -6,7 +6,6 @@ import ( "time" restate "github.com/restatedev/sdk-go" - "github.com/rs/zerolog/log" ) const UserSessionServiceName = "UserSession" @@ -81,7 +80,7 @@ func (u *userSession) Checkout(ctx restate.ObjectContext, _ restate.Void) (bool, return false, err } - log.Info().Strs("tickets", tickets).Msg("tickets in basket") + ctx.Log().Info("tickets in basket", "tickets", tickets) if len(tickets) == 0 { return false, nil @@ -95,7 +94,7 @@ func (u *userSession) Checkout(ctx restate.ObjectContext, _ restate.Void) (bool, return false, err } - log.Info().Str("id", response.ID).Int("price", response.Price).Msg("payment details") + ctx.Log().Info("payment details", "id", response.ID, "price", response.Price) for _, ticket := range tickets { call := ctx.ObjectSend(TicketServiceName, ticket, 0).Method("MarkAsSold") diff --git a/go.mod b/go.mod index f9f4eb6..975a6a1 100644 --- a/go.mod +++ b/go.mod @@ -5,7 +5,6 @@ go 1.22.0 require ( github.com/google/uuid v1.6.0 github.com/posener/h2conn v0.0.0-20231204025407-3997deeca0f0 - github.com/rs/zerolog v1.32.0 github.com/stretchr/testify v1.9.0 github.com/vmihailenco/msgpack/v5 v5.4.1 golang.org/x/net v0.21.0 @@ -14,11 +13,8 @@ require ( require ( github.com/davecgh/go-spew v1.1.1 // indirect - github.com/mattn/go-colorable v0.1.13 // indirect - github.com/mattn/go-isatty v0.0.19 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/vmihailenco/tagparser/v2 v2.0.0 // indirect - golang.org/x/sys v0.17.0 // indirect golang.org/x/text v0.14.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index 11597c7..d3a47ec 100644 --- a/go.sum +++ b/go.sum @@ -1,24 +1,13 @@ -github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= github.com/google/go-cmp v0.5.5 h1:Khx7svrCpmxxtHBq5j2mp/xVjsi8hQMfNLvJFAlrGgU= github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= -github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA= -github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg= -github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= -github.com/mattn/go-isatty v0.0.19 h1:JITubQf0MOLdlGRuRq+jtsDlekdYPia9ZFsB8h/APPA= -github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= -github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/posener/h2conn v0.0.0-20231204025407-3997deeca0f0 h1:zZg03nifrj6ayWNaDO8tNj57tqrOIKDwiBaLkhtK7Kk= github.com/posener/h2conn v0.0.0-20231204025407-3997deeca0f0/go.mod h1:bblJa8QcHntareAJYfLJUzLj42sUFBKCBeTDK5LyUrw= -github.com/rs/xid v1.5.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg= -github.com/rs/zerolog v1.32.0 h1:keLypqrlIjaFsbmJOBdB/qvyF8KEtCWHwobLp5l/mQ0= -github.com/rs/zerolog v1.32.0/go.mod h1:/7mN4D5sKwJLZQ2b/znpjC3/GQWY/xaDXUM0kKWRHss= github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= github.com/vmihailenco/msgpack/v5 v5.4.1 h1:cQriyiUvjTwOHg8QZaPihLWeRAAVoCpE00IUPn0Bjt8= @@ -27,11 +16,6 @@ github.com/vmihailenco/tagparser/v2 v2.0.0 h1:y09buUbR+b5aycVFQs/g70pqKVZNBmxwAh github.com/vmihailenco/tagparser/v2 v2.0.0/go.mod h1:Wri+At7QHww0WTrCBeu4J6bNtoV6mEfg5OIWRZA9qds= golang.org/x/net v0.21.0 h1:AQyQV4dYCvJ7vGmJyKki9+PBdyvhkSd8EIx/qb0AYv4= golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44= -golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.17.0 h1:25cE3gD+tdBA7lp7QfhuV+rJiE9YXTcS3VG1SqssI/Y= -golang.org/x/sys v0.17.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ= golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4= diff --git a/internal/log/log.go b/internal/log/log.go new file mode 100644 index 0000000..538ff5d --- /dev/null +++ b/internal/log/log.go @@ -0,0 +1,82 @@ +package log + +import ( + "context" + "fmt" + "log/slog" + "reflect" + "sync/atomic" + + "github.com/restatedev/sdk-go/rcontext" +) + +const ( + LevelTrace slog.Level = -8 +) + +type typeValue struct{ inner any } + +func (t typeValue) LogValue() slog.Value { + return slog.StringValue(reflect.TypeOf(t.inner).String()) +} + +func Type(key string, value any) slog.Attr { + return slog.Any(key, typeValue{value}) +} + +type stringerValue[T fmt.Stringer] struct{ inner T } + +func (t stringerValue[T]) LogValue() slog.Value { + return slog.StringValue(t.inner.String()) +} + +func Stringer[T fmt.Stringer](key string, value T) slog.Attr { + return slog.Any(key, slog.AnyValue(stringerValue[T]{value})) +} + +func Error(err error) slog.Attr { + return slog.String("err", err.Error()) +} + +type contextInjectingHandler struct { + logContext *atomic.Pointer[rcontext.LogContext] + dropReplay bool + inner slog.Handler +} + +func NewUserContextHandler(logContext *atomic.Pointer[rcontext.LogContext], dropReplay bool, inner slog.Handler) slog.Handler { + return &contextInjectingHandler{logContext, dropReplay, inner} +} + +func NewRestateContextHandler(inner slog.Handler) slog.Handler { + logContext := atomic.Pointer[rcontext.LogContext]{} + logContext.Store(&rcontext.LogContext{Source: rcontext.LogSourceRestate, IsReplaying: false}) + return &contextInjectingHandler{&logContext, false, inner} +} + +func (d *contextInjectingHandler) Enabled(ctx context.Context, l slog.Level) bool { + lc := d.logContext.Load() + if d.dropReplay && lc.IsReplaying { + return false + } + return d.inner.Enabled(rcontext.WithLogContext(ctx, lc), l) +} + +func (d *contextInjectingHandler) Handle(ctx context.Context, record slog.Record) error { + return d.inner.Handle(rcontext.WithLogContext(ctx, d.logContext.Load()), record) +} + +func (d *contextInjectingHandler) WithAttrs(attrs []slog.Attr) slog.Handler { + return &contextInjectingHandler{d.logContext, d.dropReplay, d.inner.WithAttrs(attrs)} +} + +func (d *contextInjectingHandler) WithGroup(name string) slog.Handler { + return &contextInjectingHandler{d.logContext, d.dropReplay, d.inner.WithGroup(name)} +} + +var _ slog.Handler = &contextInjectingHandler{} + +type dropReplayHandler struct { + isReplaying func() bool + inner slog.Handler +} diff --git a/internal/state/call.go b/internal/state/call.go index 4c987b6..a0e6281 100644 --- a/internal/state/call.go +++ b/internal/state/call.go @@ -91,8 +91,6 @@ func (m *Machine) doDynCall(service, key, method string, input any) (*wire.CallE } func (m *Machine) doCall(service, key, method string, params []byte) (*wire.CallEntryMessage, uint32) { - m.log.Debug().Str("service", service).Str("method", method).Str("key", key).Msg("executing sync call") - entry, entryIndex := replayOrNew( m, func(entry *wire.CallEntryMessage) *wire.CallEntryMessage { @@ -132,8 +130,6 @@ func (m *Machine) _doCall(service, key, method string, params []byte) *wire.Call } func (m *Machine) sendCall(service, key, method string, body any, delay time.Duration) error { - m.log.Debug().Str("service", service).Str("method", method).Str("key", key).Msg("executing async call") - params, err := json.Marshal(body) if err != nil { return err diff --git a/internal/state/completion.go b/internal/state/completion.go index 111e597..a04ada6 100644 --- a/internal/state/completion.go +++ b/internal/state/completion.go @@ -1,9 +1,9 @@ package state import ( - "errors" - "io" + "log/slog" + "github.com/restatedev/sdk-go/internal/log" "github.com/restatedev/sdk-go/internal/wire" ) @@ -40,7 +40,9 @@ func (m *Machine) Write(message wire.Message) { m.pendingAcks[m.entryIndex] = message m.pendingMutex.Unlock() } - if err := m.protocol.Write(message); err != nil { + typ := wire.MessageType(message) + m.log.LogAttrs(m.ctx, log.LevelTrace, "Sending message to runtime", log.Stringer("type", typ)) + if err := m.protocol.Write(typ, message); err != nil { panic(m.newWriteError(message, err)) } } @@ -59,33 +61,36 @@ func (m *Machine) newWriteError(entry wire.Message, err error) *writeError { func (m *Machine) handleCompletionsAcks() { for { - msg, err := m.protocol.Read() + msg, _, err := m.protocol.Read() if err != nil { - if errors.Is(err, io.EOF) { - m.log.Trace().Err(err).Msg("request body closed; next blocking operation will suspend") + if m.ctx.Err() == nil { + m.log.LogAttrs(m.ctx, log.LevelTrace, "Request body closed; next blocking operation will suspend") + m.suspend(err) } - m.suspend(err) return } switch msg := msg.(type) { case *wire.CompletionMessage: completable := m.completable(msg.EntryIndex) if completable == nil { - m.log.Error().Uint32("index", msg.EntryIndex).Msg("failed to find pending completion at index") + m.log.LogAttrs(m.ctx, slog.LevelError, "Failed to find pending completion at index", slog.Uint64("index", uint64(msg.EntryIndex))) continue } - completable.Complete(&msg.CompletionMessage) - m.log.Debug().Uint32("index", msg.EntryIndex).Msg("processed completion") + if err := completable.Complete(&msg.CompletionMessage); err != nil { + m.log.LogAttrs(m.ctx, slog.LevelError, "Failed to process completion", log.Error(err), slog.Uint64("index", uint64(msg.EntryIndex))) + } else { + m.log.LogAttrs(m.ctx, slog.LevelDebug, "Processed completion", slog.Uint64("index", uint64(msg.EntryIndex))) + } case *wire.EntryAckMessage: ackable := m.ackable(msg.EntryIndex) if ackable == nil { - m.log.Error().Uint32("index", msg.EntryIndex).Msg("failed to find pending ack at index") + m.log.LogAttrs(m.ctx, slog.LevelError, "Failed to find pending ack at index", slog.Uint64("index", uint64(msg.EntryIndex))) continue } ackable.Ack() - m.log.Debug().Uint32("index", msg.EntryIndex).Msg("processed ack") + m.log.LogAttrs(m.ctx, slog.LevelDebug, "Processed ack", slog.Uint64("index", uint64(msg.EntryIndex))) default: - m.log.Error().Type("type", msg).Msg("unexpected non-completion non-ack message during invocation") + m.log.LogAttrs(m.ctx, slog.LevelError, "Unexpected non-completion non-ack message during invocation", log.Type("type", msg)) continue } } diff --git a/internal/state/state.go b/internal/state/state.go index 4cf43c7..dfdd4cd 100644 --- a/internal/state/state.go +++ b/internal/state/state.go @@ -6,18 +6,19 @@ import ( stderrors "errors" "fmt" "io" + "log/slog" "runtime/debug" "sync" + "sync/atomic" "time" restate "github.com/restatedev/sdk-go" "github.com/restatedev/sdk-go/generated/proto/protocol" "github.com/restatedev/sdk-go/internal/errors" "github.com/restatedev/sdk-go/internal/futures" + "github.com/restatedev/sdk-go/internal/log" "github.com/restatedev/sdk-go/internal/wire" - - "github.com/rs/zerolog" - "github.com/rs/zerolog/log" + "github.com/restatedev/sdk-go/rcontext" ) const ( @@ -34,12 +35,17 @@ var ( type Context struct { context.Context - machine *Machine + userLogger *slog.Logger + machine *Machine } var _ restate.ObjectContext = &Context{} var _ restate.Context = &Context{} +func (c *Context) Log() *slog.Logger { + return c.machine.userLog +} + func (c *Context) Set(key string, value []byte) { c.machine.set(key, value) } @@ -103,7 +109,7 @@ func (c *Context) ObjectSend(service, key string, delay time.Duration) restate.S } } -func (c *Context) Run(fn func(ctx context.Context) ([]byte, error)) ([]byte, error) { +func (c *Context) Run(fn func(ctx restate.RunContext) ([]byte, error)) ([]byte, error) { return c.machine.run(fn) } @@ -158,7 +164,9 @@ type Machine struct { entryIndex uint32 entryMutex sync.Mutex - log zerolog.Logger + log *slog.Logger + userLog *slog.Logger + userLogContext atomic.Pointer[rcontext.LogContext] pendingCompletions map[uint32]wire.CompleteableMessage pendingAcks map[uint32]wire.AckableMessage @@ -171,18 +179,18 @@ func NewMachine(handler restate.Handler, conn io.ReadWriter) *Machine { m := &Machine{ handler: handler, current: make(map[string][]byte), - log: log.Logger, pendingAcks: map[uint32]wire.AckableMessage{}, pendingCompletions: map[uint32]wire.CompleteableMessage{}, } - m.protocol = wire.NewProtocol(&m.log, conn) + m.protocol = wire.NewProtocol(conn) return m } +func (m *Machine) Log() *slog.Logger { return m.log } + // Start starts the state machine -func (m *Machine) Start(inner context.Context, trace string) error { - // reader starts a rea - msg, err := m.protocol.Read() +func (m *Machine) Start(inner context.Context, dropReplayLogs bool, logHandler slog.Handler) error { + msg, _, err := m.protocol.Read() if err != nil { return err } @@ -198,12 +206,12 @@ func (m *Machine) Start(inner context.Context, trace string) error { m.id = start.Id m.key = start.Key - m.log = m.log.With().Str("id", start.DebugId).Str("method", trace).Logger() + logHandler = logHandler.WithAttrs([]slog.Attr{slog.String("invocationID", start.DebugId)}) - ctx := newContext(inner, m) + m.log = slog.New(log.NewRestateContextHandler(logHandler)) + m.userLog = slog.New(log.NewUserContextHandler(&m.userLogContext, dropReplayLogs, logHandler)) - m.log.Debug().Msg("start invocation") - defer m.log.Debug().Msg("invocation ended") + ctx := newContext(inner, m) return m.process(ctx, start) } @@ -227,15 +235,14 @@ func (m *Machine) invoke(ctx *Context, input []byte, outputSeen bool) error { expected, _ := json.Marshal(typ.expectedEntry) actual, _ := json.Marshal(typ.actualEntry) - m.log.Error(). - Type("expectedType", typ.expectedEntry). - RawJSON("expectedMessage", expected). - Type("actualType", typ.actualEntry). - RawJSON("actualMessage", actual). - Msg("Journal mismatch: Replayed journal entries did not correspond to the user code. The user code has to be deterministic!") + m.log.LogAttrs(m.ctx, slog.LevelError, "Journal mismatch: Replayed journal entries did not correspond to the user code. The user code has to be deterministic!", + log.Type("expectedType", typ.expectedEntry), + slog.String("expectedMessage", string(expected)), + log.Type("actualType", typ.actualEntry), + slog.String("actualMessage", string(actual))) // journal entry mismatch - if err := m.protocol.Write(&wire.ErrorMessage{ + if err := m.protocol.Write(wire.ErrorMessageType, &wire.ErrorMessage{ ErrorMessage: protocol.ErrorMessage{ Code: uint32(errors.ErrJournalMismatch), Message: fmt.Sprintf(`Journal mismatch: Replayed journal entries did not correspond to the user code. The user code has to be deterministic! @@ -243,23 +250,21 @@ The journal entry at position %d was: - In the user code: type: %T, message: %s - In the replayed messages: type: %T, message %s`, typ.entryIndex, typ.expectedEntry, string(expected), typ.actualEntry, string(actual)), - Description: string(debug.Stack()), RelatedEntryIndex: &typ.entryIndex, RelatedEntryType: wire.MessageType(typ.actualEntry).UInt32(), }, }); err != nil { - m.log.Error().Err(err).Msg("error sending failure message") + m.log.LogAttrs(m.ctx, slog.LevelError, "Error sending failure message", log.Error(err)) } return case *writeError: - m.log.Error().Err(typ.err).Msg("Failed to write entry to Restate, shutting down state machine") + m.log.LogAttrs(m.ctx, slog.LevelError, "Failed to write entry to Restate, shutting down state machine", log.Error(typ.err)) // don't even check for failure here because most likely the http2 conn is closed anyhow - _ = m.protocol.Write(&wire.ErrorMessage{ + _ = m.protocol.Write(wire.ErrorMessageType, &wire.ErrorMessage{ ErrorMessage: protocol.ErrorMessage{ Code: uint32(errors.ErrProtocolViolation), Message: typ.err.Error(), - Description: string(debug.Stack()), RelatedEntryIndex: &typ.entryIndex, RelatedEntryType: wire.MessageType(typ.entry).UInt32(), }, @@ -267,18 +272,17 @@ The journal entry at position %d was: return case *runFailure: - m.log.Error().Err(typ.err).Msg("Run returned a failure, returning error to Restate") + m.log.LogAttrs(m.ctx, slog.LevelError, "Run returned a failure, returning error to Restate", log.Error(typ.err)) - if err := m.protocol.Write(&wire.ErrorMessage{ + if err := m.protocol.Write(wire.ErrorMessageType, &wire.ErrorMessage{ ErrorMessage: protocol.ErrorMessage{ Code: uint32(restate.ErrorCode(typ.err)), Message: typ.err.Error(), - Description: string(debug.Stack()), RelatedEntryIndex: &typ.entryIndex, RelatedEntryType: wire.AwakeableEntryMessageType.UInt32(), }, }); err != nil { - m.log.Error().Err(err).Msg("error sending failure message") + m.log.LogAttrs(m.ctx, slog.LevelError, "Error sending failure message", log.Error(typ.err)) } return @@ -288,48 +292,53 @@ The journal entry at position %d was: return } if stderrors.Is(typ.Err, io.EOF) { - m.log.Info().Uints32("entryIndexes", typ.EntryIndexes).Msg("Suspending") + m.log.LogAttrs(m.ctx, slog.LevelInfo, "Suspending invocation", slog.Any("entryIndexes", typ.EntryIndexes)) - if err := m.protocol.Write(&wire.SuspensionMessage{ + if err := m.protocol.Write(wire.SuspensionMessageType, &wire.SuspensionMessage{ SuspensionMessage: protocol.SuspensionMessage{ EntryIndexes: typ.EntryIndexes, }, }); err != nil { - m.log.Error().Err(err).Msg("error sending suspension message") + m.log.LogAttrs(m.ctx, slog.LevelError, "Error sending suspension message", log.Error(err)) } } else { - m.log.Error().Err(typ.Err).Uints32("entryIndexes", typ.EntryIndexes).Msg("Unexpected error reading completions; shutting down state machine") + m.log.LogAttrs(m.ctx, slog.LevelError, "Unexpected error reading completions; shutting down state machine", log.Error(typ.Err), slog.Any("entryIndexes", typ.EntryIndexes)) // don't check for error here, most likely we will fail to send if we are in such a bad state - _ = m.protocol.Write(&wire.ErrorMessage{ + _ = m.protocol.Write(wire.ErrorMessageType, &wire.ErrorMessage{ ErrorMessage: protocol.ErrorMessage{ - Code: uint32(restate.ErrorCode(typ.Err)), - Message: fmt.Sprintf("problem reading completions: %v", typ.Err), - Description: string(debug.Stack()), + Code: uint32(restate.ErrorCode(typ.Err)), + Message: fmt.Sprintf("problem reading completions: %v", typ.Err), }, }) } return default: + m.log.LogAttrs(m.ctx, slog.LevelError, "Invocation panicked, returning error to Restate", slog.Any("err", typ)) + // unknown panic! // send an error message (retryable) - if err := m.protocol.Write(&wire.ErrorMessage{ + if err := m.protocol.Write(wire.ErrorMessageType, &wire.ErrorMessage{ ErrorMessage: protocol.ErrorMessage{ Code: 500, Message: fmt.Sprint(typ), Description: string(debug.Stack()), }, }); err != nil { - m.log.Error().Err(err).Msg("error sending failure message") + m.log.LogAttrs(m.ctx, slog.LevelError, "Error sending failure message", log.Error(err)) } return } }() + m.log.InfoContext(m.ctx, "Handling invocation") + if outputSeen { - return m.protocol.Write(&wire.EndMessage{}) + m.log.WarnContext(m.ctx, "Invocation already completed; ending immediately") + + return m.protocol.Write(wire.EndMessageType, &wire.EndMessage{}) } var bytes []byte @@ -341,13 +350,11 @@ The journal entry at position %d was: bytes, err = handler.Call(ctx, input) } - if err != nil { - m.log.Error().Err(err).Msg("failure") - } - if err != nil && restate.IsTerminalError(err) { + m.log.LogAttrs(m.ctx, slog.LevelError, "Invocation returned a terminal failure", log.Error(err)) + // terminal errors. - if err := m.protocol.Write(&wire.OutputEntryMessage{ + if err := m.protocol.Write(wire.OutputEntryMessageType, &wire.OutputEntryMessage{ OutputEntryMessage: protocol.OutputEntryMessage{ Result: &protocol.OutputEntryMessage_Failure{ Failure: &protocol.Failure{ @@ -359,17 +366,21 @@ The journal entry at position %d was: }); err != nil { return err } - return m.protocol.Write(&wire.EndMessage{}) + return m.protocol.Write(wire.EndMessageType, &wire.EndMessage{}) } else if err != nil { + m.log.LogAttrs(m.ctx, slog.LevelError, "Invocation returned a non-terminal failure", log.Error(err)) + // non terminal error - no end message - return m.protocol.Write(&wire.ErrorMessage{ + return m.protocol.Write(wire.ErrorMessageType, &wire.ErrorMessage{ ErrorMessage: protocol.ErrorMessage{ Code: uint32(restate.ErrorCode(err)), Message: err.Error(), }, }) } else { - if err := m.protocol.Write(&wire.OutputEntryMessage{ + m.log.InfoContext(m.ctx, "Invocation completed successfully") + + if err := m.protocol.Write(wire.OutputEntryMessageType, &wire.OutputEntryMessage{ OutputEntryMessage: protocol.OutputEntryMessage{ Result: &protocol.OutputEntryMessage_Value{ Value: bytes, @@ -379,7 +390,7 @@ The journal entry at position %d was: return err } - return m.protocol.Write(&wire.EndMessage{}) + return m.protocol.Write(wire.EndMessageType, &wire.EndMessage{}) } } @@ -389,7 +400,7 @@ func (m *Machine) process(ctx *Context, start *wire.StartMessage) error { } // expect input message - msg, err := m.protocol.Read() + msg, _, err := m.protocol.Read() if err != nil { return err } @@ -398,19 +409,27 @@ func (m *Machine) process(ctx *Context, start *wire.StartMessage) error { return wire.ErrUnexpectedMessage } - m.log.Trace().Uint32("known entries", start.KnownEntries).Msg("known entires") + m.log.LogAttrs(m.ctx, log.LevelTrace, "Received input message", slog.Uint64("knownEntries", uint64(start.KnownEntries))) m.entries = make([]wire.Message, 0, start.KnownEntries-1) + if start.KnownEntries > 1 { + // more than just an input message; will be at least one replay + m.userLogContext.Store(&rcontext.LogContext{Source: rcontext.LogSourceUser, IsReplaying: true}) + } else { + // only an input message; no replayed messages + m.userLogContext.Store(&rcontext.LogContext{Source: rcontext.LogSourceUser, IsReplaying: false}) + } outputSeen := false // we don't track the poll input entry for i := uint32(1); i < start.KnownEntries; i++ { - msg, err := m.protocol.Read() + msg, typ, err := m.protocol.Read() if err != nil { return fmt.Errorf("failed to read entry: %w", err) } - m.log.Trace().Type("type", msg).Msg("replay log entry") + m.log.LogAttrs(m.ctx, log.LevelTrace, "Received replay journal entry from runtime", log.Stringer("type", typ), slog.Uint64("index", uint64(i))) + m.entries = append(m.entries, msg) if _, ok := msg.(*wire.OutputEntryMessage); ok { @@ -466,6 +485,10 @@ func replayOrNew[M wire.Message, O any]( } m.entryIndex += 1 + if m.entryIndex == uint32(len(m.entries)) { + // this is a replay, but the next entry will not be a replay; log should now be allowed + m.userLogContext.Store(&rcontext.LogContext{Source: rcontext.LogSourceUser, IsReplaying: false}) + } // check if there is an entry as this index entry, ok := m.currentEntry() diff --git a/internal/state/sys.go b/internal/state/sys.go index 45db016..f5f33a1 100644 --- a/internal/state/sys.go +++ b/internal/state/sys.go @@ -4,6 +4,7 @@ import ( "bytes" "context" "fmt" + "log/slog" "sort" "time" @@ -274,7 +275,7 @@ func (m *Machine) _sleep(d time.Duration) *wire.SleepEntryMessage { return msg } -func (m *Machine) run(fn func(context.Context) ([]byte, error)) ([]byte, error) { +func (m *Machine) run(fn func(restate.RunContext) ([]byte, error)) ([]byte, error) { entry, entryIndex := replayOrNew( m, func(entry *wire.RunEntryMessage) *wire.RunEntryMessage { @@ -301,8 +302,15 @@ func (m *Machine) run(fn func(context.Context) ([]byte, error)) ([]byte, error) return nil, restate.TerminalError(fmt.Errorf("run entry had invalid result: %v", entry.Result), errors.ErrProtocolViolation) } -func (m *Machine) _run(fn func(context.Context) ([]byte, error)) *wire.RunEntryMessage { - bytes, err := fn(m.ctx) +type runContext struct { + context.Context + log *slog.Logger +} + +func (r runContext) Log() *slog.Logger { return r.log } + +func (m *Machine) _run(fn func(restate.RunContext) ([]byte, error)) *wire.RunEntryMessage { + bytes, err := fn(runContext{m.ctx, m.userLog}) if err != nil { if restate.IsTerminalError(err) { diff --git a/internal/wire/wire.go b/internal/wire/wire.go index 85967b8..e1d4bae 100644 --- a/internal/wire/wire.go +++ b/internal/wire/wire.go @@ -13,8 +13,6 @@ import ( _go "github.com/restatedev/sdk-go/generated/proto/go" protocol "github.com/restatedev/sdk-go/generated/proto/protocol" - "github.com/rs/zerolog" - "github.com/rs/zerolog/log" "google.golang.org/protobuf/proto" ) @@ -175,12 +173,11 @@ func (r *Reader) Next() <-chan ReaderMessage { // Note that Protocol is not concurrent safe and it's up to the user // to make sure it's used correctly type Protocol struct { - log *zerolog.Logger stream io.ReadWriter } -func NewProtocol(log *zerolog.Logger, stream io.ReadWriter) *Protocol { - return &Protocol{log, stream} +func NewProtocol(stream io.ReadWriter) *Protocol { + return &Protocol{stream} } // ReadHeader from stream @@ -189,33 +186,32 @@ func (s *Protocol) header() (header Header, err error) { return } -func (s *Protocol) Read() (Message, error) { +func (s *Protocol) Read() (Message, Type, error) { header, err := s.header() if err != nil { - return nil, fmt.Errorf("failed to read message header: %w", err) + return nil, 0, fmt.Errorf("failed to read message header: %w", err) } buf := make([]byte, header.Length) if _, err := io.ReadFull(s.stream, buf); err != nil { - return nil, fmt.Errorf("failed to read message body: %w", err) + return nil, 0, fmt.Errorf("failed to read message body: %w", err) } builder, ok := builders[header.TypeCode] if !ok { - return nil, fmt.Errorf("unknown message type '%d'", header.TypeCode) + return nil, 0, fmt.Errorf("unknown message type '%d'", header.TypeCode) } msg, err := builder(header, buf) if err != nil { - return nil, err + return nil, 0, err } - s.log.Trace().Stringer("type", header.TypeCode).Interface("msg", msg).Msg("received message") - return msg, nil + return msg, header.TypeCode, nil } -func (s *Protocol) Write(message Message) error { +func (s *Protocol) Write(typ Type, message Message) error { var flag Flag if message, ok := message.(CompleteableMessage); ok && message.Completed() { @@ -225,10 +221,6 @@ func (s *Protocol) Write(message Message) error { flag |= FlagRequiresAck } - typ := MessageType(message) - - s.log.Trace().Stringer("type", typ).Interface("msg", message).Msg("sending message to runtime") - bytes, err := proto.Marshal(message) if err != nil { return fmt.Errorf("failed to serialize message: %w", err) @@ -432,7 +424,7 @@ type GetStateEntryMessage struct { var _ CompleteableMessage = (*GetStateEntryMessage)(nil) -func (a *GetStateEntryMessage) Complete(c *protocol.CompletionMessage) { +func (a *GetStateEntryMessage) Complete(c *protocol.CompletionMessage) error { switch result := c.Result.(type) { case *protocol.CompletionMessage_Value: a.Result = &protocol.GetStateEntryMessage_Value{Value: result.Value} @@ -443,6 +435,7 @@ func (a *GetStateEntryMessage) Complete(c *protocol.CompletionMessage) { } a.complete() + return nil } type SetStateEntryMessage struct { @@ -467,25 +460,24 @@ type GetStateKeysEntryMessage struct { var _ CompleteableMessage = (*GetStateKeysEntryMessage)(nil) -func (a *GetStateKeysEntryMessage) Complete(c *protocol.CompletionMessage) { +func (a *GetStateKeysEntryMessage) Complete(c *protocol.CompletionMessage) error { switch result := c.Result.(type) { case *protocol.CompletionMessage_Value: var keys protocol.GetStateKeysEntryMessage_StateKeys if err := proto.Unmarshal(result.Value, &keys); err != nil { - log.Error().Err(err).Msg("received invalid value for getstatekeys") - return + return fmt.Errorf("received invalid value for getstatekeys: %w", err) } a.Result = &protocol.GetStateKeysEntryMessage_Value{Value: &keys} case *protocol.CompletionMessage_Failure: a.Result = &protocol.GetStateKeysEntryMessage_Failure{Failure: result.Failure} case *protocol.CompletionMessage_Empty: - log.Error().Msg("received empty completion for getstatekeys") - return + return fmt.Errorf("received empty completion for getstatekeys") } a.complete() + return nil } type CompletionMessage struct { @@ -500,18 +492,18 @@ type SleepEntryMessage struct { var _ CompleteableMessage = (*SleepEntryMessage)(nil) -func (a *SleepEntryMessage) Complete(c *protocol.CompletionMessage) { +func (a *SleepEntryMessage) Complete(c *protocol.CompletionMessage) error { switch result := c.Result.(type) { case *protocol.CompletionMessage_Empty: a.Result = &protocol.SleepEntryMessage_Empty{Empty: result.Empty} case *protocol.CompletionMessage_Failure: a.Result = &protocol.SleepEntryMessage_Failure{Failure: result.Failure} case *protocol.CompletionMessage_Value: - log.Error().Msg("received value completion for sleep") - return + return fmt.Errorf("received value completion for sleep") } a.complete() + return nil } type CallEntryMessage struct { @@ -521,18 +513,18 @@ type CallEntryMessage struct { var _ CompleteableMessage = (*CallEntryMessage)(nil) -func (a *CallEntryMessage) Complete(c *protocol.CompletionMessage) { +func (a *CallEntryMessage) Complete(c *protocol.CompletionMessage) error { switch result := c.Result.(type) { case *protocol.CompletionMessage_Value: a.Result = &protocol.CallEntryMessage_Value{Value: result.Value} case *protocol.CompletionMessage_Failure: a.Result = &protocol.CallEntryMessage_Failure{Failure: result.Failure} case *protocol.CompletionMessage_Empty: - log.Error().Msg("received empty completion for call") - return + return fmt.Errorf("received empty completion for call") } a.complete() + return nil } type OneWayCallEntryMessage struct { @@ -547,18 +539,18 @@ type AwakeableEntryMessage struct { var _ CompleteableMessage = (*AwakeableEntryMessage)(nil) -func (a *AwakeableEntryMessage) Complete(c *protocol.CompletionMessage) { +func (a *AwakeableEntryMessage) Complete(c *protocol.CompletionMessage) error { switch result := c.Result.(type) { case *protocol.CompletionMessage_Value: a.Result = &protocol.AwakeableEntryMessage_Value{Value: result.Value} case *protocol.CompletionMessage_Failure: a.Result = &protocol.AwakeableEntryMessage_Failure{Failure: result.Failure} case *protocol.CompletionMessage_Empty: - log.Error().Msg("received empty completion for an awakeable") - return + return fmt.Errorf("received empty completion for an awakeable") } a.complete() + return nil } type CompleteAwakeableEntryMessage struct { @@ -586,7 +578,7 @@ type CompleteableMessage interface { Done() <-chan struct{} Completed() bool Await(suspensionCtx context.Context, entryIndex uint32) - Complete(*protocol.CompletionMessage) + Complete(*protocol.CompletionMessage) error } type completable struct { diff --git a/rcontext/rcontext.go b/rcontext/rcontext.go new file mode 100644 index 0000000..cd8e8b5 --- /dev/null +++ b/rcontext/rcontext.go @@ -0,0 +1,28 @@ +package rcontext + +import "context" + +type LogSource int + +const ( + LogSourceRestate = iota + LogSourceUser +) + +type LogContext struct { + Source LogSource + IsReplaying bool +} + +type logContextKey struct{} + +func WithLogContext(parent context.Context, logContext *LogContext) context.Context { + return context.WithValue(parent, logContextKey{}, logContext) +} + +func LogContextFrom(ctx context.Context) *LogContext { + if val, ok := ctx.Value(logContextKey{}).(*LogContext); ok { + return val + } + return nil +} diff --git a/router.go b/router.go index 7bdffb4..6b02c8d 100644 --- a/router.go +++ b/router.go @@ -4,6 +4,7 @@ import ( "context" "encoding/json" "fmt" + "log/slog" "time" "github.com/restatedev/sdk-go/internal" @@ -47,7 +48,7 @@ type Selector interface { } type Context interface { - context.Context + RunContext // Sleep for the duration d Sleep(d time.Duration) @@ -76,7 +77,7 @@ type Context interface { // this stores the results of the function inside restate runtime so a replay // will produce the same value (think generating a unique id for example) // Note: use the RunAs helper function - Run(fn func(ctx context.Context) ([]byte, error)) ([]byte, error) + Run(fn func(RunContext) ([]byte, error)) ([]byte, error) Awakeable() Awakeable[[]byte] ResolveAwakeable(id string, value []byte) @@ -85,6 +86,16 @@ type Context interface { Selector(futs ...futures.Selectable) (Selector, error) } +// RunContext methods are the only methods safe to call from inside a .Run() +type RunContext interface { + context.Context + + // Log obtains a handle on a slog.Logger which already has some useful fields (invocationID and method) + // By default, this logger will not output messages if the invocation is currently replaying + // The log handler can be set with `.WithLogger()` on the server object + Log() *slog.Logger +} + // Router interface type Router interface { Name() string @@ -245,8 +256,8 @@ func SetAs[T any](ctx ObjectContext, key string, value T) error { // RunAs helper function runs a run function with specific concrete type as a result // it does encoding/decoding of bytes automatically using msgpack -func RunAs[T any](ctx Context, fn func(context.Context) (T, error)) (output T, err error) { - bytes, err := ctx.Run(func(ctx context.Context) ([]byte, error) { +func RunAs[T any](ctx Context, fn func(RunContext) (T, error)) (output T, err error) { + bytes, err := ctx.Run(func(ctx RunContext) ([]byte, error) { out, err := fn(ctx) if err != nil { return nil, err diff --git a/server/restate.go b/server/restate.go index 6d4d2ff..88354bc 100644 --- a/server/restate.go +++ b/server/restate.go @@ -4,6 +4,7 @@ import ( "context" "encoding/json" "fmt" + "log/slog" "net" "net/http" "runtime/debug" @@ -14,8 +15,8 @@ import ( "github.com/restatedev/sdk-go/generated/proto/discovery" "github.com/restatedev/sdk-go/generated/proto/protocol" "github.com/restatedev/sdk-go/internal" + "github.com/restatedev/sdk-go/internal/log" "github.com/restatedev/sdk-go/internal/state" - "github.com/rs/zerolog/log" "golang.org/x/net/http2" ) @@ -40,16 +41,34 @@ func init() { } type Restate struct { - routers map[string]restate.Router + logHandler slog.Handler + dropReplayLogs bool + systemLog *slog.Logger + routers map[string]restate.Router } // NewRestate creates a new instance of Restate server func NewRestate() *Restate { + handler := slog.Default().Handler() return &Restate{ - routers: make(map[string]restate.Router), + logHandler: handler, + systemLog: slog.New(log.NewRestateContextHandler(handler)), + dropReplayLogs: true, + routers: make(map[string]restate.Router), } } +// WithLogger overrides the slog handler used by the SDK (which defaults to the slog Default()) +// You may specify with dropReplayLogs whether to drop logs that originated from handler code +// while the invocation was replaying. If they are not dropped, you may still determine the replay +// status in a slog.Handler using rcontext.LogContextFrom(ctx) +func (r *Restate) WithLogger(h slog.Handler, dropReplayLogs bool) *Restate { + r.dropReplayLogs = dropReplayLogs + r.systemLog = slog.New(log.NewRestateContextHandler(h)) + r.logHandler = h + return r +} + func (r *Restate) Bind(router restate.Router) *Restate { if _, ok := r.routers[router.Name()]; ok { // panic because this is a programming error @@ -97,7 +116,7 @@ func (r *Restate) discover() (resource *internal.Endpoint, err error) { } func (r *Restate) discoverHandler(writer http.ResponseWriter, req *http.Request) { - log.Trace().Msg("discover called") + r.systemLog.DebugContext(req.Context(), "Processing discovery request") acceptVersionsString := req.Header.Get("accept") if acceptVersionsString == "" { @@ -134,7 +153,7 @@ func (r *Restate) discoverHandler(writer http.ResponseWriter, req *http.Request) writer.Header().Add("Content-Type", serviceDiscoveryProtocolVersionToHeaderValue(serviceDiscoveryProtocolVersion)) writer.WriteHeader(200) if _, err := writer.Write(bytes); err != nil { - log.Error().Err(err).Msg("failed to write discovery information") + r.systemLog.LogAttrs(req.Context(), slog.LevelError, "Failed to write discovery information", log.Error(err)) } } @@ -191,29 +210,35 @@ func serviceProtocolVersionToHeaderValue(serviceProtocolVersion protocol.Service panic(fmt.Sprintf("unexpected service protocol version %d", serviceProtocolVersion)) } +type serviceMethod struct { + service string + method string +} + // takes care of function call -func (r *Restate) callHandler(serviceProtocolVersion protocol.ServiceProtocolVersion, service, fn string, writer http.ResponseWriter, request *http.Request) { - log.Debug().Str("service", service).Str("handler", fn).Msg("new request") +func (r *Restate) callHandler(serviceProtocolVersion protocol.ServiceProtocolVersion, service, method string, writer http.ResponseWriter, request *http.Request) { + logger := r.systemLog.With("method", slog.StringValue(fmt.Sprintf("%s/%s", service, method))) writer.Header().Add("x-restate-server", X_RESTATE_SERVER) writer.Header().Add("content-type", serviceProtocolVersionToHeaderValue(serviceProtocolVersion)) router, ok := r.routers[service] if !ok { + logger.WarnContext(request.Context(), "Service not found") writer.WriteHeader(http.StatusNotFound) return } - handler, ok := router.Handlers()[fn] + handler, ok := router.Handlers()[method] if !ok { + logger.WarnContext(request.Context(), "Method not found on service") writer.WriteHeader(http.StatusNotFound) } writer.WriteHeader(200) conn, err := h2conn.Accept(writer, request) - if err != nil { - log.Error().Err(err).Msg("failed to upgrade connection") + logger.LogAttrs(request.Context(), slog.LevelError, "Failed to upgrade connection", log.Error(err)) return } @@ -221,14 +246,12 @@ func (r *Restate) callHandler(serviceProtocolVersion protocol.ServiceProtocolVer machine := state.NewMachine(handler, conn) - if err := machine.Start(request.Context(), fmt.Sprintf("%s/%s", service, fn)); err != nil { - log.Error().Err(err).Msg("failed to handle invocation") + if err := machine.Start(request.Context(), r.dropReplayLogs, r.logHandler); err != nil { + machine.Log().LogAttrs(request.Context(), slog.LevelError, "Failed to handle invocation", log.Error(err)) } } func (r *Restate) handler(writer http.ResponseWriter, request *http.Request) { - log.Trace().Str("proto", request.Proto).Str("method", request.Method).Str("path", request.RequestURI).Msg("got request") - if request.RequestURI == "/discover" { r.discoverHandler(writer, request) return @@ -236,6 +259,8 @@ func (r *Restate) handler(writer http.ResponseWriter, request *http.Request) { serviceProtocolVersionString := request.Header.Get("content-type") if serviceProtocolVersionString == "" { + r.systemLog.ErrorContext(request.Context(), "Missing content-type header") + writer.Write([]byte("missing content-type header")) writer.WriteHeader(http.StatusUnsupportedMediaType) @@ -245,6 +270,8 @@ func (r *Restate) handler(writer http.ResponseWriter, request *http.Request) { serviceProtocolVersion := parseServiceProtocolVersion(serviceProtocolVersionString) if !isServiceProtocolVersionSupported(serviceProtocolVersion) { + r.systemLog.LogAttrs(request.Context(), slog.LevelError, "Unsupported service protocol version", slog.String("version", serviceProtocolVersionString)) + writer.Write([]byte(fmt.Sprintf("Unsupported service protocol version '%s'", serviceProtocolVersionString))) writer.WriteHeader(http.StatusUnsupportedMediaType) @@ -254,12 +281,14 @@ func (r *Restate) handler(writer http.ResponseWriter, request *http.Request) { // we expecting the uri to be something like `/invoke/{service}/{method}` // so if !strings.HasPrefix(request.RequestURI, "/invoke/") { + r.systemLog.LogAttrs(request.Context(), slog.LevelError, "Invalid request path", slog.String("path", request.RequestURI)) writer.WriteHeader(http.StatusNotFound) return } parts := strings.Split(strings.TrimPrefix(request.RequestURI, "/invoke/"), "/") if len(parts) != 2 { + r.systemLog.LogAttrs(request.Context(), slog.LevelError, "Invalid request path", slog.String("path", request.RequestURI)) writer.WriteHeader(http.StatusNotFound) return } From c6c5a951d84ade751532b31e825e89e1b555de74 Mon Sep 17 00:00:00 2001 From: Jack Kleeman Date: Sat, 13 Jul 2024 12:05:33 +0200 Subject: [PATCH 2/2] Warn when client goes away --- internal/state/state.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/internal/state/state.go b/internal/state/state.go index dfdd4cd..fa20bee 100644 --- a/internal/state/state.go +++ b/internal/state/state.go @@ -288,7 +288,7 @@ The journal entry at position %d was: return case *wire.SuspensionPanic: if m.ctx.Err() != nil { - // the http2 request has been cancelled; just return because we can't send a response + m.log.WarnContext(m.ctx, "Cancelling invocation as the incoming request was cancelled") return } if stderrors.Is(typ.Err, io.EOF) {