diff --git a/README.md b/README.md index 45b2f78..28b703f 100644 --- a/README.md +++ b/README.md @@ -14,6 +14,8 @@ - [X] Sleep - [x] Run - [x] Awakeable +- [x] Shared object handlers +- [ ] Workflows ## Basic usage @@ -63,8 +65,6 @@ Trying adding the same tickets again should return `false` since they are alread Finally checkout ```bash -curl localhost:8080/UserSession/azmy/Checkout \ - -H 'content-type: application/json' \ - -d 'null' +curl localhost:8080/UserSession/azmy/Checkout #{"response":true} ``` diff --git a/context.go b/context.go new file mode 100644 index 0000000..3bf3d65 --- /dev/null +++ b/context.go @@ -0,0 +1,159 @@ +package restate + +import ( + "context" + "log/slog" + "time" + + "github.com/restatedev/sdk-go/internal/futures" + "github.com/restatedev/sdk-go/internal/options" + "github.com/restatedev/sdk-go/internal/rand" +) + +type Context interface { + RunContext + + // Rand returns a random source which will give deterministic results for a given invocation + // The source wraps the stdlib rand.Rand but with some extra helper methods + // This source is not safe for use inside .Run() + Rand() *rand.Rand + + // Sleep for the duration d + Sleep(d time.Duration) + // After is an alternative to Context.Sleep which allows you to complete other tasks concurrently + // with the sleep. This is particularly useful when combined with Context.Select to race between + // the sleep and other Selectable operations. + After(d time.Duration) After + + // Service gets a Service accessor by service and method name + // Note: use the CallAs helper function to deserialise return values + Service(service, method string, opts ...options.CallOption) CallClient + + // Object gets a Object accessor by name, key and method name + // Note: use the CallAs helper function to receive serialised values + Object(object, key, method string, opts ...options.CallOption) CallClient + + // Run runs the function (fn), storing final results (including terminal errors) + // durably in the journal, or otherwise for transient errors stopping execution + // so Restate can retry the invocation. Replays will produce the same value, so + // all non-deterministic operations (eg, generating a unique ID) *must* happen + // inside Run blocks. + // Note: use the RunAs helper function to get typed output values instead of providing an output pointer + Run(fn func(RunContext) (any, error), output any, opts ...options.RunOption) error + + // Awakeable returns a Restate awakeable; a 'promise' to a future + // value or error, that can be resolved or rejected by other services. + // Note: use the AwakeableAs helper function to avoid having to pass a output pointer to Awakeable.Result() + Awakeable(options ...options.AwakeableOption) Awakeable + // ResolveAwakeable allows an awakeable (not necessarily from this service) to be + // resolved with a particular value. + ResolveAwakeable(id string, value any, options ...options.ResolveAwakeableOption) error + // ResolveAwakeable allows an awakeable (not necessarily from this service) to be + // rejected with a particular error. + RejectAwakeable(id string, reason error) + + // Select returns an iterator over blocking Restate operations (sleep, call, awakeable) + // which allows you to safely run them in parallel. The Selector will store the order + // that things complete in durably inside Restate, so that on replay the same order + // can be used. This avoids non-determinism. It is *not* safe to use goroutines or channels + // outside of Context.Run functions, as they do not behave deterministically. + Select(futs ...futures.Selectable) Selector +} + +// Awakeable is the Go representation of a Restate awakeable; a 'promise' to a future +// value or error, that can be resolved or rejected by other services. +type Awakeable interface { + // Id returns the awakeable ID, which can be stored or sent to a another service + Id() string + // Result blocks on receiving the result of the awakeable, storing the value it was + // resolved with in output or otherwise returning the error it was rejected with. + // It is *not* safe to call this in a goroutine - use Context.Select if you + // want to wait on multiple results at once. + // Note: use the AwakeableAs helper function to avoid having to pass a output pointer + Result(output any) error + futures.Selectable +} + +type CallClient interface { + // RequestFuture makes a call and returns a handle on a future response + RequestFuture(input any) (ResponseFuture, error) + // Request makes a call and blocks on getting the response which is stored in output + Request(input any, output any) error + SendClient +} + +type SendClient interface { + // Send makes a one-way call which is executed in the background + Send(input any, delay time.Duration) error +} + +type ResponseFuture interface { + // Response blocks on the response to the call and stores it in output, or returns the associated error + // It is *not* safe to call this in a goroutine - use Context.Select if you + // want to wait on multiple results at once. + Response(output any) error + futures.Selectable +} + +// Selector is an iterator over a list of blocking Restate operations that are running +// in the background. +type Selector interface { + // Remaining returns whether there are still operations that haven't been returned by Select(). + // There will always be exactly the same number of results as there were operations + // given to Context.Select + Remaining() bool + // Select blocks on the next completed operation + Select() futures.Selectable +} + +// RunContext methods are the only methods safe to call from inside a .Run() +type RunContext interface { + context.Context + + // Log obtains a handle on a slog.Logger which already has some useful fields (invocationID and method) + // By default, this logger will not output messages if the invocation is currently replaying + // The log handler can be set with `.WithLogger()` on the server object + Log() *slog.Logger +} + +// After is a handle on a Sleep operation which allows you to do other work concurrently +// with the sleep. +type After interface { + // Done blocks waiting on the remaining duration of the sleep. + // It is *not* safe to call this in a goroutine - use Context.Select if you + // want to wait on multiple results at once. + Done() + futures.Selectable +} + +type ObjectContext interface { + Context + KeyValueReader + KeyValueWriter +} + +type ObjectSharedContext interface { + Context + KeyValueReader +} + +type KeyValueReader interface { + // Get gets value associated with key and stores it in value + // If key does not exist, this function returns ErrKeyNotFound + // Note: Use GetAs generic helper function to avoid passing in a value pointer + Get(key string, value any, options ...options.GetOption) error + // Keys returns a list of all associated key + Keys() []string + // Key retrieves the key for this virtual object invocation. This is a no-op and is + // always safe to call. + Key() string +} + +type KeyValueWriter interface { + // Set sets a value against a key, using the provided codec (defaults to JSON) + Set(key string, value any, options ...options.SetOption) error + // Clear deletes a key + Clear(key string) + // ClearAll drops all stored state associated with key + ClearAll() +} diff --git a/encoding/encoding.go b/encoding/encoding.go index c7dd5eb..12286f1 100644 --- a/encoding/encoding.go +++ b/encoding/encoding.go @@ -2,10 +2,32 @@ package encoding import ( "encoding/json" + "fmt" + "reflect" "google.golang.org/protobuf/proto" ) +var ( + BinaryCodec PayloadCodec = binaryCodec{} + VoidCodec PayloadCodec = voidCodec{} + ProtoCodec PayloadCodec = protoCodec{} + JSONCodec PayloadCodec = jsonCodec{} + _ PayloadCodec = PairCodec{} +) + +type Void struct{} + +type Codec interface { + Marshal(v any) ([]byte, error) + Unmarshal(data []byte, v any) error +} + +type PayloadCodec interface { + Codec + InputPayload() *InputPayload + OutputPayload() *OutputPayload +} type InputPayload struct { Required bool `json:"required"` ContentType *string `json:"contentType,omitempty"` @@ -18,53 +40,185 @@ type OutputPayload struct { JsonSchema interface{} `json:"jsonSchema,omitempty"` } -type JSONDecoder[I any] struct{} +type voidCodec struct{} -func (j JSONDecoder[I]) InputPayload() *InputPayload { - return &InputPayload{Required: true, ContentType: proto.String("application/json")} +func (j voidCodec) InputPayload() *InputPayload { + return &InputPayload{} } -func (j JSONDecoder[I]) Decode(data []byte) (input I, err error) { - err = json.Unmarshal(data, &input) - return +func (j voidCodec) OutputPayload() *OutputPayload { + return &OutputPayload{} } -type JSONEncoder[O any] struct{} +func (j voidCodec) Unmarshal(data []byte, input any) (err error) { + return nil +} -func (j JSONEncoder[O]) OutputPayload() *OutputPayload { - return &OutputPayload{ContentType: proto.String("application/json")} +func (j voidCodec) Marshal(output any) ([]byte, error) { + return nil, nil } -func (j JSONEncoder[O]) Encode(output O) ([]byte, error) { - return json.Marshal(output) +type PairCodec struct { + Input PayloadCodec + Output PayloadCodec } -type MessagePointer[I any] interface { - proto.Message - *I +func (w PairCodec) InputPayload() *InputPayload { + return w.Input.InputPayload() } -type ProtoDecoder[I any, IP MessagePointer[I]] struct{} +func (w PairCodec) OutputPayload() *OutputPayload { + return w.Output.OutputPayload() +} -func (p ProtoDecoder[I, IP]) InputPayload() *InputPayload { - return &InputPayload{Required: true, ContentType: proto.String("application/proto")} +func (w PairCodec) Unmarshal(data []byte, v any) error { + return w.Input.Unmarshal(data, v) +} + +func (w PairCodec) Marshal(v any) ([]byte, error) { + return w.Output.Marshal(v) +} + +func MergeCodec(base, overlay PayloadCodec) PayloadCodec { + switch { + case base == nil && overlay == nil: + return nil + case base == nil: + return overlay + case overlay == nil: + return base + } + + basePair, baseOk := base.(PairCodec) + overlayPair, overlayOk := overlay.(PairCodec) + + switch { + case baseOk && overlayOk: + return PairCodec{ + Input: MergeCodec(basePair.Input, overlayPair.Input), + Output: MergeCodec(basePair.Output, overlayPair.Output), + } + case baseOk: + return PairCodec{ + Input: MergeCodec(basePair.Input, overlay), + Output: MergeCodec(basePair.Output, overlay), + } + case overlayOk: + return PairCodec{ + Input: MergeCodec(base, overlayPair.Input), + Output: MergeCodec(base, overlayPair.Output), + } + default: + // just two non-pairs; keep base + return base + } +} + +func PartialVoidCodec[I any, O any]() PayloadCodec { + var input I + var output O + _, inputVoid := any(input).(Void) + _, outputVoid := any(output).(Void) + switch { + case inputVoid && outputVoid: + return VoidCodec + case inputVoid: + return PairCodec{Input: VoidCodec, Output: nil} + case outputVoid: + return PairCodec{Input: nil, Output: VoidCodec} + default: + return nil + } +} + +type binaryCodec struct{} + +func (j binaryCodec) InputPayload() *InputPayload { + return &InputPayload{Required: true, ContentType: proto.String("application/octet-stream")} } -func (p ProtoDecoder[I, IP]) Decode(data []byte) (input IP, err error) { - // Unmarshal expects a non-nil pointer to a proto.Message implementing struct - // hence we must have a type parameter for the struct itself (I) and here we allocate - // a non-nil pointer of type IP - input = IP(new(I)) - err = proto.Unmarshal(data, input) - return +func (j binaryCodec) OutputPayload() *OutputPayload { + return &OutputPayload{ContentType: proto.String("application/octet-stream")} } -type ProtoEncoder[O proto.Message] struct{} +func (j binaryCodec) Unmarshal(data []byte, input any) (err error) { + switch input := input.(type) { + case *[]byte: + *input = data + return nil + default: + return fmt.Errorf("BinaryCodec.Unmarshal called with a type that is not *[]byte") + } +} + +func (j binaryCodec) Marshal(output any) ([]byte, error) { + switch output := output.(type) { + case []byte: + return output, nil + default: + return nil, fmt.Errorf("BinaryCodec.Marshal called with a type that is not []byte") + } +} -func (p ProtoEncoder[O]) OutputPayload() *OutputPayload { +type jsonCodec struct{} + +func (j jsonCodec) InputPayload() *InputPayload { + return &InputPayload{Required: true, ContentType: proto.String("application/json")} +} + +func (j jsonCodec) OutputPayload() *OutputPayload { + return &OutputPayload{ContentType: proto.String("application/json")} +} + +func (j jsonCodec) Unmarshal(data []byte, input any) (err error) { + return json.Unmarshal(data, &input) +} + +func (j jsonCodec) Marshal(output any) ([]byte, error) { + return json.Marshal(output) +} + +type protoCodec struct{} + +func (p protoCodec) InputPayload() *InputPayload { + return &InputPayload{Required: true, ContentType: proto.String("application/proto")} +} + +func (p protoCodec) OutputPayload() *OutputPayload { return &OutputPayload{ContentType: proto.String("application/proto")} } -func (p ProtoEncoder[O]) Encode(output O) ([]byte, error) { - return proto.Marshal(output) +func (p protoCodec) Unmarshal(data []byte, input any) (err error) { + switch input := input.(type) { + case proto.Message: + // called with a *Message + return proto.Unmarshal(data, input) + default: + // we must support being called with a **Message where *Message is nil because this is the result of new(I) where I is a proto.Message + // and calling with new(I) is really the only generic approach. + value := reflect.ValueOf(input) + if value.Kind() != reflect.Pointer || value.IsNil() || value.Elem().Kind() != reflect.Pointer { + return fmt.Errorf("ProtoCodec.Unmarshal called with neither a proto.Message nor a non-nil pointer to a type that implements proto.Message.") + } + elem := value.Elem() // hopefully a *Message + if elem.IsNil() { + // allocate a &Message and swap this in + elem.Set(reflect.New(elem.Type().Elem())) + } + switch elemI := elem.Interface().(type) { + case proto.Message: + return proto.Unmarshal(data, elemI) + default: + return fmt.Errorf("ProtoCodec.Unmarshal called with neither a proto.Message nor a non-nil pointer to a type that implements proto.Message.") + } + } +} + +func (p protoCodec) Marshal(output any) (data []byte, err error) { + switch output := output.(type) { + case proto.Message: + return proto.Marshal(output) + default: + return nil, fmt.Errorf("ProtoCodec.Marshal called with a type that is not a proto.Message") + } } diff --git a/encoding/encoding_test.go b/encoding/encoding_test.go new file mode 100644 index 0000000..d3bbbae --- /dev/null +++ b/encoding/encoding_test.go @@ -0,0 +1,73 @@ +package encoding + +import ( + "testing" + + "github.com/restatedev/sdk-go/generated/proto/protocol" +) + +func willPanic(t *testing.T, do func()) { + defer func() { + switch recover() { + case nil: + t.Fatalf("expected panic but didn't find one") + default: + return + } + }() + + do() +} + +func willSucceed(t *testing.T, err error) { + if err != nil { + t.Fatal(err) + } +} + +func checkMessage(t *testing.T, msg *protocol.AwakeableEntryMessage) { + if msg.Name != "foobar" { + t.Fatalf("unexpected msg.Name: %s", msg.Name) + } +} + +func TestProto(t *testing.T) { + p := ProtoCodec + + _, err := p.Marshal(protocol.AwakeableEntryMessage{Name: "foobar"}) + if err == nil { + t.Fatalf("expected error when marshaling non-pointer proto Message") + } + + bytes, err := p.Marshal(&protocol.AwakeableEntryMessage{Name: "foobar"}) + if err != nil { + t.Fatal(err) + } + + { + msg := &protocol.AwakeableEntryMessage{} + willSucceed(t, p.Unmarshal(bytes, msg)) + checkMessage(t, msg) + } + + { + inner := &protocol.AwakeableEntryMessage{} + msg := &inner + willSucceed(t, p.Unmarshal(bytes, msg)) + checkMessage(t, *msg) + } + + { + msg := new(*protocol.AwakeableEntryMessage) + willSucceed(t, p.Unmarshal(bytes, msg)) + checkMessage(t, *msg) + } + + { + var msg *protocol.AwakeableEntryMessage + willPanic(t, func() { + p.Unmarshal(bytes, msg) + }) + } + +} diff --git a/error.go b/error.go index e6b1683..0d69679 100644 --- a/error.go +++ b/error.go @@ -6,6 +6,10 @@ import ( "github.com/restatedev/sdk-go/internal/errors" ) +var ( + ErrKeyNotFound = errors.ErrKeyNotFound +) + // WithErrorCode returns an error with specific func WithErrorCode(err error, code errors.Code) error { if err == nil { diff --git a/example/checkout.go b/example/checkout.go index d6cd8a4..f222540 100644 --- a/example/checkout.go +++ b/example/checkout.go @@ -19,7 +19,7 @@ type PaymentResponse struct { type checkout struct{} -func (c *checkout) Name() string { +func (c *checkout) ServiceName() string { return CheckoutServiceName } diff --git a/example/main.go b/example/main.go index 12574af..f6422f3 100644 --- a/example/main.go +++ b/example/main.go @@ -12,9 +12,13 @@ import ( func main() { server := server.NewRestate(). + // Handlers can be inferred from object methods Bind(restate.Object(&userSession{})). Bind(restate.Object(&ticketService{})). - Bind(restate.Service(&checkout{})) + Bind(restate.Service(&checkout{})). + // Or created and registered explicitly + Bind(health). + Bind(bigCounter) if err := server.Start(context.Background(), ":9080"); err != nil { slog.Error("application exited unexpectedly", "err", err.Error()) diff --git a/example/ticket_service.go b/example/ticket_service.go index aa98ae8..1544157 100644 --- a/example/ticket_service.go +++ b/example/ticket_service.go @@ -18,7 +18,7 @@ const TicketServiceName = "TicketService" type ticketService struct{} -func (t *ticketService) Name() string { return TicketServiceName } +func (t *ticketService) ServiceName() string { return TicketServiceName } func (t *ticketService) Reserve(ctx restate.ObjectContext, _ restate.Void) (bool, error) { status, err := restate.GetAs[TicketStatus](ctx, "status") @@ -27,7 +27,7 @@ func (t *ticketService) Reserve(ctx restate.ObjectContext, _ restate.Void) (bool } if status == TicketAvailable { - return true, restate.SetAs(ctx, "status", TicketReserved) + return true, ctx.Set("status", TicketReserved) } return false, nil @@ -59,8 +59,20 @@ func (t *ticketService) MarkAsSold(ctx restate.ObjectContext, _ restate.Void) (v } if status == TicketReserved { - return void, restate.SetAs(ctx, "status", TicketSold) + return void, ctx.Set("status", TicketSold) } return void, nil } + +func (t *ticketService) Status(ctx restate.ObjectSharedContext, _ restate.Void) (TicketStatus, error) { + ticketId := ctx.Key() + ctx.Log().Info("mark ticket as sold", "ticket", ticketId) + + status, err := restate.GetAs[TicketStatus](ctx, "status") + if err != nil && !errors.Is(err, restate.ErrKeyNotFound) { + return status, err + } + + return status, nil +} diff --git a/example/user_session.go b/example/user_session.go index 54b2a86..6919b50 100644 --- a/example/user_session.go +++ b/example/user_session.go @@ -12,15 +12,15 @@ const UserSessionServiceName = "UserSession" type userSession struct{} -func (u *userSession) Name() string { +func (u *userSession) ServiceName() string { return UserSessionServiceName } func (u *userSession) AddTicket(ctx restate.ObjectContext, ticketId string) (bool, error) { userId := ctx.Key() - var success bool - if err := ctx.Object(TicketServiceName, ticketId).Method("Reserve").Request(userId).Response(&success); err != nil { + success, err := restate.CallAs[bool](ctx.Object(TicketServiceName, ticketId, "Reserve")).Request(userId) + if err != nil { return false, err } @@ -37,11 +37,11 @@ func (u *userSession) AddTicket(ctx restate.ObjectContext, ticketId string) (boo tickets = append(tickets, ticketId) - if err := restate.SetAs(ctx, "tickets", tickets); err != nil { + if err := ctx.Set("tickets", tickets); err != nil { return false, err } - if err := ctx.ObjectSend(UserSessionServiceName, ticketId, 15*time.Minute).Method("ExpireTicket").Request(ticketId); err != nil { + if err := ctx.Object(UserSessionServiceName, ticketId, "ExpireTicket").Send(ticketId, 15*time.Minute); err != nil { return false, err } @@ -66,11 +66,11 @@ func (u *userSession) ExpireTicket(ctx restate.ObjectContext, ticketId string) ( return void, nil } - if err := restate.SetAs(ctx, "tickets", tickets); err != nil { + if err := ctx.Set("tickets", tickets); err != nil { return void, err } - return void, ctx.ObjectSend(TicketServiceName, ticketId, 0).Method("Unreserve").Request(nil) + return void, ctx.Object(TicketServiceName, ticketId, "Unreserve").Send(nil, 0) } func (u *userSession) Checkout(ctx restate.ObjectContext, _ restate.Void) (bool, error) { @@ -86,19 +86,34 @@ func (u *userSession) Checkout(ctx restate.ObjectContext, _ restate.Void) (bool, return false, nil } - var response PaymentResponse - if err := ctx.Object(CheckoutServiceName, ""). - Method("Payment"). - Request(PaymentRequest{UserID: userId, Tickets: tickets}). - Response(&response); err != nil { + timeout := ctx.After(time.Minute) + + request, err := restate.CallAs[PaymentResponse](ctx.Object(CheckoutServiceName, "", "Payment")). + RequestFuture(PaymentRequest{UserID: userId, Tickets: tickets}) + if err != nil { + return false, err + } + + // race between the request and the timeout + switch ctx.Select(timeout, request).Select() { + case request: + // happy path + case timeout: + // we could choose to fail here with terminal error, but we'd also have to refund the payment! + ctx.Log().Warn("slow payment") + } + + // block on the eventual response + response, err := request.Response() + if err != nil { return false, err } ctx.Log().Info("payment details", "id", response.ID, "price", response.Price) for _, ticket := range tickets { - call := ctx.ObjectSend(TicketServiceName, ticket, 0).Method("MarkAsSold") - if err := call.Request(nil); err != nil { + call := ctx.Object(TicketServiceName, ticket, "MarkAsSold") + if err := call.Send(nil, 0); err != nil { return false, err } } diff --git a/example/utils.go b/example/utils.go new file mode 100644 index 0000000..6d05c7d --- /dev/null +++ b/example/utils.go @@ -0,0 +1,45 @@ +package main + +import ( + "fmt" + "math/big" + + restate "github.com/restatedev/sdk-go" +) + +var health = restate. + NewServiceRouter("health"). + Handler("ping", restate.NewServiceHandler( + func(restate.Context, struct{}) (restate.Void, error) { + return restate.Void{}, nil + })) + +var bigCounter = restate. + NewObjectRouter("bigCounter"). + Handler("add", restate.NewObjectHandler( + func(ctx restate.ObjectContext, deltaText string) (string, error) { + delta, ok := big.NewInt(0).SetString(deltaText, 10) + if !ok { + return "", restate.TerminalError(fmt.Errorf("input must be a valid integer string: %s", deltaText)) + } + + bytes, err := restate.GetAs[[]byte](ctx, "counter", restate.WithBinary) + if err != nil && err != restate.ErrKeyNotFound { + return "", err + } + newCount := big.NewInt(0).Add(big.NewInt(0).SetBytes(bytes), delta) + if err := ctx.Set("counter", newCount.Bytes(), restate.WithBinary); err != nil { + return "", err + } + + return newCount.String(), nil + })). + Handler("get", restate.NewObjectSharedHandler( + func(ctx restate.ObjectSharedContext, input restate.Void) (string, error) { + bytes, err := restate.GetAs[[]byte](ctx, "counter", restate.WithBinary) + if err != nil { + return "", err + } + + return big.NewInt(0).SetBytes(bytes).String(), err + })) diff --git a/facilitators.go b/facilitators.go new file mode 100644 index 0000000..6af240c --- /dev/null +++ b/facilitators.go @@ -0,0 +1,101 @@ +package restate + +import ( + "github.com/restatedev/sdk-go/internal/futures" + "github.com/restatedev/sdk-go/internal/options" +) + +// GetAs helper function to get a key, returning a typed response instead of accepting a pointer. +// If there is no associated value with key, an error ErrKeyNotFound is returned +func GetAs[T any](ctx ObjectSharedContext, key string, options ...options.GetOption) (output T, err error) { + err = ctx.Get(key, &output, options...) + return +} + +// RunAs helper function runs a Run function, returning a typed response instead of accepting a pointer +func RunAs[T any](ctx Context, fn func(RunContext) (T, error), options ...options.RunOption) (output T, err error) { + err = ctx.Run(func(ctx RunContext) (any, error) { + return fn(ctx) + }, &output, options...) + + return +} + +// TypedAwakeable is an extension of Awakeable which returns typed responses instead of accepting a pointer +type TypedAwakeable[T any] interface { + // Id returns the awakeable ID, which can be stored or sent to a another service + Id() string + // Result blocks on receiving the result of the awakeable, storing the value it was + // resolved with in output or otherwise returning the error it was rejected with. + // It is *not* safe to call this in a goroutine - use Context.Select if you + // want to wait on multiple results at once. + Result() (T, error) + futures.Selectable +} + +type typedAwakeable[T any] struct { + Awakeable +} + +func (t typedAwakeable[T]) Result() (output T, err error) { + err = t.Awakeable.Result(&output) + return +} + +// AwakeableAs helper function to treat awakeable results as a particular type. +func AwakeableAs[T any](ctx Context, options ...options.AwakeableOption) TypedAwakeable[T] { + return typedAwakeable[T]{ctx.Awakeable(options...)} +} + +// TypedCallClient is an extension of CallClient which returns typed responses instead of accepting a pointer +type TypedCallClient[O any] interface { + // RequestFuture makes a call and returns a handle on a future response + RequestFuture(input any) (TypedResponseFuture[O], error) + // Request makes a call and blocks on getting the response + Request(input any) (O, error) + SendClient +} + +type typedCallClient[O any] struct { + CallClient +} + +func (t typedCallClient[O]) Request(input any) (output O, err error) { + fut, err := t.CallClient.RequestFuture(input) + if err != nil { + return output, err + } + err = fut.Response(&output) + return +} + +func (t typedCallClient[O]) RequestFuture(input any) (TypedResponseFuture[O], error) { + fut, err := t.CallClient.RequestFuture(input) + if err != nil { + return nil, err + } + return typedResponseFuture[O]{fut}, nil +} + +// TypedResponseFuture is an extension of ResponseFuture which returns typed responses instead of accepting a pointer +type TypedResponseFuture[O any] interface { + // Response blocks on the response to the call and returns it or the associated error + // It is *not* safe to call this in a goroutine - use Context.Select if you + // want to wait on multiple results at once. + Response() (O, error) + futures.Selectable +} + +type typedResponseFuture[O any] struct { + ResponseFuture +} + +func (t typedResponseFuture[O]) Response() (output O, err error) { + err = t.ResponseFuture.Response(&output) + return +} + +// CallAs helper function to get typed responses instead of passing in a pointer +func CallAs[O any](client CallClient) TypedCallClient[O] { + return typedCallClient[O]{client} +} diff --git a/handler.go b/handler.go index af66b28..37a8b6d 100644 --- a/handler.go +++ b/handler.go @@ -1,78 +1,80 @@ package restate import ( - "encoding/json" "fmt" + "net/http" "github.com/restatedev/sdk-go/encoding" + "github.com/restatedev/sdk-go/internal" ) -// Void is a placeholder used usually for functions that their signature require that -// you accept an input or return an output but the function implementation does not -// require them -type Void struct{} +// Void is a placeholder to signify 'no value' where a type is otherwise needed. It can be used in several contexts: +// 1. Input types for handlers - the request payload codec will default to a encoding.VoidCodec which will reject input at the ingress +// 2. Output types for handlers - the response payload codec will default to a encoding.VoidCodec which will send no bytes and set no content-type +type Void = encoding.Void -type VoidDecoder struct{} +type ObjectHandler interface { + Call(ctx ObjectContext, request []byte) (output []byte, err error) + getOptions() *objectHandlerOptions + Handler +} -func (v VoidDecoder) InputPayload() *encoding.InputPayload { - return &encoding.InputPayload{} +type ServiceHandler interface { + Call(ctx Context, request []byte) (output []byte, err error) + getOptions() *serviceHandlerOptions + Handler } -func (v VoidDecoder) Decode(data []byte) (input Void, err error) { - if len(data) > 0 { - err = fmt.Errorf("restate.Void decoder expects no request data") - } - return +type Handler interface { + sealed() + InputPayload() *encoding.InputPayload + OutputPayload() *encoding.OutputPayload + HandlerType() *internal.ServiceHandlerType } -type VoidEncoder struct{} +// ServiceHandlerFn signature of service (unkeyed) handler function +type ServiceHandlerFn[I any, O any] func(ctx Context, input I) (O, error) -func (v VoidEncoder) OutputPayload() *encoding.OutputPayload { - return &encoding.OutputPayload{} -} +// ObjectHandlerFn signature for object (keyed) handler function +type ObjectHandlerFn[I any, O any] func(ctx ObjectContext, input I) (O, error) -func (v VoidEncoder) Encode(output Void) ([]byte, error) { - return nil, nil +// ObjectHandlerFn signature for object (keyed) handler function that can run concurrently with other handlers against a snapshot of state +type ObjectSharedHandlerFn[I any, O any] func(ctx ObjectSharedContext, input I) (O, error) + +type serviceHandlerOptions struct { + codec encoding.PayloadCodec } type serviceHandler[I any, O any] struct { fn ServiceHandlerFn[I, O] - decoder Decoder[I] - encoder Encoder[O] + options serviceHandlerOptions } -// NewJSONServiceHandler create a new handler for a service using JSON encoding -func NewJSONServiceHandler[I any, O any](fn ServiceHandlerFn[I, O]) *serviceHandler[I, O] { - return &serviceHandler[I, O]{ - fn: fn, - decoder: encoding.JSONDecoder[I]{}, - encoder: encoding.JSONEncoder[O]{}, - } -} +var _ ServiceHandler = (*serviceHandler[struct{}, struct{}])(nil) -// NewProtoServiceHandler create a new handler for a service using protobuf encoding -// Input and output type must both be pointers that satisfy proto.Message -func NewProtoServiceHandler[I any, O any, IP encoding.MessagePointer[I], OP encoding.MessagePointer[O]](fn ServiceHandlerFn[IP, OP]) *serviceHandler[IP, OP] { - return &serviceHandler[IP, OP]{ - fn: fn, - decoder: encoding.ProtoDecoder[I, IP]{}, - encoder: encoding.ProtoEncoder[OP]{}, - } +type ServiceHandlerOption interface { + beforeServiceHandler(*serviceHandlerOptions) } -// NewServiceHandlerWithEncoders create a new handler for a service using a custom encoder/decoder implementation -func NewServiceHandlerWithEncoders[I any, O any](fn ServiceHandlerFn[I, O], decoder Decoder[I], encoder Encoder[O]) *serviceHandler[I, O] { +// NewServiceHandler create a new handler for a service, defaulting to JSON encoding +func NewServiceHandler[I any, O any](fn ServiceHandlerFn[I, O], options ...ServiceHandlerOption) *serviceHandler[I, O] { + opts := serviceHandlerOptions{} + for _, opt := range options { + opt.beforeServiceHandler(&opts) + } + if opts.codec == nil { + opts.codec = encoding.PartialVoidCodec[I, O]() + } return &serviceHandler[I, O]{ fn: fn, - decoder: decoder, - encoder: encoder, + options: opts, } } func (h *serviceHandler[I, O]) Call(ctx Context, bytes []byte) ([]byte, error) { - input, err := h.decoder.Decode(bytes) - if err != nil { - return nil, TerminalError(fmt.Errorf("request could not be decoded into handler input type: %w", err)) + var input I + if err := h.options.codec.Unmarshal(bytes, &input); err != nil { + return nil, TerminalError(fmt.Errorf("request could not be decoded into handler input type: %w", err), http.StatusBadRequest) } output, err := h.fn( @@ -83,7 +85,7 @@ func (h *serviceHandler[I, O]) Call(ctx Context, bytes []byte) ([]byte, error) { return nil, err } - bytes, err = h.encoder.Encode(output) + bytes, err = h.options.codec.Marshal(output) if err != nil { return nil, TerminalError(fmt.Errorf("failed to serialize output: %w", err)) } @@ -92,44 +94,96 @@ func (h *serviceHandler[I, O]) Call(ctx Context, bytes []byte) ([]byte, error) { } func (h *serviceHandler[I, O]) InputPayload() *encoding.InputPayload { - return h.decoder.InputPayload() + return h.options.codec.InputPayload() } func (h *serviceHandler[I, O]) OutputPayload() *encoding.OutputPayload { - return h.encoder.OutputPayload() + return h.options.codec.OutputPayload() +} + +func (h *serviceHandler[I, O]) HandlerType() *internal.ServiceHandlerType { + return nil +} + +func (h *serviceHandler[I, O]) getOptions() *serviceHandlerOptions { + return &h.options } func (h *serviceHandler[I, O]) sealed() {} +type objectHandlerOptions struct { + codec encoding.PayloadCodec +} + type objectHandler[I any, O any] struct { - fn ObjectHandlerFn[I, O] + // only one of exclusiveFn or sharedFn should be set, as indicated by handlerType + exclusiveFn ObjectHandlerFn[I, O] + sharedFn ObjectSharedHandlerFn[I, O] + options objectHandlerOptions + handlerType internal.ServiceHandlerType +} + +var _ ObjectHandler = (*objectHandler[struct{}, struct{}])(nil) + +type ObjectHandlerOption interface { + beforeObjectHandler(*objectHandlerOptions) } -func NewObjectHandler[I any, O any](fn ObjectHandlerFn[I, O]) *objectHandler[I, O] { +func NewObjectHandler[I any, O any](fn ObjectHandlerFn[I, O], options ...ObjectHandlerOption) *objectHandler[I, O] { + opts := objectHandlerOptions{} + for _, opt := range options { + opt.beforeObjectHandler(&opts) + } + if opts.codec == nil { + opts.codec = encoding.PartialVoidCodec[I, O]() + } return &objectHandler[I, O]{ - fn: fn, + exclusiveFn: fn, + options: opts, + handlerType: internal.ServiceHandlerType_EXCLUSIVE, } } -func (h *objectHandler[I, O]) Call(ctx ObjectContext, bytes []byte) ([]byte, error) { - input := new(I) +func NewObjectSharedHandler[I any, O any](fn ObjectSharedHandlerFn[I, O], options ...ObjectHandlerOption) *objectHandler[I, O] { + opts := objectHandlerOptions{} + for _, opt := range options { + opt.beforeObjectHandler(&opts) + } + if opts.codec == nil { + opts.codec = encoding.PartialVoidCodec[I, O]() + } + return &objectHandler[I, O]{ + sharedFn: fn, + options: opts, + handlerType: internal.ServiceHandlerType_SHARED, + } +} - if len(bytes) > 0 { - // use the zero value if there is no input data at all - if err := json.Unmarshal(bytes, input); err != nil { - return nil, TerminalError(fmt.Errorf("request doesn't match handler signature: %w", err)) - } +func (h *objectHandler[I, O]) Call(ctx ObjectContext, bytes []byte) ([]byte, error) { + var input I + if err := h.options.codec.Unmarshal(bytes, &input); err != nil { + return nil, TerminalError(fmt.Errorf("request could not be decoded into handler input type: %w", err), http.StatusBadRequest) } - output, err := h.fn( - ctx, - *input, - ) + var output O + var err error + switch h.handlerType { + case internal.ServiceHandlerType_EXCLUSIVE: + output, err = h.exclusiveFn( + ctx, + input, + ) + case internal.ServiceHandlerType_SHARED: + output, err = h.sharedFn( + ctx, + input, + ) + } if err != nil { return nil, err } - bytes, err = json.Marshal(output) + bytes, err = h.options.codec.Marshal(output) if err != nil { return nil, TerminalError(fmt.Errorf("failed to serialize output: %w", err)) } @@ -137,4 +191,20 @@ func (h *objectHandler[I, O]) Call(ctx ObjectContext, bytes []byte) ([]byte, err return bytes, nil } +func (h *objectHandler[I, O]) InputPayload() *encoding.InputPayload { + return h.options.codec.InputPayload() +} + +func (h *objectHandler[I, O]) OutputPayload() *encoding.OutputPayload { + return h.options.codec.OutputPayload() +} + +func (h *objectHandler[I, O]) getOptions() *objectHandlerOptions { + return &h.options +} + +func (h *objectHandler[I, O]) HandlerType() *internal.ServiceHandlerType { + return &h.handlerType +} + func (h *objectHandler[I, O]) sealed() {} diff --git a/internal/errors/error.go b/internal/errors/error.go index e7ca7ff..f0ef08f 100644 --- a/internal/errors/error.go +++ b/internal/errors/error.go @@ -13,13 +13,17 @@ const ( ErrProtocolViolation Code = 571 ) +var ( + ErrKeyNotFound = NewTerminalError(fmt.Errorf("key not found"), 404) +) + type CodeError struct { Code Code Inner error } func (e *CodeError) Error() string { - return fmt.Sprintf("[CODE %04X] %s", e.Code, e.Inner) + return fmt.Sprintf("[%d] %s", e.Code, e.Inner) } func (e *CodeError) Unwrap() error { diff --git a/internal/futures/futures.go b/internal/futures/futures.go index 148f078..732ccd7 100644 --- a/internal/futures/futures.go +++ b/internal/futures/futures.go @@ -4,7 +4,6 @@ import ( "context" "encoding/base64" "encoding/binary" - "encoding/json" "fmt" "github.com/restatedev/sdk-go/generated/proto/protocol" @@ -32,8 +31,8 @@ func (a *After) Done() { a.entry.Await(a.suspensionCtx, a.entryIndex) } -func (a *After) getEntry() (wire.CompleteableMessage, uint32, error) { - return a.entry, a.entryIndex, nil +func (a *After) getEntry() (wire.CompleteableMessage, uint32) { + return a.entry, a.entryIndex } const AWAKEABLE_IDENTIFIER_PREFIX = "prom_1" @@ -62,8 +61,8 @@ func (c *Awakeable) Result() ([]byte, error) { return nil, fmt.Errorf("unexpected result in completed awakeable entry: %v", c.entry.Result) } } -func (c *Awakeable) getEntry() (wire.CompleteableMessage, uint32, error) { - return c.entry, c.entryIndex, nil +func (c *Awakeable) getEntry() (wire.CompleteableMessage, uint32) { + return c.entry, c.entryIndex } func awakeableID(invocationID []byte, entryIndex uint32) string { @@ -74,46 +73,29 @@ func awakeableID(invocationID []byte, entryIndex uint32) string { } type ResponseFuture struct { - suspensionCtx context.Context - err error - entry *wire.CallEntryMessage - entryIndex uint32 -} - -func NewResponseFuture(suspensionCtx context.Context, entry *wire.CallEntryMessage, entryIndex uint32) *ResponseFuture { - return &ResponseFuture{suspensionCtx, nil, entry, entryIndex} + suspensionCtx context.Context + entry *wire.CallEntryMessage + entryIndex uint32 + newProtocolViolation func(error) any } -func NewFailedResponseFuture(err error) *ResponseFuture { - return &ResponseFuture{nil, err, nil, 0} +func NewResponseFuture(suspensionCtx context.Context, entry *wire.CallEntryMessage, entryIndex uint32, newProtocolViolation func(error) any) *ResponseFuture { + return &ResponseFuture{suspensionCtx, entry, entryIndex, newProtocolViolation} } -func (r *ResponseFuture) Response(output any) error { - if r.err != nil { - return r.err - } - +func (r *ResponseFuture) Response() ([]byte, error) { r.entry.Await(r.suspensionCtx, r.entryIndex) - var bytes []byte switch result := r.entry.Result.(type) { case *protocol.CallEntryMessage_Failure: - return errors.ErrorFromFailure(result.Failure) + return nil, errors.ErrorFromFailure(result.Failure) case *protocol.CallEntryMessage_Value: - bytes = result.Value + return result.Value, nil default: - return errors.NewTerminalError(fmt.Errorf("sync call had invalid result: %v", r.entry.Result), 571) - + panic(r.newProtocolViolation(fmt.Errorf("call entry had invalid result: %v", r.entry.Result))) } - - if err := json.Unmarshal(bytes, output); err != nil { - // TODO: is this should be a terminal error or not? - return errors.NewTerminalError(fmt.Errorf("failed to decode response (%s): %w", string(bytes), err)) - } - - return nil } -func (r *ResponseFuture) getEntry() (wire.CompleteableMessage, uint32, error) { - return r.entry, r.entryIndex, r.err +func (r *ResponseFuture) getEntry() (wire.CompleteableMessage, uint32) { + return r.entry, r.entryIndex } diff --git a/internal/futures/select.go b/internal/futures/select.go index fc3674d..314cb45 100644 --- a/internal/futures/select.go +++ b/internal/futures/select.go @@ -9,7 +9,7 @@ import ( ) type Selectable interface { - getEntry() (wire.CompleteableMessage, uint32, error) + getEntry() (wire.CompleteableMessage, uint32) } type Selector struct { @@ -56,10 +56,7 @@ func (s *Selector) Take(winningEntryIndex uint32) Selectable { if selectable == nil { return nil } - entry, _, err := selectable.getEntry() - if err != nil { - return nil - } + entry, _ := selectable.getEntry() if !entry.Completed() { return nil } @@ -81,19 +78,16 @@ func (s *Selector) Indexes() []uint32 { return indexes } -func Select(suspensionCtx context.Context, futs ...Selectable) (*Selector, error) { +func Select(suspensionCtx context.Context, futs ...Selectable) *Selector { s := &Selector{ suspensionCtx: suspensionCtx, indexedFuts: make(map[uint32]Selectable, len(futs)), indexedChans: make(map[uint32]<-chan struct{}, len(futs)), } for i := range futs { - entry, entryIndex, err := futs[i].getEntry() - if err != nil { - return nil, err - } + entry, entryIndex := futs[i].getEntry() s.indexedFuts[entryIndex] = futs[i] s.indexedChans[entryIndex] = entry.Done() } - return s, nil + return s } diff --git a/internal/options/options.go b/internal/options/options.go new file mode 100644 index 0000000..f3d2299 --- /dev/null +++ b/internal/options/options.go @@ -0,0 +1,51 @@ +package options + +import "github.com/restatedev/sdk-go/encoding" + +type AwakeableOptions struct { + Codec encoding.Codec +} + +type AwakeableOption interface { + BeforeAwakeable(*AwakeableOptions) +} + +type ResolveAwakeableOptions struct { + Codec encoding.Codec +} + +type ResolveAwakeableOption interface { + BeforeResolveAwakeable(*ResolveAwakeableOptions) +} + +type GetOptions struct { + Codec encoding.Codec +} + +type GetOption interface { + BeforeGet(*GetOptions) +} + +type SetOptions struct { + Codec encoding.Codec +} + +type SetOption interface { + BeforeSet(*SetOptions) +} + +type CallOptions struct { + Codec encoding.Codec +} + +type CallOption interface { + BeforeCall(*CallOptions) +} + +type RunOptions struct { + Codec encoding.Codec +} + +type RunOption interface { + BeforeRun(*RunOptions) +} diff --git a/internal/state/awakeable.go b/internal/state/awakeable.go index 05b575d..85f6671 100644 --- a/internal/state/awakeable.go +++ b/internal/state/awakeable.go @@ -9,7 +9,7 @@ import ( "github.com/restatedev/sdk-go/internal/wire" ) -func (c *Machine) awakeable() restate.Awakeable[[]byte] { +func (c *Machine) awakeable() *futures.Awakeable { entry, entryIndex := replayOrNew( c, func(entry *wire.AwakeableEntryMessage) *wire.AwakeableEntryMessage { diff --git a/internal/state/call.go b/internal/state/call.go index a0e6281..8c9e5cf 100644 --- a/internal/state/call.go +++ b/internal/state/call.go @@ -2,92 +2,74 @@ package state import ( "bytes" - "encoding/json" + "fmt" "time" restate "github.com/restatedev/sdk-go" "github.com/restatedev/sdk-go/generated/proto/protocol" + "github.com/restatedev/sdk-go/internal/errors" "github.com/restatedev/sdk-go/internal/futures" + "github.com/restatedev/sdk-go/internal/options" "github.com/restatedev/sdk-go/internal/wire" ) -var ( - _ restate.ServiceClient = (*serviceProxy)(nil) - _ restate.ServiceSendClient = (*serviceSendProxy)(nil) - _ restate.CallClient = (*serviceCall)(nil) - _ restate.SendClient = (*serviceSend)(nil) -) - -type serviceProxy struct { +type serviceCall struct { + options options.CallOptions machine *Machine service string key string + method string } -func (c *serviceProxy) Method(fn string) restate.CallClient { - return &serviceCall{ - machine: c.machine, - service: c.service, - key: c.key, - method: fn, +// RequestFuture makes a call and returns a handle on the response +func (c *serviceCall) RequestFuture(input any) (restate.ResponseFuture, error) { + bytes, err := c.options.Codec.Marshal(input) + if err != nil { + return nil, errors.NewTerminalError(fmt.Errorf("failed to marshal RequestFuture input: %w", err)) } -} + entry, entryIndex := c.machine.doCall(c.service, c.key, c.method, bytes) -type serviceSendProxy struct { - machine *Machine - service string - key string - delay time.Duration + return decodingResponseFuture{ + futures.NewResponseFuture(c.machine.suspensionCtx, entry, entryIndex, func(err error) any { return c.machine.newProtocolViolation(entry, err) }), + c.options, + }, nil } -func (c *serviceSendProxy) Method(fn string) restate.SendClient { - return &serviceSend{ - machine: c.machine, - service: c.service, - key: c.key, - method: fn, - delay: c.delay, - } +type decodingResponseFuture struct { + *futures.ResponseFuture + options options.CallOptions } -type serviceCall struct { - machine *Machine - service string - key string - method string -} - -// Do makes a call and wait for the response -func (c *serviceCall) Request(input any) restate.ResponseFuture { - if entry, entryIndex, err := c.machine.doDynCall(c.service, c.key, c.method, input); err != nil { - return futures.NewFailedResponseFuture(err) - } else { - return futures.NewResponseFuture(c.machine.suspensionCtx, entry, entryIndex) +func (d decodingResponseFuture) Response(output any) (err error) { + bytes, err := d.ResponseFuture.Response() + if err != nil { + return err } -} -type serviceSend struct { - machine *Machine - service string - key string - method string + if err := d.options.Codec.Unmarshal(bytes, output); err != nil { + return errors.NewTerminalError(fmt.Errorf("failed to unmarshal Call response into O: %w", err)) + } - delay time.Duration + return nil } -// Send runs a call in the background after delay duration -func (c *serviceSend) Request(input any) error { - return c.machine.sendCall(c.service, c.key, c.method, input, c.delay) +// Request makes a call and blocks on the response +func (c *serviceCall) Request(input any, output any) error { + fut, err := c.RequestFuture(input) + if err != nil { + return err + } + return fut.Response(output) } -func (m *Machine) doDynCall(service, key, method string, input any) (*wire.CallEntryMessage, uint32, error) { - params, err := json.Marshal(input) +// Send runs a call in the background after delay duration +func (c *serviceCall) Send(input any, delay time.Duration) error { + bytes, err := c.options.Codec.Marshal(input) if err != nil { - return nil, 0, err + return errors.NewTerminalError(fmt.Errorf("failed to marshal Send input: %w", err)) } - - entry, entryIndex := m.doCall(service, key, method, params) - return entry, entryIndex, nil + c.machine.sendCall(c.service, c.key, c.method, bytes, delay) + return nil } func (m *Machine) doCall(service, key, method string, params []byte) (*wire.CallEntryMessage, uint32) { @@ -129,24 +111,19 @@ func (m *Machine) _doCall(service, key, method string, params []byte) *wire.Call return msg } -func (m *Machine) sendCall(service, key, method string, body any, delay time.Duration) error { - params, err := json.Marshal(body) - if err != nil { - return err - } - +func (m *Machine) sendCall(service, key, method string, body []byte, delay time.Duration) { _, _ = replayOrNew( m, func(entry *wire.OneWayCallEntryMessage) restate.Void { if entry.ServiceName != service || entry.Key != key || entry.HandlerName != method || - !bytes.Equal(entry.Parameter, params) { + !bytes.Equal(entry.Parameter, body) { panic(m.newEntryMismatch(&wire.OneWayCallEntryMessage{ OneWayCallEntryMessage: protocol.OneWayCallEntryMessage{ ServiceName: service, HandlerName: method, - Parameter: params, + Parameter: body, Key: key, }, }, entry)) @@ -155,12 +132,10 @@ func (m *Machine) sendCall(service, key, method string, body any, delay time.Dur return restate.Void{} }, func() restate.Void { - m._sendCall(service, key, method, params, delay) + m._sendCall(service, key, method, body, delay) return restate.Void{} }, ) - - return nil } func (c *Machine) _sendCall(service, key, method string, params []byte, delay time.Duration) { diff --git a/internal/state/select.go b/internal/state/select.go index 1ea21e1..1be3f4c 100644 --- a/internal/state/select.go +++ b/internal/state/select.go @@ -13,12 +13,9 @@ type selector struct { inner *futures.Selector } -func (m *Machine) selector(futs ...futures.Selectable) (*selector, error) { - inner, err := futures.Select(m.suspensionCtx, futs...) - if err != nil { - return nil, err - } - return &selector{m, inner}, nil +func (m *Machine) selector(futs ...futures.Selectable) *selector { + inner := futures.Select(m.suspensionCtx, futs...) + return &selector{m, inner} } func (s *selector) Select() futures.Selectable { diff --git a/internal/state/state.go b/internal/state/state.go index b270988..8d6f3cf 100644 --- a/internal/state/state.go +++ b/internal/state/state.go @@ -13,10 +13,12 @@ import ( "time" restate "github.com/restatedev/sdk-go" + "github.com/restatedev/sdk-go/encoding" "github.com/restatedev/sdk-go/generated/proto/protocol" "github.com/restatedev/sdk-go/internal/errors" "github.com/restatedev/sdk-go/internal/futures" "github.com/restatedev/sdk-go/internal/log" + "github.com/restatedev/sdk-go/internal/options" "github.com/restatedev/sdk-go/internal/rand" "github.com/restatedev/sdk-go/internal/wire" "github.com/restatedev/sdk-go/rcontext" @@ -30,10 +32,6 @@ var ( ErrInvalidVersion = fmt.Errorf("invalid version number") ) -var ( - _ restate.Context = (*Context)(nil) -) - type Context struct { context.Context userLogger *slog.Logger @@ -41,7 +39,9 @@ type Context struct { } var _ restate.ObjectContext = &Context{} +var _ restate.ObjectSharedContext = &Context{} var _ restate.Context = &Context{} +var _ restate.RunContext = &Context{} func (c *Context) Log() *slog.Logger { return c.machine.userLog @@ -51,8 +51,22 @@ func (c *Context) Rand() *rand.Rand { return c.machine.rand } -func (c *Context) Set(key string, value []byte) { - c.machine.set(key, value) +func (c *Context) Set(key string, value any, opts ...options.SetOption) error { + o := options.SetOptions{} + for _, opt := range opts { + opt.BeforeSet(&o) + } + if o.Codec == nil { + o.Codec = encoding.JSONCodec + } + + bytes, err := o.Codec.Marshal(value) + if err != nil { + return errors.NewTerminalError(fmt.Errorf("failed to marshal Set value: %w", err)) + } + + c.machine.set(key, bytes) + return nil } func (c *Context) Clear(key string) { @@ -66,11 +80,28 @@ func (c *Context) ClearAll() { } -func (c *Context) Get(key string) ([]byte, error) { - return c.machine.get(key) +func (c *Context) Get(key string, output any, opts ...options.GetOption) error { + o := options.GetOptions{} + for _, opt := range opts { + opt.BeforeGet(&o) + } + if o.Codec == nil { + o.Codec = encoding.JSONCodec + } + + bytes := c.machine.get(key) + if len(bytes) == 0 { + return errors.ErrKeyNotFound + } + + if err := o.Codec.Unmarshal(bytes, output); err != nil { + return errors.NewTerminalError(fmt.Errorf("failed to unmarshal Get state into output: %w", err)) + } + + return nil } -func (c *Context) Keys() ([]string, error) { +func (c *Context) Keys() []string { return c.machine.keys() } @@ -82,55 +113,131 @@ func (c *Context) After(d time.Duration) restate.After { return c.machine.after(d) } -func (c *Context) Service(service string) restate.ServiceClient { - return &serviceProxy{ - machine: c.machine, - service: service, +func (c *Context) Service(service, method string, opts ...options.CallOption) restate.CallClient { + o := options.CallOptions{} + for _, opt := range opts { + opt.BeforeCall(&o) + } + if o.Codec == nil { + o.Codec = encoding.JSONCodec } -} -func (c *Context) ServiceSend(service string, delay time.Duration) restate.ServiceSendClient { - return &serviceSendProxy{ + return &serviceCall{ + options: o, machine: c.machine, service: service, - delay: delay, + method: method, } } -func (c *Context) Object(service, key string) restate.ServiceClient { - return &serviceProxy{ +func (c *Context) Object(service, key, method string, opts ...options.CallOption) restate.CallClient { + o := options.CallOptions{} + for _, opt := range opts { + opt.BeforeCall(&o) + } + if o.Codec == nil { + o.Codec = encoding.JSONCodec + } + + return &serviceCall{ + options: o, machine: c.machine, service: service, key: key, + method: method, } } -func (c *Context) ObjectSend(service, key string, delay time.Duration) restate.ServiceSendClient { - return &serviceSendProxy{ - machine: c.machine, - service: service, - key: key, - delay: delay, +func (c *Context) Run(fn func(ctx restate.RunContext) (any, error), output any, opts ...options.RunOption) error { + o := options.RunOptions{} + for _, opt := range opts { + opt.BeforeRun(&o) + } + if o.Codec == nil { + o.Codec = encoding.JSONCodec + } + + bytes, err := c.machine.run(func(ctx restate.RunContext) ([]byte, error) { + output, err := fn(ctx) + if err != nil { + return nil, err + } + + bytes, err := o.Codec.Marshal(output) + if err != nil { + return nil, errors.NewTerminalError(fmt.Errorf("failed to marshal Run output: %w", err)) + } + + return bytes, nil + }) + if err != nil { + return err } + + if err := o.Codec.Unmarshal(bytes, output); err != nil { + return errors.NewTerminalError(fmt.Errorf("failed to unmarshal Run output: %w", err)) + } + + return nil } -func (c *Context) Run(fn func(ctx restate.RunContext) ([]byte, error)) ([]byte, error) { - return c.machine.run(fn) +type awakeableOptions struct { + codec encoding.Codec } -func (c *Context) Awakeable() restate.Awakeable[[]byte] { - return c.machine.awakeable() +type AwakeableOption interface { + beforeAwakeable(*awakeableOptions) } -func (c *Context) ResolveAwakeable(id string, value []byte) { - c.machine.resolveAwakeable(id, value) +func (c *Context) Awakeable(opts ...options.AwakeableOption) restate.Awakeable { + o := options.AwakeableOptions{} + for _, opt := range opts { + opt.BeforeAwakeable(&o) + } + if o.Codec == nil { + o.Codec = encoding.JSONCodec + } + return decodingAwakeable{c.machine.awakeable(), o.Codec} +} + +type decodingAwakeable struct { + *futures.Awakeable + codec encoding.Codec +} + +func (d decodingAwakeable) Id() string { return d.Awakeable.Id() } +func (d decodingAwakeable) Result(output any) (err error) { + bytes, err := d.Awakeable.Result() + if err != nil { + return err + } + if err := d.codec.Unmarshal(bytes, output); err != nil { + return errors.NewTerminalError(fmt.Errorf("failed to unmarshal Awakeable result into output: %w", err)) + } + return +} + +func (c *Context) ResolveAwakeable(id string, value any, opts ...options.ResolveAwakeableOption) error { + o := options.ResolveAwakeableOptions{} + for _, opt := range opts { + opt.BeforeResolveAwakeable(&o) + } + if o.Codec == nil { + o.Codec = encoding.JSONCodec + } + bytes, err := o.Codec.Marshal(value) + if err != nil { + return errors.NewTerminalError(fmt.Errorf("failed to marshal ResolveAwakeable value: %w", err)) + } + c.machine.resolveAwakeable(id, bytes) + return nil } func (c *Context) RejectAwakeable(id string, reason error) { c.machine.rejectAwakeable(id, reason) } -func (c *Context) Selector(futs ...futures.Selectable) (restate.Selector, error) { +func (c *Context) Select(futs ...futures.Selectable) restate.Selector { return c.machine.selector(futs...) } @@ -239,6 +346,19 @@ func (m *Machine) invoke(ctx *Context, input []byte, outputSeen bool) error { case nil: // nothing to do, just exit return + case *protocolViolation: + m.log.LogAttrs(m.ctx, slog.LevelError, "Protocol violation", log.Error(typ.err)) + + if err := m.protocol.Write(wire.ErrorMessageType, &wire.ErrorMessage{ + ErrorMessage: protocol.ErrorMessage{ + Code: uint32(errors.ErrProtocolViolation), + Message: fmt.Sprintf("Protocol violation: %v", typ.err), + RelatedEntryIndex: &typ.entryIndex, + RelatedEntryType: wire.MessageType(typ.entry).UInt32(), + }, + }); err != nil { + m.log.LogAttrs(m.ctx, slog.LevelError, "Error sending failure message", log.Error(err)) + } case *entryMismatch: expected, _ := json.Marshal(typ.expectedEntry) actual, _ := json.Marshal(typ.actualEntry) diff --git a/internal/state/sys.go b/internal/state/sys.go index f5f33a1..0409049 100644 --- a/internal/state/sys.go +++ b/internal/state/sys.go @@ -29,6 +29,18 @@ func (m *Machine) newEntryMismatch(expectedEntry wire.Message, actualEntry wire. return e } +type protocolViolation struct { + entryIndex uint32 + entry wire.Message + err error +} + +func (m *Machine) newProtocolViolation(entry wire.Message, err error) *protocolViolation { + e := &protocolViolation{m.entryIndex, entry, err} + m.failure = e + return e +} + func (m *Machine) set(key string, value []byte) { _, _ = replayOrNew( m, @@ -113,7 +125,7 @@ func (m *Machine) _clearAll() { ) } -func (m *Machine) get(key string) ([]byte, error) { +func (m *Machine) get(key string) []byte { entry, entryIndex := replayOrNew( m, func(entry *wire.GetStateEntryMessage) *wire.GetStateEntryMessage { @@ -133,18 +145,13 @@ func (m *Machine) get(key string) ([]byte, error) { switch value := entry.Result.(type) { case *protocol.GetStateEntryMessage_Empty: - return nil, nil - case *protocol.GetStateEntryMessage_Failure: - // the get state entry message is not failable so this should - // never happen - // TODO terminal? - return nil, fmt.Errorf("[%d] %s", value.Failure.Code, value.Failure.Message) + return nil case *protocol.GetStateEntryMessage_Value: m.current[key] = value.Value - return value.Value, nil + return value.Value + default: + panic(m.newProtocolViolation(entry, fmt.Errorf("get state entry had invalid result: %v", entry.Result))) } - - return nil, restate.TerminalError(fmt.Errorf("get state had invalid result: %v", entry.Result), errors.ErrProtocolViolation) } func (m *Machine) _get(key string) *wire.GetStateEntryMessage { @@ -184,7 +191,7 @@ func (m *Machine) _get(key string) *wire.GetStateEntryMessage { return msg } -func (m *Machine) keys() ([]string, error) { +func (m *Machine) keys() []string { entry, entryIndex := replayOrNew( m, func(entry *wire.GetStateKeysEntryMessage) *wire.GetStateKeysEntryMessage { @@ -196,20 +203,16 @@ func (m *Machine) keys() ([]string, error) { entry.Await(m.suspensionCtx, entryIndex) switch value := entry.Result.(type) { - case *protocol.GetStateKeysEntryMessage_Failure: - // the get state entry message is not failable so this should - // never happen - return nil, fmt.Errorf("[%d] %s", value.Failure.Code, value.Failure.Message) case *protocol.GetStateKeysEntryMessage_Value: values := make([]string, 0, len(value.Value.Keys)) for _, key := range value.Value.Keys { values = append(values, string(key)) } - return values, nil + return values + default: + panic(m.newProtocolViolation(entry, fmt.Errorf("get state keys entry had invalid result: %v", entry.Result))) } - - return nil, nil } func (m *Machine) _keys() *wire.GetStateKeysEntryMessage { @@ -297,9 +300,9 @@ func (m *Machine) run(fn func(restate.RunContext) ([]byte, error)) ([]byte, erro case nil: // Empty result is valid return nil, nil + default: + panic(m.newProtocolViolation(entry, fmt.Errorf("run entry had invalid result: %v", entry.Result))) } - - return nil, restate.TerminalError(fmt.Errorf("run entry had invalid result: %v", entry.Result), errors.ErrProtocolViolation) } type runContext struct { diff --git a/options.go b/options.go new file mode 100644 index 0000000..58c2ae6 --- /dev/null +++ b/options.go @@ -0,0 +1,55 @@ +package restate + +import ( + "github.com/restatedev/sdk-go/encoding" + "github.com/restatedev/sdk-go/internal/options" +) + +type withCodec struct { + codec encoding.Codec +} + +var _ options.GetOption = withCodec{} +var _ options.SetOption = withCodec{} +var _ options.RunOption = withCodec{} +var _ options.AwakeableOption = withCodec{} +var _ options.ResolveAwakeableOption = withCodec{} +var _ options.CallOption = withCodec{} + +func (w withCodec) BeforeGet(opts *options.GetOptions) { opts.Codec = w.codec } +func (w withCodec) BeforeSet(opts *options.SetOptions) { opts.Codec = w.codec } +func (w withCodec) BeforeRun(opts *options.RunOptions) { opts.Codec = w.codec } +func (w withCodec) BeforeAwakeable(opts *options.AwakeableOptions) { opts.Codec = w.codec } +func (w withCodec) BeforeResolveAwakeable(opts *options.ResolveAwakeableOptions) { + opts.Codec = w.codec +} +func (w withCodec) BeforeCall(opts *options.CallOptions) { opts.Codec = w.codec } + +func WithCodec(codec encoding.Codec) withCodec { + return withCodec{codec} +} + +type withPayloadCodec struct { + withCodec + codec encoding.PayloadCodec +} + +var _ ServiceHandlerOption = withPayloadCodec{} +var _ ServiceRouterOption = withPayloadCodec{} +var _ ObjectHandlerOption = withPayloadCodec{} +var _ ObjectRouterOption = withPayloadCodec{} + +func (w withPayloadCodec) beforeServiceHandler(opts *serviceHandlerOptions) { opts.codec = w.codec } +func (w withPayloadCodec) beforeObjectHandler(opts *objectHandlerOptions) { opts.codec = w.codec } +func (w withPayloadCodec) beforeServiceRouter(opts *serviceRouterOptions) { + opts.defaultCodec = w.codec +} +func (w withPayloadCodec) beforeObjectRouter(opts *objectRouterOptions) { opts.defaultCodec = w.codec } + +func WithPayloadCodec(codec encoding.PayloadCodec) withPayloadCodec { + return withPayloadCodec{withCodec{codec}, codec} +} + +var WithProto = WithPayloadCodec(encoding.ProtoCodec) +var WithBinary = WithPayloadCodec(encoding.BinaryCodec) +var WithJSON = WithPayloadCodec(encoding.JSONCodec) diff --git a/reflect.go b/reflect.go index 6469f0a..24c5b4e 100644 --- a/reflect.go +++ b/reflect.go @@ -1,12 +1,12 @@ package restate import ( - "encoding/json" "fmt" + "net/http" "reflect" "github.com/restatedev/sdk-go/encoding" - "google.golang.org/protobuf/proto" + "github.com/restatedev/sdk-go/internal" ) type serviceNamer interface { @@ -14,21 +14,22 @@ type serviceNamer interface { } var ( - typeOfContext = reflect.TypeOf((*Context)(nil)).Elem() - typeOfObjectContext = reflect.TypeOf((*ObjectContext)(nil)).Elem() - typeOfVoid = reflect.TypeOf((*Void)(nil)) - typeOfError = reflect.TypeOf((*error)(nil)) + typeOfContext = reflect.TypeOf((*Context)(nil)).Elem() + typeOfObjectContext = reflect.TypeOf((*ObjectContext)(nil)).Elem() + typeOfSharedObjectContext = reflect.TypeOf((*ObjectSharedContext)(nil)).Elem() + typeOfVoid = reflect.TypeOf((*Void)(nil)).Elem() + typeOfError = reflect.TypeOf((*error)(nil)).Elem() ) // Object converts a struct with methods into a Virtual Object where each correctly-typed // and exported method of the struct will become a handler on the Object. The Object name defaults // to the name of the struct, but this can be overidden by providing a `ServiceName() string` method. -// The handler name is the name of the method. Handler methods should be of the type `ObjectHandlerFn[I, O]`. -// Input types I will be deserialised from JSON except when they are restate.Void, -// in which case no input bytes or content type may be sent. Output types O will be serialised -// to JSON except when they are restate.Void, in which case no data will be sent and no content type -// set. -func Object(object any) *ObjectRouter { +// The handler name is the name of the method. Handler methods should be of the type `ObjectHandlerFn[I, O]` or `ObjectSharedHandlerFn[I, O]`. +// Input types I will be deserialised with the provided codec (defaults to JSON) except when they are restate.Void, +// in which case no input bytes or content type may be sent. +// Output types O will be serialised with the provided codec (defaults to JSON) except when they are restate.Void, +// in which case no data will be sent and no content type set. +func Object(object any, options ...ObjectRouterOption) *ObjectRouter { typ := reflect.TypeOf(object) val := reflect.ValueOf(object) var name string @@ -37,7 +38,7 @@ func Object(object any) *ObjectRouter { } else { name = reflect.Indirect(val).Type().Name() } - router := NewObjectRouter(name) + router := NewObjectRouter(name, options...) for m := 0; m < typ.NumMethod(); m++ { method := typ.Method(m) @@ -52,7 +53,15 @@ func Object(object any) *ObjectRouter { continue } - if ctxType := mtype.In(1); ctxType != typeOfObjectContext { + var handlerType internal.ServiceHandlerType + + switch mtype.In(1) { + case typeOfObjectContext: + handlerType = internal.ServiceHandlerType_EXCLUSIVE + case typeOfSharedObjectContext: + handlerType = internal.ServiceHandlerType_SHARED + default: + // first parameter is not an object context continue } @@ -66,12 +75,29 @@ func Object(object any) *ObjectRouter { continue } + input := mtype.In(2) + output := mtype.Out(0) + + var codec encoding.PayloadCodec + switch { + case input == typeOfVoid && output == typeOfVoid: + codec = encoding.VoidCodec + case input == typeOfVoid: + codec = encoding.PairCodec{Input: encoding.VoidCodec, Output: nil} + case output == typeOfVoid: + codec = encoding.PairCodec{Input: nil, Output: encoding.VoidCodec} + default: + codec = nil + } + router.Handler(mname, &objectReflectHandler{ + objectHandlerOptions{codec}, + handlerType, reflectHandler{ fn: method.Func, receiver: val, - input: mtype.In(2), - output: mtype.Out(0), + input: input, + output: output, }, }) } @@ -83,11 +109,11 @@ func Object(object any) *ObjectRouter { // and exported method of the struct will become a handler on the Service. The Service name defaults // to the name of the struct, but this can be overidden by providing a `ServiceName() string` method. // The handler name is the name of the method. Handler methods should be of the type `ServiceHandlerFn[I, O]`. -// Input types I will be deserialised from JSON except when they are restate.Void, -// in which case no input bytes or content type may be sent. Output types O will be serialised -// to JSON except when they are restate.Void, in which case no data will be sent and no content type -// set. -func Service(service any) *ServiceRouter { +// Input types I will be deserialised with the provided codec (defaults to JSON) except when they are restate.Void, +// in which case no input bytes or content type may be sent. +// Output types O will be serialised with the provided codec (defaults to JSON) except when they are restate.Void, +// in which case no data will be sent and no content type set. +func Service(service any, options ...ServiceRouterOption) *ServiceRouter { typ := reflect.TypeOf(service) val := reflect.ValueOf(service) var name string @@ -96,7 +122,7 @@ func Service(service any) *ServiceRouter { } else { name = reflect.Indirect(val).Type().Name() } - router := NewServiceRouter(name) + router := NewServiceRouter(name, options...) for m := 0; m < typ.NumMethod(); m++ { method := typ.Method(m) @@ -127,12 +153,28 @@ func Service(service any) *ServiceRouter { continue } + input := mtype.In(2) + output := mtype.Out(0) + + var codec encoding.PayloadCodec + switch { + case input == typeOfVoid && output == typeOfVoid: + codec = encoding.VoidCodec + case input == typeOfVoid: + codec = encoding.PairCodec{Input: encoding.VoidCodec, Output: nil} + case output == typeOfVoid: + codec = encoding.PairCodec{Input: nil, Output: encoding.VoidCodec} + default: + codec = nil + } + router.Handler(mname, &serviceReflectHandler{ + serviceHandlerOptions{codec: codec}, reflectHandler{ fn: method.Func, receiver: val, - input: mtype.In(2), - output: mtype.Out(0), + input: input, + output: output, }, }) } @@ -147,42 +189,21 @@ type reflectHandler struct { output reflect.Type } -func (h *reflectHandler) InputPayload() *encoding.InputPayload { - if h.input == typeOfVoid { - return &encoding.InputPayload{} - } else { - return &encoding.InputPayload{ - Required: true, - ContentType: proto.String("application/json"), - } - } -} - -func (h *reflectHandler) OutputPayload() *encoding.OutputPayload { - if h.output == typeOfVoid { - return &encoding.OutputPayload{} - } else { - return &encoding.OutputPayload{ - ContentType: proto.String("application/json"), - } - } -} - func (h *reflectHandler) sealed() {} type objectReflectHandler struct { + options objectHandlerOptions + handlerType internal.ServiceHandlerType reflectHandler } -var _ Handler = (*objectReflectHandler)(nil) +var _ ObjectHandler = (*objectReflectHandler)(nil) func (h *objectReflectHandler) Call(ctx ObjectContext, bytes []byte) ([]byte, error) { input := reflect.New(h.input) - if h.input != typeOfVoid { - if err := json.Unmarshal(bytes, input.Interface()); err != nil { - return nil, TerminalError(fmt.Errorf("request could not be decoded into handler input type: %w", err)) - } + if err := h.options.codec.Unmarshal(bytes, input.Interface()); err != nil { + return nil, TerminalError(fmt.Errorf("request could not be decoded into handler input type: %w", err), http.StatusBadRequest) } // we are sure about the fn signature so it's safe to do this @@ -198,11 +219,7 @@ func (h *objectReflectHandler) Call(ctx ObjectContext, bytes []byte) ([]byte, er return nil, errI.(error) } - if h.output == typeOfVoid { - return nil, nil - } - - bytes, err := json.Marshal(outI) + bytes, err := h.options.codec.Marshal(outI) if err != nil { return nil, TerminalError(fmt.Errorf("failed to serialize output: %w", err)) } @@ -210,17 +227,34 @@ func (h *objectReflectHandler) Call(ctx ObjectContext, bytes []byte) ([]byte, er return bytes, nil } +func (h *objectReflectHandler) getOptions() *objectHandlerOptions { + return &h.options +} + +func (h *objectReflectHandler) InputPayload() *encoding.InputPayload { + return h.options.codec.InputPayload() +} + +func (h *objectReflectHandler) OutputPayload() *encoding.OutputPayload { + return h.options.codec.OutputPayload() +} + +func (h *objectReflectHandler) HandlerType() *internal.ServiceHandlerType { + return &h.handlerType +} + type serviceReflectHandler struct { + options serviceHandlerOptions reflectHandler } -var _ Handler = (*serviceReflectHandler)(nil) +var _ ServiceHandler = (*serviceReflectHandler)(nil) func (h *serviceReflectHandler) Call(ctx Context, bytes []byte) ([]byte, error) { input := reflect.New(h.input) - if err := json.Unmarshal(bytes, input.Interface()); err != nil { - return nil, TerminalError(fmt.Errorf("request doesn't match handler signature: %w", err)) + if err := h.options.codec.Unmarshal(bytes, input.Interface()); err != nil { + return nil, TerminalError(fmt.Errorf("request could not be decoded into handler input type: %w", err), http.StatusBadRequest) } // we are sure about the fn signature so it's safe to do this @@ -236,10 +270,26 @@ func (h *serviceReflectHandler) Call(ctx Context, bytes []byte) ([]byte, error) return nil, errI.(error) } - bytes, err := json.Marshal(outI) + bytes, err := h.options.codec.Marshal(outI) if err != nil { return nil, TerminalError(fmt.Errorf("failed to serialize output: %w", err)) } return bytes, nil } + +func (h *serviceReflectHandler) getOptions() *serviceHandlerOptions { + return &h.options +} + +func (h *serviceReflectHandler) InputPayload() *encoding.InputPayload { + return h.options.codec.InputPayload() +} + +func (h *serviceReflectHandler) OutputPayload() *encoding.OutputPayload { + return h.options.codec.OutputPayload() +} + +func (h *serviceReflectHandler) HandlerType() *internal.ServiceHandlerType { + return nil +} diff --git a/router.go b/router.go index 7f07fd7..9cd52b4 100644 --- a/router.go +++ b/router.go @@ -1,107 +1,10 @@ package restate import ( - "context" - "encoding/json" - "fmt" - "log/slog" - "time" - "github.com/restatedev/sdk-go/encoding" "github.com/restatedev/sdk-go/internal" - "github.com/restatedev/sdk-go/internal/futures" - "github.com/restatedev/sdk-go/internal/rand" -) - -var ( - ErrKeyNotFound = fmt.Errorf("key not found") ) -type CallClient interface { - // Request makes a call and returns a handle on a future response - Request(input any) ResponseFuture -} - -type SendClient interface { - // Send makes a call in the background (doesn't wait for response) - Request(input any) error -} - -type ResponseFuture interface { - // Response waits for the response to the call and unmarshals it into output - Response(output any) error - futures.Selectable -} - -type ServiceClient interface { - // Method creates a call to method with name - Method(method string) CallClient -} - -type ServiceSendClient interface { - // Method creates a call to method with name - Method(method string) SendClient -} - -type Selector interface { - Remaining() bool - Select() futures.Selectable -} - -type Context interface { - RunContext - - // Returns a random source which will give deterministic results for a given invocation - // The source wraps the stdlib rand.Rand but with some extra helper methods - // This source is not safe for use inside .Run() - Rand() *rand.Rand - - // Sleep for the duration d - Sleep(d time.Duration) - // Return a handle on a sleep duration which can be combined - After(d time.Duration) After - - // Service gets a Service accessor by name where service - // must be another service known by restate runtime - Service(service string) ServiceClient - // Service gets a Service send accessor by name where service - // must be another service known by restate runtime - // and delay is the duration with which to delay requests - ServiceSend(service string, delay time.Duration) ServiceSendClient - - // Object gets a Object accessor by name where object - // must be another object known by restate runtime and - // key is any string representing the key for the object - Object(object, key string) ServiceClient - // Object gets a Object accessor by name where object - // must be another object known by restate runtime, - // key is any string representing the key for the object, - // and delay is the duration with which to delay requests - ObjectSend(object, key string, delay time.Duration) ServiceSendClient - - // Run runs the function (fn) until it succeeds or permanently fails. - // this stores the results of the function inside restate runtime so a replay - // will produce the same value (think generating a unique id for example) - // Note: use the RunAs helper function - Run(fn func(RunContext) ([]byte, error)) ([]byte, error) - - Awakeable() Awakeable[[]byte] - ResolveAwakeable(id string, value []byte) - RejectAwakeable(id string, reason error) - - Selector(futs ...futures.Selectable) (Selector, error) -} - -// RunContext methods are the only methods safe to call from inside a .Run() -type RunContext interface { - context.Context - - // Log obtains a handle on a slog.Logger which already has some useful fields (invocationID and method) - // By default, this logger will not output messages if the invocation is currently replaying - // The log handler can be set with `.WithLogger()` on the server object - Log() *slog.Logger -} - // Router interface type Router interface { Name() string @@ -110,83 +13,36 @@ type Router interface { Handlers() map[string]Handler } -type ObjectHandler interface { - Call(ctx ObjectContext, request []byte) (output []byte, err error) - Handler -} - -type ServiceHandler interface { - Call(ctx Context, request []byte) (output []byte, err error) - Handler -} - -type Handler interface { - sealed() - InputPayload() *encoding.InputPayload - OutputPayload() *encoding.OutputPayload -} - -type ServiceType string - -const ( - ServiceType_VIRTUAL_OBJECT ServiceType = "VIRTUAL_OBJECT" - ServiceType_SERVICE ServiceType = "SERVICE" -) - -type KeyValueStore interface { - // Set sets key value to bytes array. You can - // Note: Use SetAs helper function to seamlessly store - // a value of specific type. - Set(key string, value []byte) - // Get gets value (bytes array) associated with key - // If key does not exist, this function return a nil bytes array - // and a nil error - // Note: Use GetAs helper function to seamlessly get value - // as specific type. - Get(key string) ([]byte, error) - // Clear deletes a key - Clear(key string) - // ClearAll drops all stored state associated with key - ClearAll() - // Keys returns a list of all associated key - Keys() ([]string, error) -} - -type ObjectContext interface { - Context - KeyValueStore - Key() string -} - -// ServiceHandlerFn signature of service (unkeyed) handler function -type ServiceHandlerFn[I any, O any] func(ctx Context, input I) (output O, err error) - -// ObjectHandlerFn signature for object (keyed) handler function -type ObjectHandlerFn[I any, O any] func(ctx ObjectContext, input I) (output O, err error) - -type Decoder[I any] interface { - InputPayload() *encoding.InputPayload - Decode(data []byte) (input I, err error) +type serviceRouterOptions struct { + defaultCodec encoding.PayloadCodec } -type Encoder[O any] interface { - OutputPayload() *encoding.OutputPayload - Encode(output O) ([]byte, error) +type ServiceRouterOption interface { + beforeServiceRouter(*serviceRouterOptions) } // ServiceRouter implements Router type ServiceRouter struct { name string handlers map[string]Handler + options serviceRouterOptions } var _ Router = &ServiceRouter{} // NewServiceRouter creates a new ServiceRouter -func NewServiceRouter(name string) *ServiceRouter { +func NewServiceRouter(name string, options ...ServiceRouterOption) *ServiceRouter { + opts := serviceRouterOptions{} + for _, opt := range options { + opt.beforeServiceRouter(&opts) + } + if opts.defaultCodec == nil { + opts.defaultCodec = encoding.JSONCodec + } return &ServiceRouter{ name: name, handlers: make(map[string]Handler), + options: opts, } } @@ -196,6 +52,7 @@ func (r *ServiceRouter) Name() string { // Handler registers a new handler by name func (r *ServiceRouter) Handler(name string, handler ServiceHandler) *ServiceRouter { + handler.getOptions().codec = encoding.MergeCodec(handler.getOptions().codec, r.options.defaultCodec) r.handlers[name] = handler return r } @@ -208,18 +65,35 @@ func (r *ServiceRouter) Type() internal.ServiceType { return internal.ServiceType_SERVICE } +type objectRouterOptions struct { + defaultCodec encoding.PayloadCodec +} + +type ObjectRouterOption interface { + beforeObjectRouter(*objectRouterOptions) +} + // ObjectRouter type ObjectRouter struct { name string handlers map[string]Handler + options objectRouterOptions } var _ Router = &ObjectRouter{} -func NewObjectRouter(name string) *ObjectRouter { +func NewObjectRouter(name string, options ...ObjectRouterOption) *ObjectRouter { + opts := objectRouterOptions{} + for _, opt := range options { + opt.beforeObjectRouter(&opts) + } + if opts.defaultCodec == nil { + opts.defaultCodec = encoding.JSONCodec + } return &ObjectRouter{ name: name, handlers: make(map[string]Handler), + options: opts, } } @@ -228,6 +102,7 @@ func (r *ObjectRouter) Name() string { } func (r *ObjectRouter) Handler(name string, handler ObjectHandler) *ObjectRouter { + handler.getOptions().codec = encoding.MergeCodec(handler.getOptions().codec, r.options.defaultCodec) r.handlers[name] = handler return r } @@ -239,97 +114,3 @@ func (r *ObjectRouter) Handlers() map[string]Handler { func (r *ObjectRouter) Type() internal.ServiceType { return internal.ServiceType_VIRTUAL_OBJECT } - -// GetAs helper function to get a key as specific type. Note that -// if there is no associated value with key, an error ErrKeyNotFound is -// returned -// it does encoding/decoding of bytes automatically using json -func GetAs[T any](ctx ObjectContext, key string) (output T, err error) { - bytes, err := ctx.Get(key) - if err != nil { - return output, err - } - - if bytes == nil { - // key does not exit. - return output, ErrKeyNotFound - } - - err = json.Unmarshal(bytes, &output) - - return -} - -// SetAs helper function to set a key value with a generic type T. -// it does encoding/decoding of bytes automatically using json -func SetAs[T any](ctx ObjectContext, key string, value T) error { - bytes, err := json.Marshal(value) - if err != nil { - return err - } - - ctx.Set(key, bytes) - return nil -} - -// RunAs helper function runs a run function with specific concrete type as a result -// it does encoding/decoding of bytes automatically using json -func RunAs[T any](ctx Context, fn func(RunContext) (T, error)) (output T, err error) { - bytes, err := ctx.Run(func(ctx RunContext) ([]byte, error) { - out, err := fn(ctx) - if err != nil { - return nil, err - } - - bytes, err := json.Marshal(out) - return bytes, TerminalError(err) - }) - - if err != nil { - return output, err - } - - err = json.Unmarshal(bytes, &output) - - return output, TerminalError(err) -} - -type Awakeable[T any] interface { - Id() string - Result() (T, error) - futures.Selectable -} - -type decodingAwakeable[T any] struct { - Awakeable[[]byte] -} - -func (d decodingAwakeable[T]) Id() string { return d.Awakeable.Id() } -func (d decodingAwakeable[T]) Result() (out T, err error) { - bytes, err := d.Awakeable.Result() - if err != nil { - return out, err - } - if err := json.Unmarshal(bytes, &out); err != nil { - return out, err - } - return -} - -func AwakeableAs[T any](ctx Context) Awakeable[T] { - return decodingAwakeable[T]{Awakeable: ctx.Awakeable()} -} - -func ResolveAwakeableAs[T any](ctx Context, id string, value T) error { - bytes, err := json.Marshal(value) - if err != nil { - return TerminalError(err) - } - ctx.ResolveAwakeable(id, bytes) - return nil -} - -type After interface { - Done() - futures.Selectable -} diff --git a/server/restate.go b/server/restate.go index 366717f..8b60b78 100644 --- a/server/restate.go +++ b/server/restate.go @@ -109,6 +109,7 @@ func (r *Restate) discover() (resource *internal.Endpoint, err error) { Name: name, Input: handler.InputPayload(), Output: handler.OutputPayload(), + Ty: handler.HandlerType(), }) } resource.Services = append(resource.Services, service)