Skip to content

Commit

Permalink
add support for queries like data[0][name] (not yet supporting deeper…
Browse files Browse the repository at this point in the history
… nested levels)
  • Loading branch information
efectn committed Nov 20, 2022
1 parent 6cb876a commit 3661d33
Show file tree
Hide file tree
Showing 3 changed files with 175 additions and 11 deletions.
27 changes: 27 additions & 0 deletions bind_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,33 @@ func Test_Binder_Nested(t *testing.T) {
require.Equal(t, 10, req.Nested.And.Age)
}

func Test_Binder_Nested_Slice(t *testing.T) {
t.Parallel()
app := New()

c := app.NewCtx(&fasthttp.RequestCtx{}).(*DefaultCtx)
c.Request().SetBody([]byte(``))
c.Request().Header.SetContentType("")
c.Request().URI().SetQueryString("name=tom&data[0][name]=john&data[0][age]=10&data[1][name]=doe&data[1][age]=12")

var req struct {
Name string `query:"name"`
Data []struct {
Name string `query:"name"`
Age int `query:"age"`
} `query:"data"`
}

err := c.Bind().Req(&req).Err()
require.NoError(t, err)
require.Equal(t, 2, len(req.Data))
require.Equal(t, "john", req.Data[0].Name)
require.Equal(t, 10, req.Data[0].Age)
require.Equal(t, "doe", req.Data[1].Name)
require.Equal(t, 12, req.Data[1].Age)
require.Equal(t, "tom", req.Name)
}

// go test -run Test_Bind_BasicType -v
func Test_Bind_BasicType(t *testing.T) {
t.Parallel()
Expand Down
70 changes: 59 additions & 11 deletions binder_compile.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,24 +72,28 @@ type parentStruct struct {
index []int
}

func compileFieldDecoder(field reflect.StructField, index int, opt bindCompileOption, parent parentStruct) ([]decoder, error) {
if reflect.PtrTo(field.Type).Implements(bindUnmarshalerType) {
return []decoder{&fieldCtxDecoder{index: index, fieldName: field.Name, fieldType: field.Type}}, nil
}

func lookupTagScope(field reflect.StructField, opt bindCompileOption) (tagScope string) {
var tags = []string{bindTagRespHeader, bindTagQuery, bindTagParam, bindTagHeader, bindTagCookie}
if opt.bodyDecoder {
tags = []string{bindTagForm, bindTagMultipart}
}

var tagScope = ""
for _, loopTagScope := range tags {
if _, ok := field.Tag.Lookup(loopTagScope); ok {
tagScope = loopTagScope
break
}
}

return
}

func compileFieldDecoder(field reflect.StructField, index int, opt bindCompileOption, parent parentStruct) ([]decoder, error) {
if reflect.PtrTo(field.Type).Implements(bindUnmarshalerType) {
return []decoder{&fieldCtxDecoder{index: index, fieldName: field.Name, fieldType: field.Type}}, nil
}

tagScope := lookupTagScope(field, opt)
if tagScope == "" {
return nil, nil
}
Expand Down Expand Up @@ -202,15 +206,56 @@ func compileTextBasedDecoder(field reflect.StructField, index int, tagScope, tag
return []decoder{fieldDecoder}, nil
}

type subElem struct {
et reflect.Type
tag string
index int
elementDecoder bind.TextDecoder
}

func compileSliceFieldTextBasedDecoder(field reflect.StructField, index int, tagScope string, tagContent string) ([]decoder, error) {
if field.Type.Kind() != reflect.Slice {
panic("BUG: unexpected type, expecting slice " + field.Type.String())
}

var elems []subElem
var elementUnmarshaler bind.TextDecoder
var err error

et := field.Type.Elem()
elementUnmarshaler, err := bind.CompileTextDecoder(et)
if err != nil {
return nil, fmt.Errorf("failed to build slice binder: %w", err)
if et.Kind() == reflect.Struct {
elems = make([]subElem, et.NumField())
for i := 0; i < et.NumField(); i++ {
if !et.Field(i).IsExported() {
// ignore unexported field
continue
}

// Skip different tag scopes (main -> sub)
subScope := lookupTagScope(et.Field(i), bindCompileOption{})
if subScope != tagScope {
continue
}

elementUnmarshaler, err := bind.CompileTextDecoder(et.Field(i).Type)
if err != nil {
return nil, fmt.Errorf("failed to build slice binder: %w", err)
}

elem := subElem{
index: i,
tag: et.Field(i).Tag.Get(subScope),
et: et.Field(i).Type,
elementDecoder: elementUnmarshaler,
}

elems = append(elems, elem)
}
} else {
elementUnmarshaler, err = bind.CompileTextDecoder(et)
if err != nil {
return nil, fmt.Errorf("failed to build slice binder: %w", err)
}
}

var eqBytes = bytes.Equal
Expand All @@ -236,7 +281,8 @@ func compileSliceFieldTextBasedDecoder(field reflect.StructField, index int, tag
return nil, errors.New("unexpected tag scope " + strconv.Quote(tagScope))
}

return []decoder{&fieldSliceDecoder{
fieldSliceDecoder := &fieldSliceDecoder{
elems: elems,
fieldIndex: index,
eqBytes: eqBytes,
fieldName: field.Name,
Expand All @@ -245,5 +291,7 @@ func compileSliceFieldTextBasedDecoder(field reflect.StructField, index int, tag
fieldType: field.Type,
elementType: et,
elementDecoder: elementUnmarshaler,
}}, nil
}

return []decoder{fieldSliceDecoder}, nil
}
89 changes: 89 additions & 0 deletions binder_slice.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
package fiber

import (
"bytes"
"reflect"
"strconv"

"github.com/gofiber/fiber/v3/internal/bind"
"github.com/gofiber/utils/v2"
Expand All @@ -11,6 +13,7 @@ var _ decoder = (*fieldSliceDecoder)(nil)

type fieldSliceDecoder struct {
fieldIndex int
elems []subElem
fieldName string
fieldType reflect.Type
reqKey []byte
Expand All @@ -22,6 +25,10 @@ type fieldSliceDecoder struct {
}

func (d *fieldSliceDecoder) Decode(ctx Ctx, reqValue reflect.Value) error {
if d.elementType.Kind() == reflect.Struct {
return d.decodeStruct(ctx, reqValue)
}

count := 0
d.visitAll(ctx, func(key, value []byte) {
if d.eqBytes(key, d.reqKey) {
Expand Down Expand Up @@ -59,6 +66,88 @@ func (d *fieldSliceDecoder) Decode(ctx Ctx, reqValue reflect.Value) error {
return nil
}

func (d *fieldSliceDecoder) decodeStruct(ctx Ctx, reqValue reflect.Value) error {
var maxNum int
d.visitAll(ctx, func(key, value []byte) {
start := bytes.IndexByte(key, byte('['))
end := bytes.IndexByte(key, byte(']'))

if start != -1 || end != -1 {
num := utils.UnsafeString(key[start+1 : end])

if len(num) > 0 {
maxNum, _ = strconv.Atoi(num)
}
}
})

if maxNum != 0 {
maxNum += 1
}

rv := reflect.MakeSlice(d.fieldType, maxNum, maxNum)
if maxNum == 0 {
reqValue.Field(d.fieldIndex).Set(rv)
return nil
}

var err error
d.visitAll(ctx, func(key, value []byte) {
if err != nil {
return
}

if bytes.IndexByte(key, byte('[')) == -1 {
return
}

// TODO: support queries like data[0][users][0][name]
ints := make([]int, 0)
elems := make([]string, 0)

// nested
lookupKey := key
for {
start := bytes.IndexByte(lookupKey, byte('['))
end := bytes.IndexByte(lookupKey, byte(']'))

if start == -1 || end == -1 {
break
}

content := utils.UnsafeString(lookupKey[start+1 : end])
num, errElse := strconv.Atoi(content)

if errElse == nil {
ints = append(ints, num)
} else {
elems = append(elems, content)
}

lookupKey = lookupKey[end+1:]
}

for _, elem := range d.elems {
if elems[0] == elem.tag {
ev := reflect.New(elem.et)
if ee := elem.elementDecoder.UnmarshalString(utils.UnsafeString(value), ev.Elem()); ee != nil {
err = ee
}

i := rv.Index(ints[0])
i.Field(elem.index).Set(ev.Elem())
}
}
})

if err != nil {
return err
}

reqValue.Field(d.fieldIndex).Set(rv)
return nil
}

func visitQuery(ctx Ctx, f func(key []byte, value []byte)) {
ctx.Context().QueryArgs().VisitAll(f)
}
Expand Down

0 comments on commit 3661d33

Please sign in to comment.