Skip to content

Commit

Permalink
add basic nested binding support (not yet for slices)
Browse files Browse the repository at this point in the history
  • Loading branch information
efectn committed Nov 17, 2022
1 parent d52652e commit 6cb876a
Show file tree
Hide file tree
Showing 3 changed files with 102 additions and 18 deletions.
26 changes: 19 additions & 7 deletions bind.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,13 @@ func (d *fieldCtxDecoder) Decode(ctx Ctx, reqValue reflect.Value) error {
}

type fieldTextDecoder struct {
index int
fieldName string
tag string // query,param,header,respHeader ...
reqField string
dec bind.TextDecoder
get func(c Ctx, key string, defaultValue ...string) string
index int
parentIndex []int
fieldName string
tag string // query,param,header,respHeader ...
reqField string
dec bind.TextDecoder
get func(c Ctx, key string, defaultValue ...string) string
}

func (d *fieldTextDecoder) Decode(ctx Ctx, reqValue reflect.Value) error {
Expand All @@ -50,7 +51,18 @@ func (d *fieldTextDecoder) Decode(ctx Ctx, reqValue reflect.Value) error {
return nil
}

err := d.dec.UnmarshalString(text, reqValue.Field(d.index))
var err error
if len(d.parentIndex) > 0 {
for _, i := range d.parentIndex {
reqValue = reqValue.Field(i)
}

err = d.dec.UnmarshalString(text, reqValue.Field(d.index))

} else {
err = d.dec.UnmarshalString(text, reqValue.Field(d.index))
}

if err != nil {
return fmt.Errorf("unable to decode '%s' as %s: %w", text, d.reqField, err)
}
Expand Down
26 changes: 26 additions & 0 deletions bind_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,32 @@ func Test_Binder(t *testing.T) {
require.Equal(t, "john doe", body.Name)
}

func Test_Binder_Nested(t *testing.T) {
t.Parallel()
app := New()

c := app.NewCtx(&fasthttp.RequestCtx{}).(*DefaultCtx)
c.Request().SetBody([]byte(``))
c.Request().Header.SetContentType("")
c.Request().URI().SetQueryString("name=tom&nested.and.age=10&nested.and.test=john")

var req struct {
Name string `query:"name"`
Nested struct {
And struct {
Age int `query:"age"`
Test string `query:"test"`
} `query:"and"`
} `query:"nested"`
}

err := c.Bind().Req(&req).Err()
require.NoError(t, err)
require.Equal(t, "tom", req.Name)
require.Equal(t, "john", req.Nested.And.Test)
require.Equal(t, 10, req.Nested.And.Age)
}

// go test -run Test_Bind_BasicType -v
func Test_Bind_BasicType(t *testing.T) {
t.Parallel()
Expand Down
68 changes: 57 additions & 11 deletions binder_compile.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,13 +45,13 @@ func compileReqParser(rt reflect.Type, opt bindCompileOption) (Decoder, error) {
continue
}

dec, err := compileFieldDecoder(el.Field(i), i, opt)
dec, err := compileFieldDecoder(el.Field(i), i, opt, parentStruct{})
if err != nil {
return nil, err
}

if dec != nil {
decoders = append(decoders, dec)
decoders = append(decoders, dec...)
}
}

Expand All @@ -67,9 +67,14 @@ func compileReqParser(rt reflect.Type, opt bindCompileOption) (Decoder, error) {
}, nil
}

func compileFieldDecoder(field reflect.StructField, index int, opt bindCompileOption) (decoder, error) {
type parentStruct struct {
tag string
index []int
}

func compileFieldDecoder(field reflect.StructField, index int, opt bindCompileOption, parent parentStruct) ([]decoder, error) {
if reflect.PtrTo(field.Type).Implements(bindUnmarshalerType) {
return &fieldCtxDecoder{index: index, fieldName: field.Name, fieldType: field.Type}, nil
return []decoder{&fieldCtxDecoder{index: index, fieldName: field.Name, fieldType: field.Type}}, nil
}

var tags = []string{bindTagRespHeader, bindTagQuery, bindTagParam, bindTagHeader, bindTagCookie}
Expand All @@ -91,6 +96,10 @@ func compileFieldDecoder(field reflect.StructField, index int, opt bindCompileOp

tagContent := field.Tag.Get(tagScope)

if parent.tag != "" {
tagContent = parent.tag + "." + tagContent
}

if reflect.PtrTo(field.Type).Implements(textUnmarshalerType) {
return compileTextBasedDecoder(field, index, tagScope, tagContent)
}
Expand All @@ -99,7 +108,38 @@ func compileFieldDecoder(field reflect.StructField, index int, opt bindCompileOp
return compileSliceFieldTextBasedDecoder(field, index, tagScope, tagContent)
}

return compileTextBasedDecoder(field, index, tagScope, tagContent)
// Nested binding support
if field.Type.Kind() == reflect.Struct {
var decoders []decoder
el := field.Type

for i := 0; i < el.NumField(); i++ {
if !el.Field(i).IsExported() {
// ignore unexported field
continue
}
var indexes []int
if len(parent.index) > 0 {
indexes = append(indexes, parent.index...)
}
indexes = append(indexes, index)
dec, err := compileFieldDecoder(el.Field(i), i, opt, parentStruct{
tag: tagContent,
index: indexes,
})
if err != nil {
return nil, err
}

if dec != nil {
decoders = append(decoders, dec...)
}
}

return decoders, nil
}

return compileTextBasedDecoder(field, index, tagScope, tagContent, parent.index)
}

func formGetter(ctx Ctx, key string, defaultValue ...string) string {
Expand All @@ -120,7 +160,7 @@ func multipartGetter(ctx Ctx, key string, defaultValue ...string) string {
return v[0]
}

func compileTextBasedDecoder(field reflect.StructField, index int, tagScope, tagContent string) (decoder, error) {
func compileTextBasedDecoder(field reflect.StructField, index int, tagScope, tagContent string, parentIndex ...[]int) ([]decoder, error) {
var get func(ctx Ctx, key string, defaultValue ...string) string
switch tagScope {
case bindTagQuery:
Expand All @@ -146,17 +186,23 @@ func compileTextBasedDecoder(field reflect.StructField, index int, tagScope, tag
return nil, err
}

return &fieldTextDecoder{
fieldDecoder := &fieldTextDecoder{
index: index,
fieldName: field.Name,
tag: tagScope,
reqField: tagContent,
dec: textDecoder,
get: get,
}, nil
}

if len(parentIndex) > 0 {
fieldDecoder.parentIndex = parentIndex[0]
}

return []decoder{fieldDecoder}, nil
}

func compileSliceFieldTextBasedDecoder(field reflect.StructField, index int, tagScope string, tagContent string) (decoder, error) {
func compileSliceFieldTextBasedDecoder(field reflect.StructField, index int, tagScope string, tagContent string) ([]decoder, error) {
if field.Type.Kind() != reflect.Slice {
panic("BUG: unexpected type, expecting slice " + field.Type.String())
}
Expand Down Expand Up @@ -190,7 +236,7 @@ func compileSliceFieldTextBasedDecoder(field reflect.StructField, index int, tag
return nil, errors.New("unexpected tag scope " + strconv.Quote(tagScope))
}

return &fieldSliceDecoder{
return []decoder{&fieldSliceDecoder{
fieldIndex: index,
eqBytes: eqBytes,
fieldName: field.Name,
Expand All @@ -199,5 +245,5 @@ func compileSliceFieldTextBasedDecoder(field reflect.StructField, index int, tag
fieldType: field.Type,
elementType: et,
elementDecoder: elementUnmarshaler,
}, nil
}}, nil
}

0 comments on commit 6cb876a

Please sign in to comment.