diff --git a/encoding/encoding.go b/encoding/encoding.go index 30e4d65..563e23d 100644 --- a/encoding/encoding.go +++ b/encoding/encoding.go @@ -9,11 +9,6 @@ import ( "google.golang.org/protobuf/proto" ) -// Void is a placeholder to signify 'no value' where a type is otherwise needed -// It implements [RestateMarshaler] and [RestateUnmarshaler] to ensure that no marshaling or unmarshaling ever happens -// on this type. -type Void struct{} - var ( // BinaryCodec marshals []byte and unmarshals into *[]byte // In handlers, it uses a content type of application/octet-stream @@ -34,6 +29,11 @@ var ( _ RestateUnmarshaler = &Void{} ) +// Void is a placeholder to signify 'no value' where a type is otherwise needed +// It implements [RestateMarshaler] and [RestateUnmarshaler] to ensure that no marshaling or unmarshaling ever happens +// on this type. +type Void struct{} + func (v Void) RestateUnmarshal(codec Codec, data []byte) error { return nil } diff --git a/reflect.go b/reflect.go index 08d40b4..5daf551 100644 --- a/reflect.go +++ b/reflect.go @@ -19,20 +19,28 @@ var ( typeOfContext = reflect.TypeOf((*Context)(nil)).Elem() typeOfObjectContext = reflect.TypeOf((*ObjectContext)(nil)).Elem() typeOfSharedObjectContext = reflect.TypeOf((*ObjectSharedContext)(nil)).Elem() - typeOfVoid = reflect.TypeOf((*Void)(nil)).Elem() typeOfError = reflect.TypeOf((*error)(nil)).Elem() ) // Reflect converts a struct with methods into a service definition where each correctly-typed // and exported method of the struct will become a handler in the definition. The service name // defaults to the name of the struct, but this can be overidden by providing a `ServiceName() string` method. -// The handler name is the name of the method. Handler methods should be of the type `ServiceHandlerFn[I,O]`, -// `ObjectHandlerFn[I, O]` or `ObjectSharedHandlerFn[I, O]`. This function will panic if a mixture of -// object and service method signatures or opts are provided. +// The handler name is the name of the method. Handler methods should have one of the following signatures: +// - (ctx, I) (O, error) +// - (ctx, I) (O) +// - (ctx, I) (error) +// - (ctx, I) +// - (ctx) +// - (ctx) (error) +// - (ctx) (O) +// - (ctx) (O, error) +// Where ctx is [ObjectContext], [ObjectSharedContext] or [Context]. Other signatures are ignored. +// Signatures without an I or O type will be treated as if [Void] was provided. +// This function will panic if a mixture of object and service method signatures or opts are provided. // -// Input types will be deserialised with the provided codec (defaults to JSON) except when they are restate.Void, +// Input types will be deserialised with the provided codec (defaults to JSON) except when they are [Void], // in which case no input bytes or content type may be sent. -// Output types will be serialised with the provided codec (defaults to JSON) except when they are restate.Void, +// Output types will be serialised with the provided codec (defaults to JSON) except when they are [Void], // in which case no data will be sent and no content type set. func Reflect(rcvr any, opts ...options.ServiceDefinitionOption) ServiceDefinition { typ := reflect.TypeOf(rcvr) @@ -54,8 +62,16 @@ func Reflect(rcvr any, opts ...options.ServiceDefinitionOption) ServiceDefinitio if !method.IsExported() { continue } - // Method needs three ins: receiver, Context, I - if mtype.NumIn() != 3 { + // Method needs 2-3 ins: receiver, Context, optionally I + var input reflect.Type + switch mtype.NumIn() { + case 2: + // (ctx) + input = nil + case 3: + // (ctx, I) + input = mtype.In(2) + default: continue } @@ -87,42 +103,58 @@ func Reflect(rcvr any, opts ...options.ServiceDefinitionOption) ServiceDefinitio continue } - // Method needs two outs: O, and error - if mtype.NumOut() != 2 { - continue - } - - // The second return type of the method must be error. - if returnType := mtype.Out(1); returnType != typeOfError { + // Method needs 0-2 outs: (), (O), (error), (O, error) are all valid + var output reflect.Type + var hasError bool + switch mtype.NumOut() { + case 0: + // () + output = nil + hasError = false + case 1: + if returnType := mtype.Out(0); returnType == typeOfError { + // (error) + output = nil + hasError = true + } else { + // (O) + output = returnType + hasError = false + } + case 2: + if returnType := mtype.Out(1); returnType != typeOfError { + continue + } + // (O, error) + output = mtype.Out(0) + hasError = true + default: continue } - input := mtype.In(2) - output := mtype.Out(0) - switch def := definition.(type) { case *service: - def.Handler(mname, &serviceReflectHandler{ - reflectHandler{ - fn: method.Func, - receiver: val, - input: input, - output: output, - options: options.HandlerOptions{}, - handlerType: nil, - }, - }) + def.Handler(mname, &reflectHandler{ + fn: method.Func, + receiver: val, + input: input, + output: output, + hasError: hasError, + options: options.HandlerOptions{}, + handlerType: nil, + }, + ) case *object: - def.Handler(mname, &objectReflectHandler{ - reflectHandler{ - fn: method.Func, - receiver: val, - input: input, - output: input, - options: options.HandlerOptions{}, - handlerType: &handlerType, - }, - }) + def.Handler(mname, &reflectHandler{ + fn: method.Func, + receiver: val, + input: input, + output: output, + hasError: hasError, + options: options.HandlerOptions{}, + handlerType: &handlerType, + }, + ) } } @@ -138,6 +170,7 @@ type reflectHandler struct { receiver reflect.Value input reflect.Type output reflect.Type + hasError bool options options.HandlerOptions handlerType *internal.ServiceHandlerType } @@ -147,10 +180,16 @@ func (h *reflectHandler) GetOptions() *options.HandlerOptions { } func (h *reflectHandler) InputPayload() *encoding.InputPayload { + if h.input == nil { + return encoding.InputPayloadFor(h.options.Codec, Void{}) + } return encoding.InputPayloadFor(h.options.Codec, reflect.Zero(h.input).Interface()) } func (h *reflectHandler) OutputPayload() *encoding.OutputPayload { + if h.output == nil { + return encoding.OutputPayloadFor(h.options.Codec, Void{}) + } return encoding.OutputPayloadFor(h.options.Codec, reflect.Zero(h.output).Interface()) } @@ -158,30 +197,44 @@ func (h *reflectHandler) HandlerType() *internal.ServiceHandlerType { return h.handlerType } -type objectReflectHandler struct { - reflectHandler -} - -var _ state.Handler = (*objectReflectHandler)(nil) +func (h *reflectHandler) Call(ctx *state.Context, bytes []byte) ([]byte, error) { + var args []reflect.Value + if h.input != nil { + input := reflect.New(h.input) -func (h *objectReflectHandler) Call(ctx *state.Context, bytes []byte) ([]byte, error) { - input := reflect.New(h.input) + if err := encoding.Unmarshal(h.options.Codec, bytes, input.Interface()); err != nil { + return nil, TerminalError(fmt.Errorf("request could not be decoded into handler input type: %w", err), http.StatusBadRequest) + } - if err := encoding.Unmarshal(h.options.Codec, bytes, input.Interface()); err != nil { - return nil, TerminalError(fmt.Errorf("request could not be decoded into handler input type: %w", err), http.StatusBadRequest) + args = []reflect.Value{h.receiver, reflect.ValueOf(ctxWrapper{ctx}), input.Elem()} + } else { + args = []reflect.Value{h.receiver, reflect.ValueOf(ctxWrapper{ctx})} } - // we are sure about the fn signature so it's safe to do this - output := h.fn.Call([]reflect.Value{ - h.receiver, - reflect.ValueOf(ctxWrapper{ctx}), - input.Elem(), - }) - - outI := output[0].Interface() - errI := output[1].Interface() - if errI != nil { - return nil, errI.(error) + output := h.fn.Call(args) + var outI any + + switch [2]bool{h.output != nil, h.hasError} { + case [2]bool{false, false}: + // () + return nil, nil + case [2]bool{false, true}: + // (error) + errI := output[0].Interface() + if errI != nil { + return nil, errI.(error) + } + return nil, nil + case [2]bool{true, false}: + // (O) + outI = output[0].Interface() + case [2]bool{true, true}: + // (O, error) + errI := output[1].Interface() + if errI != nil { + return nil, errI.(error) + } + outI = output[0].Interface() } bytes, err := encoding.Marshal(h.options.Codec, outI) @@ -193,37 +246,4 @@ func (h *objectReflectHandler) Call(ctx *state.Context, bytes []byte) ([]byte, e return bytes, nil } -type serviceReflectHandler struct { - reflectHandler -} - -var _ state.Handler = (*serviceReflectHandler)(nil) - -func (h *serviceReflectHandler) Call(ctx *state.Context, bytes []byte) ([]byte, error) { - input := reflect.New(h.input) - - if err := encoding.Unmarshal(h.options.Codec, bytes, input.Interface()); err != nil { - return nil, TerminalError(fmt.Errorf("request could not be decoded into handler input type: %w", err), http.StatusBadRequest) - } - - // we are sure about the fn signature so it's safe to do this - output := h.fn.Call([]reflect.Value{ - h.receiver, - reflect.ValueOf(ctxWrapper{ctx}), - input.Elem(), - }) - - outI := output[0].Interface() - errI := output[1].Interface() - if errI != nil { - return nil, errI.(error) - } - - bytes, err := encoding.Marshal(h.options.Codec, outI) - if err != nil { - // we don't use a terminal error here as this is hot-fixable by changing the return type - return nil, fmt.Errorf("failed to serialize output: %w", err) - } - - return bytes, nil -} +var _ state.Handler = (*reflectHandler)(nil) diff --git a/reflect_test.go b/reflect_test.go index 3121480..4e4725a 100644 --- a/reflect_test.go +++ b/reflect_test.go @@ -6,6 +6,7 @@ import ( restate "github.com/restatedev/sdk-go" "github.com/restatedev/sdk-go/internal" + "github.com/restatedev/sdk-go/internal/state" "github.com/stretchr/testify/require" ) @@ -25,8 +26,15 @@ var shared = internal.ServiceHandlerType_SHARED var tests []reflectTestParams = []reflectTestParams{ {rcvr: validObject{}, serviceName: "validObject", expectedMethods: expectedMethods{ - "Greet": &exclusive, - "GreetShared": &shared, + "Greet": &exclusive, + "GreetShared": &shared, + "NoInput": &exclusive, + "NoError": &exclusive, + "NoOutput": &exclusive, + "NoOutputNoError": &exclusive, + "NoInputNoError": &exclusive, + "NoInputNoOutput": &exclusive, + "NoInputNoOutputNoError": &exclusive, }}, {rcvr: validService{}, serviceName: "validService", expectedMethods: expectedMethods{ "Greet": nil, @@ -53,17 +61,18 @@ func TestReflect(t *testing.T) { } }() def := restate.Reflect(test.rcvr, test.opts...) - foundMethods := make([]string, 0, len(def.Handlers())) - for k := range def.Handlers() { - foundMethods = append(foundMethods, k) - } - for k, expectedTyp := range test.expectedMethods { - handler, ok := def.Handlers()[k] - if !ok { - t.Fatalf("missing handler %s", k) - } - require.Equal(t, expectedTyp, handler.HandlerType(), "mismatched handler type") + foundMethods := make(map[string]*internal.ServiceHandlerType, len(def.Handlers())) + for k, foundHandler := range def.Handlers() { + t.Run(k, func(t *testing.T) { + foundMethods[k] = foundHandler.HandlerType() + // check for panics + _ = foundHandler.InputPayload() + _ = foundHandler.OutputPayload() + _, err := foundHandler.Call(&state.Context{}, []byte(`""`)) + require.NoError(t, err) + }) } + require.Equal(t, test.expectedMethods, foundMethods) require.Equal(t, test.serviceName, def.Name()) }) } @@ -79,22 +88,40 @@ func (validObject) GreetShared(ctx restate.ObjectSharedContext, _ string) (strin return "", nil } -func (validObject) SkipInvalidArgCount(ctx restate.ObjectContext) (string, error) { +func (validObject) NoInput(ctx restate.ObjectContext) (string, error) { return "", nil } -func (validObject) SkipInvalidCtx(ctx context.Context, _ string) (string, error) { - return "", nil +func (validObject) NoError(ctx restate.ObjectContext, _ string) string { + return "" } -func (validObject) SkipInvalidError(ctx restate.ObjectContext, _ string) (string, string) { - return "", "" +func (validObject) NoOutput(ctx restate.ObjectContext, _ string) error { + return nil } -func (validObject) SkipInvalidRetCount(ctx restate.ObjectContext, _ string) string { +func (validObject) NoOutputNoError(ctx restate.ObjectContext, _ string) { +} + +func (validObject) NoInputNoError(ctx restate.ObjectContext) string { return "" } +func (validObject) NoInputNoOutput(ctx restate.ObjectContext) error { + return nil +} + +func (validObject) NoInputNoOutputNoError(ctx restate.ObjectContext) { +} + +func (validObject) SkipInvalidCtx(ctx context.Context, _ string) (string, error) { + return "", nil +} + +func (validObject) SkipInvalidError(ctx restate.ObjectContext, _ string) (error, string) { + return nil, "" +} + func (validObject) skipUnexported(_ string) (string, error) { return "", nil }