Skip to content

Commit

Permalink
Fix i/o payloads in reflect
Browse files Browse the repository at this point in the history
  • Loading branch information
jackkleeman committed Aug 21, 2024
1 parent e582947 commit 8259ee1
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 17 deletions.
10 changes: 5 additions & 5 deletions encoding/encoding.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,6 @@ import (
"google.golang.org/protobuf/proto"
)

// Void is a placeholder to signify 'no value' where a type is otherwise needed
// It implements [RestateMarshaler] and [RestateUnmarshaler] to ensure that no marshaling or unmarshaling ever happens
// on this type.
type Void struct{}

var (
// BinaryCodec marshals []byte and unmarshals into *[]byte
// In handlers, it uses a content type of application/octet-stream
Expand All @@ -34,6 +29,11 @@ var (
_ RestateUnmarshaler = &Void{}
)

// Void is a placeholder to signify 'no value' where a type is otherwise needed
// It implements [RestateMarshaler] and [RestateUnmarshaler] to ensure that no marshaling or unmarshaling ever happens
// on this type.
type Void struct{}

func (v Void) RestateUnmarshal(codec Codec, data []byte) error {
return nil
}
Expand Down
32 changes: 21 additions & 11 deletions reflect.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ var (
typeOfContext = reflect.TypeOf((*Context)(nil)).Elem()
typeOfObjectContext = reflect.TypeOf((*ObjectContext)(nil)).Elem()
typeOfSharedObjectContext = reflect.TypeOf((*ObjectSharedContext)(nil)).Elem()
typeOfVoid = reflect.TypeOf((*Void)(nil)).Elem()
typeOfError = reflect.TypeOf((*error)(nil)).Elem()
)

Expand All @@ -36,11 +35,12 @@ var (
// - (ctx) (O)
// - (ctx) (O, error)
// Where ctx is [ObjectContext], [ObjectSharedContext] or [Context]. Other signatures are ignored.
// Signatures without an I or O type will be treated as if [Void] was provided.
// This function will panic if a mixture of object and service method signatures or opts are provided.
//
// Input types will be deserialised with the provided codec (defaults to JSON) except when they are restate.Void,
// Input types will be deserialised with the provided codec (defaults to JSON) except when they are [Void],
// in which case no input bytes or content type may be sent.
// Output types will be serialised with the provided codec (defaults to JSON) except when they are restate.Void,
// Output types will be serialised with the provided codec (defaults to JSON) except when they are [Void],
// in which case no data will be sent and no content type set.
func Reflect(rcvr any, opts ...options.ServiceDefinitionOption) ServiceDefinition {
typ := reflect.TypeOf(rcvr)
Expand All @@ -63,8 +63,15 @@ func Reflect(rcvr any, opts ...options.ServiceDefinitionOption) ServiceDefinitio
continue
}
// Method needs 2-3 ins: receiver, Context, optionally I
numIn := mtype.NumIn()
if numIn < 2 || numIn > 3 {
var input reflect.Type
switch mtype.NumIn() {
case 2:
// (ctx)
input = nil
case 3:
// (ctx, I)
input = mtype.In(2)
default:
continue
}

Expand Down Expand Up @@ -110,24 +117,21 @@ func Reflect(rcvr any, opts ...options.ServiceDefinitionOption) ServiceDefinitio
output = nil
hasError = true
} else {
// (O)
output = returnType
hasError = false
}
case 2:
if returnType := mtype.Out(1); returnType != typeOfError {
continue
}
// (O, error)
output = mtype.Out(0)
hasError = true
default:
continue
}

var input reflect.Type
if numIn > 2 {
input = mtype.In(2)
}

switch def := definition.(type) {
case *service:
def.Handler(mname, &reflectHandler{
Expand All @@ -145,7 +149,7 @@ func Reflect(rcvr any, opts ...options.ServiceDefinitionOption) ServiceDefinitio
fn: method.Func,
receiver: val,
input: input,
output: input,
output: output,
hasError: hasError,
options: options.HandlerOptions{},
handlerType: &handlerType,
Expand Down Expand Up @@ -176,10 +180,16 @@ func (h *reflectHandler) GetOptions() *options.HandlerOptions {
}

func (h *reflectHandler) InputPayload() *encoding.InputPayload {
if h.input == nil {
return encoding.InputPayloadFor(h.options.Codec, Void{})
}
return encoding.InputPayloadFor(h.options.Codec, reflect.Zero(h.input).Interface())
}

func (h *reflectHandler) OutputPayload() *encoding.OutputPayload {
if h.output == nil {
return encoding.OutputPayloadFor(h.options.Codec, Void{})
}
return encoding.OutputPayloadFor(h.options.Codec, reflect.Zero(h.output).Interface())
}

Expand Down
10 changes: 9 additions & 1 deletion reflect_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (

restate "github.com/restatedev/sdk-go"
"github.com/restatedev/sdk-go/internal"
"github.com/restatedev/sdk-go/internal/state"
"github.com/stretchr/testify/require"
)

Expand Down Expand Up @@ -62,7 +63,14 @@ func TestReflect(t *testing.T) {
def := restate.Reflect(test.rcvr, test.opts...)
foundMethods := make(map[string]*internal.ServiceHandlerType, len(def.Handlers()))
for k, foundHandler := range def.Handlers() {
foundMethods[k] = foundHandler.HandlerType()
t.Run(k, func(t *testing.T) {
foundMethods[k] = foundHandler.HandlerType()
// check for panics
_ = foundHandler.InputPayload()
_ = foundHandler.OutputPayload()
_, err := foundHandler.Call(&state.Context{}, []byte(`""`))
require.NoError(t, err)
})
}
require.Equal(t, test.expectedMethods, foundMethods)
require.Equal(t, test.serviceName, def.Name())
Expand Down

0 comments on commit 8259ee1

Please sign in to comment.