diff --git a/go.mod b/go.mod index dfc5976..d920f63 100644 --- a/go.mod +++ b/go.mod @@ -16,6 +16,7 @@ require ( github.com/pires/go-proxyproto v0.6.0 github.com/poy/onpar v1.1.2 // indirect github.com/ziutek/mymysql v1.5.4 // indirect + golang.org/x/sys v0.3.0 // indirect ) go 1.13 diff --git a/go.sum b/go.sum index b11291e..ff4ed61 100644 --- a/go.sum +++ b/go.sum @@ -232,6 +232,8 @@ golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68 h1:nxC68pudNYkKU6jWhgrqdreuFiOQWj1Fs7T3VrH4Pjw= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.3.0 h1:w8ZOecv6NaNa/zC8944JTU3vz4u6Lagfk4RPQxv92NQ= +golang.org/x/sys v0.3.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= diff --git a/tonic/handler.go b/tonic/handler.go index 1009055..15c542f 100644 --- a/tonic/handler.go +++ b/tonic/handler.go @@ -91,6 +91,11 @@ func Handler(h interface{}, status int, options ...func(*Route)) gin.HandlerFunc handleError(c, err) return } + // Bind context-keys + if err := bind(c, input, ContextTag, extractContext); err != nil { + handleError(c, err) + return + } // validating query and path inputs if they have a validate tag initValidator() args = append(args, input) diff --git a/tonic/tonic.go b/tonic/tonic.go index 8bc9290..f166bf0 100644 --- a/tonic/tonic.go +++ b/tonic/tonic.go @@ -25,6 +25,7 @@ const ( QueryTag = "query" PathTag = "path" HeaderTag = "header" + ContextTag = "context" EnumTag = "enum" RequiredTag = "required" DefaultTag = "default" @@ -345,6 +346,26 @@ func extractHeader(c *gin.Context, tag string) (string, []string, error) { return name, []string{header}, nil } +// extractContext is an extractor that operates on the gin.Context +// of a request. +func extractContext(c *gin.Context, tag string) (string, []string, error) { + name, required, defaultVal, err := parseTagKey(tag) + if err != nil { + return "", nil, err + } + context := c.GetString(name) + + // XXX: deprecated, use of "default" tag is preferred + if context == "" && defaultVal != "" { + return name, []string{defaultVal}, nil + } + // XXX: deprecated, use of "validate" tag is preferred + if required && context == "" { + return "", nil, fmt.Errorf("missing header parameter: %s", name) + } + return name, []string{context}, nil +} + // Public signature does not expose "required" and "default" because // they are deprecated in favor of the "validate" and "default" tags func parseTagKey(tag string) (string, bool, string, error) { diff --git a/tonic/tonic_test.go b/tonic/tonic_test.go index 694c159..53d14d5 100644 --- a/tonic/tonic_test.go +++ b/tonic/tonic_test.go @@ -28,6 +28,23 @@ func TestMain(m *testing.M) { tonic.SetErrorHook(errorHook) g := gin.Default() + + // for context test + g.Use(func(c *gin.Context) { + if c.FullPath() == "/context" { + if val, ok := c.GetQuery("param"); ok { + c.Set("param", val) + } + if val, ok := c.GetQuery("param-optional"); ok { + c.Set("param-optional", val) + } + if val, ok := c.GetQuery("param-optional-validated"); ok { + c.Set("param-optional-validated", val) + } + } + c.Next() + }) + g.GET("/simple", tonic.Handler(simpleHandler, 200)) g.GET("/scalar", tonic.Handler(scalarHandler, 200)) g.GET("/error", tonic.Handler(errorHandler, 200)) @@ -35,6 +52,7 @@ func TestMain(m *testing.M) { g.GET("/query", tonic.Handler(queryHandler, 200)) g.GET("/query-old", tonic.Handler(queryHandlerOld, 200)) g.POST("/body", tonic.Handler(bodyHandler, 200)) + g.GET("/context", tonic.Handler(contextHandler, 200)) r = g @@ -130,6 +148,17 @@ func TestBody(t *testing.T) { tester.Run() } +func TestContext(t *testing.T) { + tester := iffy.NewTester(t, r) + + tester.AddCall("context", "GET", "/context?param=foo", ``).Checkers(iffy.ExpectStatus(200), expectString("param", "foo")) + tester.AddCall("context", "GET", "/context", ``).Checkers(iffy.ExpectStatus(400)) + tester.AddCall("context", "GET", "/context?param=foo¶m-optional=bar", ``).Checkers(iffy.ExpectStatus(200), expectString("param-optional", "bar")) + tester.AddCall("context", "GET", "/context?param=foo¶m-optional-validated=foo", ``).Checkers(iffy.ExpectStatus(200), expectString("param-optional-validated", "foo")) + + tester.Run() +} + func errorHandler(c *gin.Context) error { return errors.New("error") } @@ -199,6 +228,16 @@ func bodyHandler(c *gin.Context, in *bodyIn) (*bodyIn, error) { return in, nil } +type ContextIn struct { + Param string `context:"param" json:"param" validate:"required"` + ParamOptional string `context:"param-optional" json:"param-optional"` + ValidatedParamOptional string `context:"param-optional-validated" json:"param-optional-validated" validate:"eq=|eq=foo|gt=10"` +} + +func contextHandler(c *gin.Context, in *ContextIn) (*ContextIn, error) { + return in, nil +} + func expectEmptyBody(r *http.Response, body string, obj interface{}) error { if len(body) != 0 { return fmt.Errorf("Body '%s' should be empty", body)