Skip to content

Commit

Permalink
Add a mechanism to obtain immutable request parameters (#14)
Browse files Browse the repository at this point in the history
ID, headers, attempt headers and body should all be obtainable if you
need them.
  • Loading branch information
jackkleeman authored Jul 19, 2024
1 parent 0a30b99 commit 460d0cd
Show file tree
Hide file tree
Showing 5 changed files with 49 additions and 15 deletions.
18 changes: 18 additions & 0 deletions context.go
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,24 @@ type RunContext interface {
// By default, this logger will not output messages if the invocation is currently replaying
// The log handler can be set with `.WithLogger()` on the server object
Log() *slog.Logger

// Request gives extra information about the request that started this invocation
Request() *Request
}

type Request struct {
// The unique id that identifies the current function invocation. This id is guaranteed to be
// unique across invocations, but constant across reties and suspensions.
ID []byte
// Request headers - the following headers capture the original invocation headers, as provided to
// the ingress.
Headers map[string]string
// Attempt headers - the following headers are sent by the restate runtime.
// These headers are attempt specific, generated by the restate runtime uniquely for each attempt.
// These headers might contain information such as the W3C trace context, and attempt specific information.
AttemptHeaders map[string][]string
// Raw unparsed request body
Body []byte
}

// After is a handle on a Sleep operation which allows you to do other work concurrently
Expand Down
2 changes: 1 addition & 1 deletion internal/state/awakeable.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ func (c *Machine) awakeable() *futures.Awakeable {
c._awakeable,
)

return futures.NewAwakeable(c.suspensionCtx, c.id, entry, entryIndex)
return futures.NewAwakeable(c.suspensionCtx, c.request.ID, entry, entryIndex)
}

func (c *Machine) _awakeable() *wire.AwakeableEntryMessage {
Expand Down
34 changes: 24 additions & 10 deletions internal/state/state.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,10 @@ func (c *Context) Log() *slog.Logger {
return c.machine.userLog
}

func (c *Context) Request() *restate.Request {
return &c.machine.request
}

func (c *Context) Rand() *rand.Rand {
return c.machine.rand
}
Expand Down Expand Up @@ -266,8 +270,8 @@ type Machine struct {
protocol *wire.Protocol

// state
id []byte
key string
key string
request restate.Request

partial bool
current map[string][]byte
Expand All @@ -289,12 +293,15 @@ type Machine struct {
failure any
}

func NewMachine(handler restate.Handler, conn io.ReadWriter) *Machine {
func NewMachine(handler restate.Handler, conn io.ReadWriter, attemptHeaders map[string][]string) *Machine {
m := &Machine{
handler: handler,
current: make(map[string][]byte),
pendingAcks: map[uint32]wire.AckableMessage{},
pendingCompletions: map[uint32]wire.CompleteableMessage{},
request: restate.Request{
AttemptHeaders: attemptHeaders,
},
}
m.protocol = wire.NewProtocol(conn)
return m
Expand All @@ -317,8 +324,8 @@ func (m *Machine) Start(inner context.Context, dropReplayLogs bool, logHandler s

m.ctx = inner
m.suspensionCtx, m.suspend = context.WithCancelCause(m.ctx)
m.id = start.Id
m.rand = rand.New(m.id)
m.request.ID = start.Id
m.rand = rand.New(m.request.ID)
m.key = start.Key

logHandler = logHandler.WithAttrs([]slog.Attr{slog.String("invocationID", start.DebugId)})
Expand All @@ -331,7 +338,7 @@ func (m *Machine) Start(inner context.Context, dropReplayLogs bool, logHandler s
return m.process(ctx, start)
}

func (m *Machine) invoke(ctx *Context, input []byte, outputSeen bool) error {
func (m *Machine) invoke(ctx *Context, outputSeen bool) error {
// always terminate the invocation with
// an end message.
// this will always terminate the connection
Expand Down Expand Up @@ -485,9 +492,9 @@ The journal entry at position %d was:
var err error
switch handler := m.handler.(type) {
case restate.ObjectHandler:
bytes, err = handler.Call(ctx, input)
bytes, err = handler.Call(ctx, m.request.Body)
case restate.ServiceHandler:
bytes, err = handler.Call(ctx, input)
bytes, err = handler.Call(ctx, m.request.Body)
}

if err != nil && restate.IsTerminalError(err) {
Expand Down Expand Up @@ -580,9 +587,16 @@ func (m *Machine) process(ctx *Context, start *wire.StartMessage) error {
go m.handleCompletionsAcks()

inputMsg := msg.(*wire.InputEntryMessage)
value := inputMsg.GetValue()
return m.invoke(ctx, value, outputSeen)
m.request.Body = inputMsg.GetValue()

if len(inputMsg.GetHeaders()) > 0 {
m.request.Headers = make(map[string]string, len(inputMsg.Headers))
for _, header := range inputMsg.Headers {
m.request.Headers[header.Key] = header.Value
}
}

return m.invoke(ctx, outputSeen)
}

func (c *Machine) currentEntry() (wire.Message, bool) {
Expand Down
8 changes: 5 additions & 3 deletions internal/state/sys.go
Original file line number Diff line number Diff line change
Expand Up @@ -307,13 +307,15 @@ func (m *Machine) run(fn func(restate.RunContext) ([]byte, error)) ([]byte, erro

type runContext struct {
context.Context
log *slog.Logger
log *slog.Logger
request *restate.Request
}

func (r runContext) Log() *slog.Logger { return r.log }
func (r runContext) Log() *slog.Logger { return r.log }
func (r runContext) Request() *restate.Request { return r.request }

func (m *Machine) _run(fn func(restate.RunContext) ([]byte, error)) *wire.RunEntryMessage {
bytes, err := fn(runContext{m.ctx, m.userLog})
bytes, err := fn(runContext{m.ctx, m.userLog, &m.request})

if err != nil {
if restate.IsTerminalError(err) {
Expand Down
2 changes: 1 addition & 1 deletion server/restate.go
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,7 @@ func (r *Restate) callHandler(serviceProtocolVersion protocol.ServiceProtocolVer

defer conn.Close()

machine := state.NewMachine(handler, conn)
machine := state.NewMachine(handler, conn, request.Header)

if err := machine.Start(request.Context(), r.dropReplayLogs, r.logHandler); err != nil {
r.systemLog.LogAttrs(request.Context(), slog.LevelError, "Failed to handle invocation", log.Error(err))
Expand Down

0 comments on commit 460d0cd

Please sign in to comment.