diff --git a/CHANGELOG.md b/CHANGELOG.md index cc78dd3..e72470e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,11 @@ How to release a new version: - Manually release new version. ## [Unreleased] +### Added +- package `http/param`: can parse into embedded structs. + +### Removed +- package `http/param`: can no longer change the tag value prefix the parser reacts to (e.g. from `param:"query=q"` to `param:"myPrefix=q"`) ## [0.7.1] - 2024-07-11 ### Changed diff --git a/http/param/param.go b/http/param/param.go index 0d417ca..407af93 100644 --- a/http/param/param.go +++ b/http/param/param.go @@ -9,41 +9,29 @@ import ( "strings" ) -// TagResolver is a function that decides from a field type what key of http parameter should be searched. -// Second return value should return whether the key should be searched in http parameter at all. -type TagResolver func(fieldTag reflect.StructTag) (string, bool) +const ( + defaultTagName = "param" + queryTagValuePrefix = "query" + pathTagValuePrefix = "path" +) -// FixedTagNameParamTagResolver returns a TagResolver, that matches struct params by specific tag. -// Example: FixedTagNameParamTagResolver("mytag") matches a field tagged with `mytag:"param_name"` -func FixedTagNameParamTagResolver(tagName string) TagResolver { - return func(fieldTag reflect.StructTag) (string, bool) { - taggedParamName := fieldTag.Get(tagName) - return taggedParamName, taggedParamName != "" - } -} +// TagResolver is a function that decides from a field tag what parameter should be searched. +// Second return value should return whether the parameter should be searched at all. +type TagResolver func(fieldTag reflect.StructTag) (string, bool) -// TagWithModifierTagResolver returns a TagResolver, that matches struct params by specific tag and -// by a value before a '=' separator. -// Example: FixedTagNameParamTagResolver("mytag", "mymodifier") matches a field tagged with `mytag:"mymodifier=param_name"` -func TagWithModifierTagResolver(tagName string, tagModifier string) TagResolver { +// TagNameResolver returns a TagResolver that returns the value of tag with tagName, and whether the tag exists at all. +// It can be used to replace Parser.ParamTagResolver to change what tag name the Parser reacts to. +func TagNameResolver(tagName string) TagResolver { return func(fieldTag reflect.StructTag) (string, bool) { tagValue := fieldTag.Get(tagName) if tagValue == "" { return "", false } - splits := strings.Split(tagValue, "=") - //nolint:gomnd // 2 not really that magic number - one value before '=', one after - if len(splits) != 2 { - return "", false - } - if splits[0] == tagModifier { - return splits[1], true - } - return "", false + return tagValue, true } } -// PathParamFunc is a function that returns value of specified http path parameter +// PathParamFunc is a function that returns value of specified http path parameter. type PathParamFunc func(r *http.Request, key string) string // Parser can Parse query and path parameters from http.Request into a struct. @@ -53,18 +41,16 @@ type PathParamFunc func(r *http.Request, key string) string // PathParamFunc is for getting path parameter from http.Request, as each http router handles it in different way (if at all). // For example for chi, use WithPathParamFunc(chi.URLParam) to be able to use tags for path parameters. type Parser struct { - QueryParamTagResolver TagResolver - PathParamTagResolver TagResolver - PathParamFunc PathParamFunc + ParamTagResolver TagResolver + PathParamFunc PathParamFunc } // DefaultParser returns query and path parameter Parser with intended struct tags // `param:"query=param_name"` for query parameters and `param:"path=param_name"` for path parameters func DefaultParser() Parser { return Parser{ - QueryParamTagResolver: TagWithModifierTagResolver("param", "query"), - PathParamTagResolver: TagWithModifierTagResolver("param", "path"), - PathParamFunc: nil, // keep nil, as there is no sensible default of how to get value of path parameter + ParamTagResolver: TagNameResolver(defaultTagName), + PathParamFunc: nil, // keep nil, as there is no sensible default of how to get value of path parameter } } @@ -75,7 +61,8 @@ func (p Parser) WithPathParamFunc(f PathParamFunc) Parser { return p } -// Parse accepts the request and a pointer to struct that is tagged with appropriate tags set in Parser. +// Parse accepts the request and a pointer to struct with its fields tagged with appropriate tags set in Parser. +// Such tagged fields must be in top level struct, or in exported struct embedded in top-level struct. // All such tagged fields are assigned the respective parameter from the actual request. // // Fields are assigned their zero value if the field was tagged but request did not contain such parameter. @@ -100,13 +87,20 @@ func (p Parser) Parse(r *http.Request, dest any) error { return fmt.Errorf("can only parse into struct, but got %s", v.Type().Name()) } - for i := 0; i < v.NumField(); i++ { - typeField := v.Type().Field(i) - if !typeField.IsExported() { - continue + fieldIndexPaths := p.findTaggedIndexPaths(v.Type(), []int{}, []taggedFieldIndexPath{}) + + for i := range fieldIndexPaths { + // Zero the value, even if it would not be set by following path or query parameter. + // This will cause potential partial result from previous parser (e.g. json.Unmarshal) to be discarded on + // fields that are tagged for path or query parameter. + err := zeroPath(v, &fieldIndexPaths[i]) + if err != nil { + return err } - valueField := v.Field(i) - err := p.parseParam(r, typeField, valueField) + } + + for _, path := range fieldIndexPaths { + err := p.parseParam(r, path) if err != nil { return err } @@ -114,34 +108,98 @@ func (p Parser) Parse(r *http.Request, dest any) error { return nil } -func (p Parser) parseParam(r *http.Request, typeField reflect.StructField, v reflect.Value) error { - tag := typeField.Tag - pathParamName, okPath := p.PathParamTagResolver(tag) - queryParamName, okQuery := p.QueryParamTagResolver(tag) - if !okPath && !okQuery { - // do nothing if tagged neither for query nor param - return nil +type paramType int + +const ( + paramTypeQuery paramType = iota + paramTypePath +) + +type taggedFieldIndexPath struct { + paramType paramType + paramName string + indexPath []int + destValue reflect.Value +} + +func (p Parser) findTaggedIndexPaths(typ reflect.Type, currentNestingIndexPath []int, paths []taggedFieldIndexPath) []taggedFieldIndexPath { + for i := 0; i < typ.NumField(); i++ { + typeField := typ.Field(i) + if typeField.Anonymous { + t := typeField.Type + if t.Kind() == reflect.Pointer { + t = t.Elem() + } + if t.Kind() == reflect.Struct { + paths = p.findTaggedIndexPaths(t, append(currentNestingIndexPath, i), paths) + } + } + if !typeField.IsExported() { + continue + } + tag := typeField.Tag + pathParamName, okPath := p.resolvePath(tag) + queryParamName, okQuery := p.resolveQuery(tag) + if okPath { + newPath := make([]int, 0, len(currentNestingIndexPath)+1) + newPath = append(newPath, currentNestingIndexPath...) + newPath = append(newPath, i) + paths = append(paths, taggedFieldIndexPath{ + paramType: paramTypePath, + paramName: pathParamName, + indexPath: newPath, + }) + } + if okQuery { + newPath := make([]int, 0, len(currentNestingIndexPath)+1) + newPath = append(newPath, currentNestingIndexPath...) + newPath = append(newPath, i) + paths = append(paths, taggedFieldIndexPath{ + paramType: paramTypeQuery, + paramName: queryParamName, + indexPath: newPath, + }) + } } + return paths +} - // Zero the value, even if it would not be set by following path or query parameter. - // This will cause potential partial result from previous parser (e.g. json.Unmarshal) to be discarded on - // fields that are tagged for path or query parameter. - v.Set(reflect.Zero(typeField.Type)) +func zeroPath(v reflect.Value, path *taggedFieldIndexPath) error { + for n, i := range path.indexPath { + if v.Kind() == reflect.Pointer { + v = v.Elem() + } + // findTaggedIndexPaths prepared a path.indexPath in such a way, that respective field is always + // pointer to struct or struct -> should be always able to .Field() here + typeField := v.Type().Field(i) + v = v.Field(i) - if okPath { - err := p.parsePathParam(r, pathParamName, v) - if err != nil { - return err + if n == len(path.indexPath)-1 { + v.Set(reflect.Zero(typeField.Type)) + path.destValue = v + } else if v.Kind() == reflect.Pointer && v.IsNil() { + if !v.CanSet() { + return fmt.Errorf("cannot set embedded pointer to unexported struct: %v", v.Type().Elem()) + } + v.Set(reflect.New(v.Type().Elem())) } } + return nil +} - if okQuery { - err := p.parseQueryParam(r, queryParamName, v) +func (p Parser) parseParam(r *http.Request, path taggedFieldIndexPath) error { + switch path.paramType { + case paramTypePath: + err := p.parsePathParam(r, path.paramName, path.destValue) + if err != nil { + return err + } + case paramTypeQuery: + err := p.parseQueryParam(r, path.paramName, path.destValue) if err != nil { return err } } - return nil } @@ -246,3 +304,33 @@ func unmarshalPrimitiveValue(text string, dest reflect.Value) error { } return nil } + +// resolveTagValueWithModifier returns a parameter value in tag value containing a prefix "tagModifier=". +// Example: resolveTagValueWithModifier("query=param_name", "query") returns "param_name", true. +func (p Parser) resolveTagValueWithModifier(tagValue string, tagModifier string) (string, bool) { + splits := strings.Split(tagValue, "=") + //nolint:gomnd // 2 not really that magic number - one value before '=', one after + if len(splits) != 2 { + return "", false + } + if splits[0] == tagModifier { + return splits[1], true + } + return "", false +} + +func (p Parser) resolveTagWithModifier(fieldTag reflect.StructTag, tagModifier string) (string, bool) { + tagValue, ok := p.ParamTagResolver(fieldTag) + if !ok { + return "", false + } + return p.resolveTagValueWithModifier(tagValue, tagModifier) +} + +func (p Parser) resolvePath(fieldTag reflect.StructTag) (string, bool) { + return p.resolveTagWithModifier(fieldTag, pathTagValuePrefix) +} + +func (p Parser) resolveQuery(fieldTag reflect.StructTag) (string, bool) { + return p.resolveTagWithModifier(fieldTag, queryTagValuePrefix) +} diff --git a/http/param/param_test.go b/http/param/param_test.go index 7b2b73d..c052990 100644 --- a/http/param/param_test.go +++ b/http/param/param_test.go @@ -471,6 +471,99 @@ func TestParser_Parse_DoesNotOverwrite(t *testing.T) { assert.Equal(t, expected, result) } +type EmbeddedStruct struct { + Embedded string `param:"query=embedded"` +} + +type embeddingStruct struct { + EmbeddedStruct +} + +type embeddingPtrStruct struct { + *EmbeddedStruct +} + +type embeddedStruct struct { + Embedded string `param:"query=embedded"` +} + +type embeddingUnexported struct { + embeddedStruct +} + +type embeddingUnexportedPtr struct { + *embeddedStruct +} + +type embeddingNested struct { + embeddingUnexported +} + +func TestParser_Parse_Embedded(t *testing.T) { + p := DefaultParser() + req := httptest.NewRequest(http.MethodGet, "https://test.com/hello?embedded=input", nil) + + tests := []struct { + resultPtr any + expectedPtr any + }{ + { + resultPtr: new(embeddingStruct), + expectedPtr: &embeddingStruct{ + EmbeddedStruct{ + Embedded: "input", + }, + }, + }, + { + resultPtr: new(embeddingPtrStruct), + expectedPtr: &embeddingPtrStruct{ + EmbeddedStruct: &EmbeddedStruct{ + Embedded: "input", + }, + }, + }, + { + resultPtr: new(embeddingUnexported), + expectedPtr: &embeddingUnexported{ + embeddedStruct: embeddedStruct{ + Embedded: "input", + }, + }, + }, + { + resultPtr: new(embeddingNested), + expectedPtr: &embeddingNested{ + embeddingUnexported{ + embeddedStruct{ + Embedded: "input", + }, + }, + }, + }, + } + + for _, tt := range tests { + t.Run(reflect.TypeOf(tt.resultPtr).Elem().Name(), func(t *testing.T) { + err := p.Parse(req, tt.resultPtr) + + assert.NoError(t, err) + assert.Equal(t, tt.expectedPtr, tt.resultPtr) + }) + } +} + +func TestParser_Parse_Embedded_Error(t *testing.T) { + p := DefaultParser() + req := httptest.NewRequest(http.MethodGet, "https://test.com/hello?embedded=input", nil) + + var result embeddingUnexportedPtr + err := p.Parse(req, &result) + + assert.ErrorContains(t, err, "unexported") + assert.ErrorContains(t, err, "embeddedStruct") +} + type variousTagsStruct struct { A string `key:"location=val"` B string `key:"location=val=excessive"` @@ -481,7 +574,7 @@ type variousTagsStruct struct { func TestTagWithModifierTagResolver(t *testing.T) { const correctKey = "key" - const correctLocation = "location" + const correctPrefix = "location" testCases := []struct { fieldName string @@ -516,44 +609,11 @@ func TestTagWithModifierTagResolver(t *testing.T) { } for _, tc := range testCases { t.Run(tc.fieldName, func(t *testing.T) { - tagResolver := TagWithModifierTagResolver(correctKey, correctLocation) - structField, found := reflect.TypeOf(variousTagsStruct{}).FieldByName(tc.fieldName) - require.True(t, found) - - paramName, ok := tagResolver(structField.Tag) - - assert.Equal(t, tc.expectedParam, paramName) - assert.Equal(t, tc.expectedOk, ok) - }) - } -} - -func TestFixedTagNameParamTagResolver(t *testing.T) { - const correctKey = "key" - - testCases := []struct { - fieldName string - expectedParam string - expectedOk bool - }{ - { - fieldName: "A", - expectedParam: "location=val", - expectedOk: true, - }, - { - fieldName: "D", - expectedParam: "", - expectedOk: false, - }, - } - for _, tc := range testCases { - t.Run(tc.fieldName, func(t *testing.T) { - tagResolver := FixedTagNameParamTagResolver(correctKey) + parser := Parser{ParamTagResolver: TagNameResolver(correctKey)} structField, found := reflect.TypeOf(variousTagsStruct{}).FieldByName(tc.fieldName) require.True(t, found) - paramName, ok := tagResolver(structField.Tag) + paramName, ok := parser.resolveTagWithModifier(structField.Tag, correctPrefix) assert.Equal(t, tc.expectedParam, paramName) assert.Equal(t, tc.expectedOk, ok)