Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow more method signatures in .Reflect() #33

Merged
merged 2 commits into from
Aug 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions encoding/encoding.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,6 @@ import (
"google.golang.org/protobuf/proto"
)

// Void is a placeholder to signify 'no value' where a type is otherwise needed
// It implements [RestateMarshaler] and [RestateUnmarshaler] to ensure that no marshaling or unmarshaling ever happens
// on this type.
type Void struct{}

var (
// BinaryCodec marshals []byte and unmarshals into *[]byte
// In handlers, it uses a content type of application/octet-stream
Expand All @@ -34,6 +29,11 @@ var (
_ RestateUnmarshaler = &Void{}
)

// Void is a placeholder to signify 'no value' where a type is otherwise needed
// It implements [RestateMarshaler] and [RestateUnmarshaler] to ensure that no marshaling or unmarshaling ever happens
// on this type.
type Void struct{}

func (v Void) RestateUnmarshal(codec Codec, data []byte) error {
return nil
}
Expand Down
204 changes: 112 additions & 92 deletions reflect.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,20 +19,28 @@ var (
typeOfContext = reflect.TypeOf((*Context)(nil)).Elem()
typeOfObjectContext = reflect.TypeOf((*ObjectContext)(nil)).Elem()
typeOfSharedObjectContext = reflect.TypeOf((*ObjectSharedContext)(nil)).Elem()
typeOfVoid = reflect.TypeOf((*Void)(nil)).Elem()
typeOfError = reflect.TypeOf((*error)(nil)).Elem()
)

// Reflect converts a struct with methods into a service definition where each correctly-typed
// and exported method of the struct will become a handler in the definition. The service name
// defaults to the name of the struct, but this can be overidden by providing a `ServiceName() string` method.
// The handler name is the name of the method. Handler methods should be of the type `ServiceHandlerFn[I,O]`,
// `ObjectHandlerFn[I, O]` or `ObjectSharedHandlerFn[I, O]`. This function will panic if a mixture of
// object and service method signatures or opts are provided.
// The handler name is the name of the method. Handler methods should have one of the following signatures:
// - (ctx, I) (O, error)
// - (ctx, I) (O)
// - (ctx, I) (error)
// - (ctx, I)
// - (ctx)
// - (ctx) (error)
// - (ctx) (O)
// - (ctx) (O, error)
// Where ctx is [ObjectContext], [ObjectSharedContext] or [Context]. Other signatures are ignored.
// Signatures without an I or O type will be treated as if [Void] was provided.
// This function will panic if a mixture of object and service method signatures or opts are provided.
//
// Input types will be deserialised with the provided codec (defaults to JSON) except when they are restate.Void,
// Input types will be deserialised with the provided codec (defaults to JSON) except when they are [Void],
// in which case no input bytes or content type may be sent.
// Output types will be serialised with the provided codec (defaults to JSON) except when they are restate.Void,
// Output types will be serialised with the provided codec (defaults to JSON) except when they are [Void],
// in which case no data will be sent and no content type set.
func Reflect(rcvr any, opts ...options.ServiceDefinitionOption) ServiceDefinition {
typ := reflect.TypeOf(rcvr)
Expand All @@ -54,8 +62,16 @@ func Reflect(rcvr any, opts ...options.ServiceDefinitionOption) ServiceDefinitio
if !method.IsExported() {
continue
}
// Method needs three ins: receiver, Context, I
if mtype.NumIn() != 3 {
// Method needs 2-3 ins: receiver, Context, optionally I
var input reflect.Type
switch mtype.NumIn() {
case 2:
// (ctx)
input = nil
case 3:
// (ctx, I)
input = mtype.In(2)
default:
continue
}

Expand Down Expand Up @@ -87,42 +103,58 @@ func Reflect(rcvr any, opts ...options.ServiceDefinitionOption) ServiceDefinitio
continue
}

// Method needs two outs: O, and error
if mtype.NumOut() != 2 {
continue
}

// The second return type of the method must be error.
if returnType := mtype.Out(1); returnType != typeOfError {
// Method needs 0-2 outs: (), (O), (error), (O, error) are all valid
var output reflect.Type
var hasError bool
switch mtype.NumOut() {
case 0:
// ()
output = nil
hasError = false
case 1:
if returnType := mtype.Out(0); returnType == typeOfError {
// (error)
output = nil
hasError = true
} else {
// (O)
output = returnType
hasError = false
}
case 2:
if returnType := mtype.Out(1); returnType != typeOfError {
continue
}
// (O, error)
output = mtype.Out(0)
hasError = true
default:
continue
}

input := mtype.In(2)
output := mtype.Out(0)

switch def := definition.(type) {
case *service:
def.Handler(mname, &serviceReflectHandler{
reflectHandler{
fn: method.Func,
receiver: val,
input: input,
output: output,
options: options.HandlerOptions{},
handlerType: nil,
},
})
def.Handler(mname, &reflectHandler{
fn: method.Func,
receiver: val,
input: input,
output: output,
hasError: hasError,
options: options.HandlerOptions{},
handlerType: nil,
},
)
case *object:
def.Handler(mname, &objectReflectHandler{
reflectHandler{
fn: method.Func,
receiver: val,
input: input,
output: input,
options: options.HandlerOptions{},
handlerType: &handlerType,
},
})
def.Handler(mname, &reflectHandler{
fn: method.Func,
receiver: val,
input: input,
output: output,
hasError: hasError,
options: options.HandlerOptions{},
handlerType: &handlerType,
},
)
}
}

Expand All @@ -138,6 +170,7 @@ type reflectHandler struct {
receiver reflect.Value
input reflect.Type
output reflect.Type
hasError bool
options options.HandlerOptions
handlerType *internal.ServiceHandlerType
}
Expand All @@ -147,41 +180,61 @@ func (h *reflectHandler) GetOptions() *options.HandlerOptions {
}

func (h *reflectHandler) InputPayload() *encoding.InputPayload {
if h.input == nil {
return encoding.InputPayloadFor(h.options.Codec, Void{})
}
return encoding.InputPayloadFor(h.options.Codec, reflect.Zero(h.input).Interface())
}

func (h *reflectHandler) OutputPayload() *encoding.OutputPayload {
if h.output == nil {
return encoding.OutputPayloadFor(h.options.Codec, Void{})
}
return encoding.OutputPayloadFor(h.options.Codec, reflect.Zero(h.output).Interface())
}

func (h *reflectHandler) HandlerType() *internal.ServiceHandlerType {
return h.handlerType
}

type objectReflectHandler struct {
reflectHandler
}

var _ state.Handler = (*objectReflectHandler)(nil)
func (h *reflectHandler) Call(ctx *state.Context, bytes []byte) ([]byte, error) {
var args []reflect.Value
if h.input != nil {
input := reflect.New(h.input)

func (h *objectReflectHandler) Call(ctx *state.Context, bytes []byte) ([]byte, error) {
input := reflect.New(h.input)
if err := encoding.Unmarshal(h.options.Codec, bytes, input.Interface()); err != nil {
return nil, TerminalError(fmt.Errorf("request could not be decoded into handler input type: %w", err), http.StatusBadRequest)
}

if err := encoding.Unmarshal(h.options.Codec, bytes, input.Interface()); err != nil {
return nil, TerminalError(fmt.Errorf("request could not be decoded into handler input type: %w", err), http.StatusBadRequest)
args = []reflect.Value{h.receiver, reflect.ValueOf(ctxWrapper{ctx}), input.Elem()}
} else {
args = []reflect.Value{h.receiver, reflect.ValueOf(ctxWrapper{ctx})}
}

// we are sure about the fn signature so it's safe to do this
output := h.fn.Call([]reflect.Value{
h.receiver,
reflect.ValueOf(ctxWrapper{ctx}),
input.Elem(),
})

outI := output[0].Interface()
errI := output[1].Interface()
if errI != nil {
return nil, errI.(error)
output := h.fn.Call(args)
var outI any

switch [2]bool{h.output != nil, h.hasError} {
case [2]bool{false, false}:
// ()
return nil, nil
case [2]bool{false, true}:
// (error)
errI := output[0].Interface()
if errI != nil {
return nil, errI.(error)
}
return nil, nil
case [2]bool{true, false}:
// (O)
outI = output[0].Interface()
case [2]bool{true, true}:
// (O, error)
errI := output[1].Interface()
if errI != nil {
return nil, errI.(error)
}
outI = output[0].Interface()
}

bytes, err := encoding.Marshal(h.options.Codec, outI)
Expand All @@ -193,37 +246,4 @@ func (h *objectReflectHandler) Call(ctx *state.Context, bytes []byte) ([]byte, e
return bytes, nil
}

type serviceReflectHandler struct {
reflectHandler
}

var _ state.Handler = (*serviceReflectHandler)(nil)

func (h *serviceReflectHandler) Call(ctx *state.Context, bytes []byte) ([]byte, error) {
input := reflect.New(h.input)

if err := encoding.Unmarshal(h.options.Codec, bytes, input.Interface()); err != nil {
return nil, TerminalError(fmt.Errorf("request could not be decoded into handler input type: %w", err), http.StatusBadRequest)
}

// we are sure about the fn signature so it's safe to do this
output := h.fn.Call([]reflect.Value{
h.receiver,
reflect.ValueOf(ctxWrapper{ctx}),
input.Elem(),
})

outI := output[0].Interface()
errI := output[1].Interface()
if errI != nil {
return nil, errI.(error)
}

bytes, err := encoding.Marshal(h.options.Codec, outI)
if err != nil {
// we don't use a terminal error here as this is hot-fixable by changing the return type
return nil, fmt.Errorf("failed to serialize output: %w", err)
}

return bytes, nil
}
var _ state.Handler = (*reflectHandler)(nil)
63 changes: 45 additions & 18 deletions reflect_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (

restate "github.com/restatedev/sdk-go"
"github.com/restatedev/sdk-go/internal"
"github.com/restatedev/sdk-go/internal/state"
"github.com/stretchr/testify/require"
)

Expand All @@ -25,8 +26,15 @@ var shared = internal.ServiceHandlerType_SHARED

var tests []reflectTestParams = []reflectTestParams{
{rcvr: validObject{}, serviceName: "validObject", expectedMethods: expectedMethods{
"Greet": &exclusive,
"GreetShared": &shared,
"Greet": &exclusive,
"GreetShared": &shared,
"NoInput": &exclusive,
"NoError": &exclusive,
"NoOutput": &exclusive,
"NoOutputNoError": &exclusive,
"NoInputNoError": &exclusive,
"NoInputNoOutput": &exclusive,
"NoInputNoOutputNoError": &exclusive,
}},
{rcvr: validService{}, serviceName: "validService", expectedMethods: expectedMethods{
"Greet": nil,
Expand All @@ -53,17 +61,18 @@ func TestReflect(t *testing.T) {
}
}()
def := restate.Reflect(test.rcvr, test.opts...)
foundMethods := make([]string, 0, len(def.Handlers()))
for k := range def.Handlers() {
foundMethods = append(foundMethods, k)
}
for k, expectedTyp := range test.expectedMethods {
handler, ok := def.Handlers()[k]
if !ok {
t.Fatalf("missing handler %s", k)
}
require.Equal(t, expectedTyp, handler.HandlerType(), "mismatched handler type")
foundMethods := make(map[string]*internal.ServiceHandlerType, len(def.Handlers()))
for k, foundHandler := range def.Handlers() {
t.Run(k, func(t *testing.T) {
foundMethods[k] = foundHandler.HandlerType()
// check for panics
_ = foundHandler.InputPayload()
_ = foundHandler.OutputPayload()
_, err := foundHandler.Call(&state.Context{}, []byte(`""`))
require.NoError(t, err)
})
}
require.Equal(t, test.expectedMethods, foundMethods)
require.Equal(t, test.serviceName, def.Name())
})
}
Expand All @@ -79,22 +88,40 @@ func (validObject) GreetShared(ctx restate.ObjectSharedContext, _ string) (strin
return "", nil
}

func (validObject) SkipInvalidArgCount(ctx restate.ObjectContext) (string, error) {
func (validObject) NoInput(ctx restate.ObjectContext) (string, error) {
return "", nil
}

func (validObject) SkipInvalidCtx(ctx context.Context, _ string) (string, error) {
return "", nil
func (validObject) NoError(ctx restate.ObjectContext, _ string) string {
return ""
}

func (validObject) SkipInvalidError(ctx restate.ObjectContext, _ string) (string, string) {
return "", ""
func (validObject) NoOutput(ctx restate.ObjectContext, _ string) error {
return nil
}

func (validObject) SkipInvalidRetCount(ctx restate.ObjectContext, _ string) string {
func (validObject) NoOutputNoError(ctx restate.ObjectContext, _ string) {
}

func (validObject) NoInputNoError(ctx restate.ObjectContext) string {
return ""
}

func (validObject) NoInputNoOutput(ctx restate.ObjectContext) error {
return nil
}

func (validObject) NoInputNoOutputNoError(ctx restate.ObjectContext) {
}

func (validObject) SkipInvalidCtx(ctx context.Context, _ string) (string, error) {
return "", nil
}

func (validObject) SkipInvalidError(ctx restate.ObjectContext, _ string) (error, string) {
return nil, ""
}

func (validObject) skipUnexported(_ string) (string, error) {
return "", nil
}
Expand Down
Loading