Skip to content

Commit

Permalink
Implement workflows
Browse files Browse the repository at this point in the history
  • Loading branch information
jackkleeman committed Sep 23, 2024
1 parent da87f6a commit 13f654a
Show file tree
Hide file tree
Showing 16 changed files with 720 additions and 42 deletions.
17 changes: 16 additions & 1 deletion context.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ type Context interface {
inner() *state.Context
}

// ObjectContext is an extension of [Context] which is passed to shared-mode Virtual Object handlers,
// ObjectSharedContext is an extension of [Context] which is passed to shared-mode Virtual Object handlers,
// giving read-only access to a snapshot of state.
type ObjectSharedContext interface {
Context
Expand All @@ -40,3 +40,18 @@ type ObjectContext interface {
ObjectSharedContext
exclusiveObject()
}

// WorkflowSharedContext is an extension of [ObjectSharedContext] which is passed to shared-mode Workflow handlers,
// giving read-only access to a snapshot of state.
type WorkflowSharedContext interface {
ObjectSharedContext
workflow()
}

// WorkflowContext is an extension of [WorkflowSharedContext] and [ObjectContext] which is passed to Workflow 'run' handlers,
// giving mutable access to state.
type WorkflowContext interface {
WorkflowSharedContext
ObjectContext
runWorkflow()
}
47 changes: 45 additions & 2 deletions facilitators.go
Original file line number Diff line number Diff line change
Expand Up @@ -143,8 +143,8 @@ func Awakeable[T any](ctx Context, options ...options.AwakeableOption) Awakeable
type AwakeableFuture[T any] interface {
// Id returns the awakeable ID, which can be stored or sent to a another service
Id() string
// Result blocks on receiving the result of the awakeable, storing the value it was
// resolved with in output or otherwise returning the error it was rejected with.
// Result blocks on receiving the result of the awakeable, returning the value it was
// resolved or otherwise returning the error it was rejected with.
// It is *not* safe to call this in a goroutine - use Context.Select if you
// want to wait on multiple results at once.
Result() (T, error)
Expand Down Expand Up @@ -237,3 +237,46 @@ func Clear(ctx ObjectContext, key string) {
func ClearAll(ctx ObjectContext) {
ctx.inner().ClearAll()
}

// Promise returns a named Restate durable Promise that can be resolved or rejected during the workflow execution.
// The promise is bound to the workflow and will be persisted across suspensions and retries.
func Promise[T any](ctx WorkflowSharedContext, name string, options ...options.PromiseOption) DurablePromise[T] {
return durablePromise[T]{ctx.inner().Promise(name, options...)}
}

type DurablePromise[T any] interface {
// Result blocks on receiving the result of the Promise, returning the value it was
// resolved or otherwise returning the error it was rejected with or a cancellation error.
// It is *not* safe to call this in a goroutine - use Context.Select if you
// want to wait on multiple results at once.
Result() (T, error)
// Peek returns the value of the promise if it has been resolved. If it has not been resolved,
// the zero value of T is returned. To check explicitly for this case pass a pointer eg *string as T.
// If the promise was rejected or the invocation was cancelled, an error is returned.
Peek() (T, error)
// Resolve resolves the promise with a value, returning an error if it was already completed
// or if the invocation was cancelled.
Resolve(value T) error
// Reject rejects the promise with an error, returning an error if it was already completed
// or if the invocation was cancelled.
Reject(reason error) error
futures.Selectable
}

type durablePromise[T any] struct {
state.DecodingPromise
}

func (t durablePromise[T]) Result() (output T, err error) {
err = t.DecodingPromise.Result(&output)
return
}

func (t durablePromise[T]) Peek() (output T, err error) {
_, err = t.DecodingPromise.Peek(&output)
return
}

func (t durablePromise[T]) Resolve(value T) (err error) {
return t.DecodingPromise.Resolve(value)
}
99 changes: 98 additions & 1 deletion handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,15 @@ type ServiceHandlerFn[I any, O any] func(ctx Context, input I) (O, error)
// ObjectHandlerFn is the signature for a Virtual Object exclusive-mode handler function
type ObjectHandlerFn[I any, O any] func(ctx ObjectContext, input I) (O, error)

// ObjectHandlerFn is the signature for a Virtual Object shared-mode handler function
// ObjectSharedHandlerFn is the signature for a Virtual Object shared-mode handler function
type ObjectSharedHandlerFn[I any, O any] func(ctx ObjectSharedContext, input I) (O, error)

// ObjectHandlerFn is the signature for a Workflow 'Run' handler function
type WorkflowHandlerFn[I any, O any] func(ctx WorkflowContext, input I) (O, error)

// WorkflowSharedHandlerFn is the signature for a Workflow shared-mode handler function
type WorkflowSharedHandlerFn[I any, O any] func(ctx WorkflowSharedContext, input I) (O, error)

type serviceHandler[I any, O any] struct {
fn ServiceHandlerFn[I, O]
options options.HandlerOptions
Expand Down Expand Up @@ -135,6 +141,8 @@ func (o ctxWrapper) inner() *state.Context {
}
func (o ctxWrapper) object() {}
func (o ctxWrapper) exclusiveObject() {}
func (o ctxWrapper) workflow() {}
func (o ctxWrapper) runWorkflow() {}

func (h *objectHandler[I, O]) Call(ctx *state.Context, bytes []byte) ([]byte, error) {
var input I
Expand Down Expand Up @@ -186,3 +194,92 @@ func (h *objectHandler[I, O]) GetOptions() *options.HandlerOptions {
func (h *objectHandler[I, O]) HandlerType() *internal.ServiceHandlerType {
return &h.handlerType
}

type workflowHandler[I any, O any] struct {
// only one of workflowFn or sharedFn should be set, as indicated by handlerType
workflowFn WorkflowHandlerFn[I, O]
sharedFn WorkflowSharedHandlerFn[I, O]
options options.HandlerOptions
handlerType internal.ServiceHandlerType
}

var _ state.Handler = (*workflowHandler[struct{}, struct{}])(nil)

// NewWorkflowHandler converts a function of signature [WorkflowHandlerFn] into the 'Run' handler on a Workflow.
// The handler will have access to a full [WorkflowContext] which may mutate state.
func NewWorkflowHandler[I any, O any](fn WorkflowHandlerFn[I, O], opts ...options.HandlerOption) *workflowHandler[I, O] {
o := options.HandlerOptions{}
for _, opt := range opts {
opt.BeforeHandler(&o)
}
return &workflowHandler[I, O]{
workflowFn: fn,
options: o,
handlerType: internal.ServiceHandlerType_WORKFLOW,
}
}

// NewWorkflowSharedHandler converts a function of signature [ObjectSharedHandlerFn] into a shared-mode handler on a Workflow.
// The handler will only have access to a [WorkflowSharedContext] which can only read a snapshot of state.
func NewWorkflowSharedHandler[I any, O any](fn WorkflowSharedHandlerFn[I, O], opts ...options.HandlerOption) *workflowHandler[I, O] {
o := options.HandlerOptions{}
for _, opt := range opts {
opt.BeforeHandler(&o)
}
return &workflowHandler[I, O]{
sharedFn: fn,
options: o,
handlerType: internal.ServiceHandlerType_SHARED,
}
}

func (h *workflowHandler[I, O]) Call(ctx *state.Context, bytes []byte) ([]byte, error) {
var input I
if err := encoding.Unmarshal(h.options.Codec, bytes, &input); err != nil {
return nil, TerminalError(fmt.Errorf("request could not be decoded into handler input type: %w", err), http.StatusBadRequest)
}

var output O
var err error
switch h.handlerType {
case internal.ServiceHandlerType_EXCLUSIVE:
output, err = h.workflowFn(
ctxWrapper{ctx},
input,
)
case internal.ServiceHandlerType_SHARED:
output, err = h.sharedFn(
ctxWrapper{ctx},
input,
)
}
if err != nil {
return nil, err
}

bytes, err = encoding.Marshal(h.options.Codec, output)
if err != nil {
// we don't use a terminal error here as this is hot-fixable by changing the return type
return nil, fmt.Errorf("failed to serialize output: %w", err)
}

return bytes, nil
}

func (h *workflowHandler[I, O]) InputPayload() *encoding.InputPayload {
var i I
return encoding.InputPayloadFor(h.options.Codec, i)
}

func (h *workflowHandler[I, O]) OutputPayload() *encoding.OutputPayload {
var o O
return encoding.OutputPayloadFor(h.options.Codec, o)
}

func (h *workflowHandler[I, O]) GetOptions() *options.HandlerOptions {
return &h.options
}

func (h *workflowHandler[I, O]) HandlerType() *internal.ServiceHandlerType {
return &h.handlerType
}
35 changes: 35 additions & 0 deletions internal/futures/futures.go
Original file line number Diff line number Diff line change
Expand Up @@ -107,3 +107,38 @@ func (r *ResponseFuture) Response() ([]byte, error) {
func (r *ResponseFuture) getEntry() (wire.CompleteableMessage, uint32) {
return r.entry, r.entryIndex
}

type Promise struct {
suspensionCtx context.Context
invocationID []byte
entry *wire.GetPromiseEntryMessage
entryIndex uint32
getPromise func() (*wire.GetPromiseEntryMessage, uint32)
}

func NewPromise(suspensionCtx context.Context, invocationID []byte, getPromise func() (*wire.GetPromiseEntryMessage, uint32)) *Promise {
return &Promise{suspensionCtx, invocationID, nil, 0, getPromise}
}

func (c *Promise) Result() ([]byte, error) {
c.getEntry()

c.entry.Await(c.suspensionCtx, c.entryIndex)

switch result := c.entry.Result.(type) {
case *protocol.GetPromiseEntryMessage_Value:
return result.Value, nil
case *protocol.GetPromiseEntryMessage_Failure:
return nil, errors.ErrorFromFailure(result.Failure)
default:
return nil, fmt.Errorf("unexpected result in completed get promise entry: %v", c.entry.Result)
}
}

func (c *Promise) getEntry() (wire.CompleteableMessage, uint32) {
if c.entry == nil {
c.entry, c.entryIndex = c.getPromise()
}

return c.entry, c.entryIndex
}
8 changes: 8 additions & 0 deletions internal/options/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,14 @@ type AwakeableOption interface {
BeforeAwakeable(*AwakeableOptions)
}

type PromiseOptions struct {
Codec encoding.Codec
}

type PromiseOption interface {
BeforePromise(*PromiseOptions)
}

type ResolveAwakeableOptions struct {
Codec encoding.Codec
}
Expand Down
6 changes: 3 additions & 3 deletions internal/state/call.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ func (c *Client) RequestFuture(input any, opts ...options.RequestOption) Decodin

bytes, err := encoding.Marshal(c.options.Codec, input)
if err != nil {
panic(c.machine.newCodecFailure(fmt.Errorf("failed to marshal RequestFuture input: %w", err)))
panic(c.machine.newCodecFailure(wire.CallEntryMessageType, fmt.Errorf("failed to marshal RequestFuture input: %w", err)))
}

entry, entryIndex := c.machine.doCall(c.service, c.key, c.method, o.Headers, bytes)
Expand All @@ -56,7 +56,7 @@ func (d DecodingResponseFuture) Response(output any) (err error) {
}

if err := encoding.Unmarshal(d.options.Codec, bytes, output); err != nil {
panic(d.machine.newCodecFailure(fmt.Errorf("failed to unmarshal Call response into output: %w", err)))
panic(d.machine.newCodecFailure(wire.CallEntryMessageType, fmt.Errorf("failed to unmarshal Call response into output: %w", err)))
}

return nil
Expand All @@ -76,7 +76,7 @@ func (c *Client) Send(input any, opts ...options.SendOption) {

bytes, err := encoding.Marshal(c.options.Codec, input)
if err != nil {
panic(c.machine.newCodecFailure(fmt.Errorf("failed to marshal Send input: %w", err)))
panic(c.machine.newCodecFailure(wire.OneWayCallEntryMessageType, fmt.Errorf("failed to marshal Send input: %w", err)))
}
c.machine.sendCall(c.service, c.key, c.method, o.Headers, bytes, o.Delay)
return
Expand Down
Loading

0 comments on commit 13f654a

Please sign in to comment.