diff --git a/bind_test.go b/bind_test.go index d339709950..ae135ec83f 100644 --- a/bind_test.go +++ b/bind_test.go @@ -64,6 +64,33 @@ func Test_Binder_Nested(t *testing.T) { require.Equal(t, 10, req.Nested.And.Age) } +func Test_Binder_Nested_Slice(t *testing.T) { + t.Parallel() + app := New() + + c := app.NewCtx(&fasthttp.RequestCtx{}).(*DefaultCtx) + c.Request().SetBody([]byte(``)) + c.Request().Header.SetContentType("") + c.Request().URI().SetQueryString("name=tom&data[0][name]=john&data[0][age]=10&data[1][name]=doe&data[1][age]=12") + + var req struct { + Name string `query:"name"` + Data []struct { + Name string `query:"name"` + Age int `query:"age"` + } `query:"data"` + } + + err := c.Bind().Req(&req).Err() + require.NoError(t, err) + require.Equal(t, 2, len(req.Data)) + require.Equal(t, "john", req.Data[0].Name) + require.Equal(t, 10, req.Data[0].Age) + require.Equal(t, "doe", req.Data[1].Name) + require.Equal(t, 12, req.Data[1].Age) + require.Equal(t, "tom", req.Name) +} + // go test -run Test_Bind_BasicType -v func Test_Bind_BasicType(t *testing.T) { t.Parallel() diff --git a/binder_compile.go b/binder_compile.go index 38e02b56e7..4e0331a014 100644 --- a/binder_compile.go +++ b/binder_compile.go @@ -72,17 +72,12 @@ type parentStruct struct { index []int } -func compileFieldDecoder(field reflect.StructField, index int, opt bindCompileOption, parent parentStruct) ([]decoder, error) { - if reflect.PtrTo(field.Type).Implements(bindUnmarshalerType) { - return []decoder{&fieldCtxDecoder{index: index, fieldName: field.Name, fieldType: field.Type}}, nil - } - +func lookupTagScope(field reflect.StructField, opt bindCompileOption) (tagScope string) { var tags = []string{bindTagRespHeader, bindTagQuery, bindTagParam, bindTagHeader, bindTagCookie} if opt.bodyDecoder { tags = []string{bindTagForm, bindTagMultipart} } - var tagScope = "" for _, loopTagScope := range tags { if _, ok := field.Tag.Lookup(loopTagScope); ok { tagScope = loopTagScope @@ -90,6 +85,15 @@ func compileFieldDecoder(field reflect.StructField, index int, opt bindCompileOp } } + return +} + +func compileFieldDecoder(field reflect.StructField, index int, opt bindCompileOption, parent parentStruct) ([]decoder, error) { + if reflect.PtrTo(field.Type).Implements(bindUnmarshalerType) { + return []decoder{&fieldCtxDecoder{index: index, fieldName: field.Name, fieldType: field.Type}}, nil + } + + tagScope := lookupTagScope(field, opt) if tagScope == "" { return nil, nil } @@ -202,15 +206,56 @@ func compileTextBasedDecoder(field reflect.StructField, index int, tagScope, tag return []decoder{fieldDecoder}, nil } +type subElem struct { + et reflect.Type + tag string + index int + elementDecoder bind.TextDecoder +} + func compileSliceFieldTextBasedDecoder(field reflect.StructField, index int, tagScope string, tagContent string) ([]decoder, error) { if field.Type.Kind() != reflect.Slice { panic("BUG: unexpected type, expecting slice " + field.Type.String()) } + var elems []subElem + var elementUnmarshaler bind.TextDecoder + var err error + et := field.Type.Elem() - elementUnmarshaler, err := bind.CompileTextDecoder(et) - if err != nil { - return nil, fmt.Errorf("failed to build slice binder: %w", err) + if et.Kind() == reflect.Struct { + elems = make([]subElem, et.NumField()) + for i := 0; i < et.NumField(); i++ { + if !et.Field(i).IsExported() { + // ignore unexported field + continue + } + + // Skip different tag scopes (main -> sub) + subScope := lookupTagScope(et.Field(i), bindCompileOption{}) + if subScope != tagScope { + continue + } + + elementUnmarshaler, err := bind.CompileTextDecoder(et.Field(i).Type) + if err != nil { + return nil, fmt.Errorf("failed to build slice binder: %w", err) + } + + elem := subElem{ + index: i, + tag: et.Field(i).Tag.Get(subScope), + et: et.Field(i).Type, + elementDecoder: elementUnmarshaler, + } + + elems = append(elems, elem) + } + } else { + elementUnmarshaler, err = bind.CompileTextDecoder(et) + if err != nil { + return nil, fmt.Errorf("failed to build slice binder: %w", err) + } } var eqBytes = bytes.Equal @@ -236,7 +281,8 @@ func compileSliceFieldTextBasedDecoder(field reflect.StructField, index int, tag return nil, errors.New("unexpected tag scope " + strconv.Quote(tagScope)) } - return []decoder{&fieldSliceDecoder{ + fieldSliceDecoder := &fieldSliceDecoder{ + elems: elems, fieldIndex: index, eqBytes: eqBytes, fieldName: field.Name, @@ -245,5 +291,7 @@ func compileSliceFieldTextBasedDecoder(field reflect.StructField, index int, tag fieldType: field.Type, elementType: et, elementDecoder: elementUnmarshaler, - }}, nil + } + + return []decoder{fieldSliceDecoder}, nil } diff --git a/binder_slice.go b/binder_slice.go index e3eb828c0e..409c19079c 100644 --- a/binder_slice.go +++ b/binder_slice.go @@ -1,7 +1,9 @@ package fiber import ( + "bytes" "reflect" + "strconv" "github.com/gofiber/fiber/v3/internal/bind" "github.com/gofiber/utils/v2" @@ -11,6 +13,7 @@ var _ decoder = (*fieldSliceDecoder)(nil) type fieldSliceDecoder struct { fieldIndex int + elems []subElem fieldName string fieldType reflect.Type reqKey []byte @@ -22,6 +25,10 @@ type fieldSliceDecoder struct { } func (d *fieldSliceDecoder) Decode(ctx Ctx, reqValue reflect.Value) error { + if d.elementType.Kind() == reflect.Struct { + return d.decodeStruct(ctx, reqValue) + } + count := 0 d.visitAll(ctx, func(key, value []byte) { if d.eqBytes(key, d.reqKey) { @@ -59,6 +66,88 @@ func (d *fieldSliceDecoder) Decode(ctx Ctx, reqValue reflect.Value) error { return nil } +func (d *fieldSliceDecoder) decodeStruct(ctx Ctx, reqValue reflect.Value) error { + var maxNum int + d.visitAll(ctx, func(key, value []byte) { + start := bytes.IndexByte(key, byte('[')) + end := bytes.IndexByte(key, byte(']')) + + if start != -1 || end != -1 { + num := utils.UnsafeString(key[start+1 : end]) + + if len(num) > 0 { + maxNum, _ = strconv.Atoi(num) + } + } + }) + + if maxNum != 0 { + maxNum += 1 + } + + rv := reflect.MakeSlice(d.fieldType, maxNum, maxNum) + if maxNum == 0 { + reqValue.Field(d.fieldIndex).Set(rv) + return nil + } + + var err error + d.visitAll(ctx, func(key, value []byte) { + if err != nil { + return + } + + if bytes.IndexByte(key, byte('[')) == -1 { + return + } + + // TODO: support queries like data[0][users][0][name] + ints := make([]int, 0) + elems := make([]string, 0) + + // nested + lookupKey := key + for { + start := bytes.IndexByte(lookupKey, byte('[')) + end := bytes.IndexByte(lookupKey, byte(']')) + + if start == -1 || end == -1 { + break + } + + content := utils.UnsafeString(lookupKey[start+1 : end]) + num, errElse := strconv.Atoi(content) + + if errElse == nil { + ints = append(ints, num) + } else { + elems = append(elems, content) + } + + lookupKey = lookupKey[end+1:] + } + + for _, elem := range d.elems { + if elems[0] == elem.tag { + ev := reflect.New(elem.et) + if ee := elem.elementDecoder.UnmarshalString(utils.UnsafeString(value), ev.Elem()); ee != nil { + err = ee + } + + i := rv.Index(ints[0]) + i.Field(elem.index).Set(ev.Elem()) + } + } + }) + + if err != nil { + return err + } + + reqValue.Field(d.fieldIndex).Set(rv) + return nil +} + func visitQuery(ctx Ctx, f func(key []byte, value []byte)) { ctx.Context().QueryArgs().VisitAll(f) }