diff --git a/bind.go b/bind.go index cce399203b..6d4d18ad9f 100644 --- a/bind.go +++ b/bind.go @@ -36,12 +36,13 @@ func (d *fieldCtxDecoder) Decode(ctx Ctx, reqValue reflect.Value) error { } type fieldTextDecoder struct { - index int - fieldName string - tag string // query,param,header,respHeader ... - reqField string - dec bind.TextDecoder - get func(c Ctx, key string, defaultValue ...string) string + index int + parentIndex []int + fieldName string + tag string // query,param,header,respHeader ... + reqField string + dec bind.TextDecoder + get func(c Ctx, key string, defaultValue ...string) string } func (d *fieldTextDecoder) Decode(ctx Ctx, reqValue reflect.Value) error { @@ -50,7 +51,18 @@ func (d *fieldTextDecoder) Decode(ctx Ctx, reqValue reflect.Value) error { return nil } - err := d.dec.UnmarshalString(text, reqValue.Field(d.index)) + var err error + if len(d.parentIndex) > 0 { + for _, i := range d.parentIndex { + reqValue = reqValue.Field(i) + } + + err = d.dec.UnmarshalString(text, reqValue.Field(d.index)) + + } else { + err = d.dec.UnmarshalString(text, reqValue.Field(d.index)) + } + if err != nil { return fmt.Errorf("unable to decode '%s' as %s: %w", text, d.reqField, err) } diff --git a/bind_test.go b/bind_test.go index 21cce7f5a0..d339709950 100644 --- a/bind_test.go +++ b/bind_test.go @@ -38,6 +38,32 @@ func Test_Binder(t *testing.T) { require.Equal(t, "john doe", body.Name) } +func Test_Binder_Nested(t *testing.T) { + t.Parallel() + app := New() + + c := app.NewCtx(&fasthttp.RequestCtx{}).(*DefaultCtx) + c.Request().SetBody([]byte(``)) + c.Request().Header.SetContentType("") + c.Request().URI().SetQueryString("name=tom&nested.and.age=10&nested.and.test=john") + + var req struct { + Name string `query:"name"` + Nested struct { + And struct { + Age int `query:"age"` + Test string `query:"test"` + } `query:"and"` + } `query:"nested"` + } + + err := c.Bind().Req(&req).Err() + require.NoError(t, err) + require.Equal(t, "tom", req.Name) + require.Equal(t, "john", req.Nested.And.Test) + require.Equal(t, 10, req.Nested.And.Age) +} + // go test -run Test_Bind_BasicType -v func Test_Bind_BasicType(t *testing.T) { t.Parallel() diff --git a/binder_compile.go b/binder_compile.go index 2085f50606..38e02b56e7 100644 --- a/binder_compile.go +++ b/binder_compile.go @@ -45,13 +45,13 @@ func compileReqParser(rt reflect.Type, opt bindCompileOption) (Decoder, error) { continue } - dec, err := compileFieldDecoder(el.Field(i), i, opt) + dec, err := compileFieldDecoder(el.Field(i), i, opt, parentStruct{}) if err != nil { return nil, err } if dec != nil { - decoders = append(decoders, dec) + decoders = append(decoders, dec...) } } @@ -67,9 +67,14 @@ func compileReqParser(rt reflect.Type, opt bindCompileOption) (Decoder, error) { }, nil } -func compileFieldDecoder(field reflect.StructField, index int, opt bindCompileOption) (decoder, error) { +type parentStruct struct { + tag string + index []int +} + +func compileFieldDecoder(field reflect.StructField, index int, opt bindCompileOption, parent parentStruct) ([]decoder, error) { if reflect.PtrTo(field.Type).Implements(bindUnmarshalerType) { - return &fieldCtxDecoder{index: index, fieldName: field.Name, fieldType: field.Type}, nil + return []decoder{&fieldCtxDecoder{index: index, fieldName: field.Name, fieldType: field.Type}}, nil } var tags = []string{bindTagRespHeader, bindTagQuery, bindTagParam, bindTagHeader, bindTagCookie} @@ -91,6 +96,10 @@ func compileFieldDecoder(field reflect.StructField, index int, opt bindCompileOp tagContent := field.Tag.Get(tagScope) + if parent.tag != "" { + tagContent = parent.tag + "." + tagContent + } + if reflect.PtrTo(field.Type).Implements(textUnmarshalerType) { return compileTextBasedDecoder(field, index, tagScope, tagContent) } @@ -99,7 +108,38 @@ func compileFieldDecoder(field reflect.StructField, index int, opt bindCompileOp return compileSliceFieldTextBasedDecoder(field, index, tagScope, tagContent) } - return compileTextBasedDecoder(field, index, tagScope, tagContent) + // Nested binding support + if field.Type.Kind() == reflect.Struct { + var decoders []decoder + el := field.Type + + for i := 0; i < el.NumField(); i++ { + if !el.Field(i).IsExported() { + // ignore unexported field + continue + } + var indexes []int + if len(parent.index) > 0 { + indexes = append(indexes, parent.index...) + } + indexes = append(indexes, index) + dec, err := compileFieldDecoder(el.Field(i), i, opt, parentStruct{ + tag: tagContent, + index: indexes, + }) + if err != nil { + return nil, err + } + + if dec != nil { + decoders = append(decoders, dec...) + } + } + + return decoders, nil + } + + return compileTextBasedDecoder(field, index, tagScope, tagContent, parent.index) } func formGetter(ctx Ctx, key string, defaultValue ...string) string { @@ -120,7 +160,7 @@ func multipartGetter(ctx Ctx, key string, defaultValue ...string) string { return v[0] } -func compileTextBasedDecoder(field reflect.StructField, index int, tagScope, tagContent string) (decoder, error) { +func compileTextBasedDecoder(field reflect.StructField, index int, tagScope, tagContent string, parentIndex ...[]int) ([]decoder, error) { var get func(ctx Ctx, key string, defaultValue ...string) string switch tagScope { case bindTagQuery: @@ -146,17 +186,23 @@ func compileTextBasedDecoder(field reflect.StructField, index int, tagScope, tag return nil, err } - return &fieldTextDecoder{ + fieldDecoder := &fieldTextDecoder{ index: index, fieldName: field.Name, tag: tagScope, reqField: tagContent, dec: textDecoder, get: get, - }, nil + } + + if len(parentIndex) > 0 { + fieldDecoder.parentIndex = parentIndex[0] + } + + return []decoder{fieldDecoder}, nil } -func compileSliceFieldTextBasedDecoder(field reflect.StructField, index int, tagScope string, tagContent string) (decoder, error) { +func compileSliceFieldTextBasedDecoder(field reflect.StructField, index int, tagScope string, tagContent string) ([]decoder, error) { if field.Type.Kind() != reflect.Slice { panic("BUG: unexpected type, expecting slice " + field.Type.String()) } @@ -190,7 +236,7 @@ func compileSliceFieldTextBasedDecoder(field reflect.StructField, index int, tag return nil, errors.New("unexpected tag scope " + strconv.Quote(tagScope)) } - return &fieldSliceDecoder{ + return []decoder{&fieldSliceDecoder{ fieldIndex: index, eqBytes: eqBytes, fieldName: field.Name, @@ -199,5 +245,5 @@ func compileSliceFieldTextBasedDecoder(field reflect.StructField, index int, tag fieldType: field.Type, elementType: et, elementDecoder: elementUnmarshaler, - }, nil + }}, nil }