Skip to content

Commit

Permalink
[feature] Use http.Request.Context() in >= Go 1.7.
Browse files Browse the repository at this point in the history
  • Loading branch information
elithrar committed Jun 2, 2016
2 parents 16dc2f5 + 5b56d12 commit 50eb875
Show file tree
Hide file tree
Showing 5 changed files with 58 additions and 10 deletions.
25 changes: 25 additions & 0 deletions context.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
// +build go1.7

package csrf

import (
"context"
"net/http"

"github.com/pkg/errors"
)

func contextGet(r *http.Request, key string) (interface{}, error) {
val := r.Context().Value(key)
if val == nil {
return nil, errors.Errorf("no value exists in the context for key %q", key)
}

return val, nil
}

func contextSave(r *http.Request, key string, val interface{}) *http.Request {
ctx := r.Context()
ctx = context.WithValue(ctx, key, val)
return r.WithContext(ctx)
}
24 changes: 24 additions & 0 deletions context_legacy.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
// +build !go1.7

package csrf

import (
"net/http"

"github.com/gorilla/context"

"github.com/pkg/errors"
)

func contextGet(r *http.Request, key string) (interface{}, error) {
if val, ok := context.GetOk(r, key); ok {
return val, nil
}

return nil, errors.Errorf("no value exists in the context for key %q", key)
}

func contextSave(r *http.Request, key string, val interface{}) *http.Request {
context.Set(r, key, val)
return r
}
6 changes: 3 additions & 3 deletions csrf.go
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ func Protect(authKey []byte, opts ...Option) func(http.Handler) http.Handler {
// Implements http.Handler for the csrf type.
func (cs *csrf) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// Skip the check if directed to. This should always be a bool.
if val, ok := context.GetOk(r, skipCheckKey); ok {
if val, err := contextGet(r, skipCheckKey); err == nil {
if skip, ok := val.(bool); ok {
if skip {
cs.h.ServeHTTP(w, r)
Expand Down Expand Up @@ -209,9 +209,9 @@ func (cs *csrf) ServeHTTP(w http.ResponseWriter, r *http.Request) {
}

// Save the masked token to the request context
context.Set(r, tokenKey, mask(realToken, r))
r = contextSave(r, tokenKey, mask(realToken, r))
// Save the field name to the request context
context.Set(r, formKey, cs.opts.FieldName)
r = contextSave(r, formKey, cs.opts.FieldName)

// HTTP methods not defined as idempotent ("safe") under RFC7231 require
// inspection.
Expand Down
11 changes: 5 additions & 6 deletions helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ import (
// a JSON response body. An empty token will be returned if the middleware
// has not been applied (which will fail subsequent validation).
func Token(r *http.Request) string {
if val, ok := context.GetOk(r, tokenKey); ok {
if val, err := contextGet(r, tokenKey); err == nil {
if maskedToken, ok := val.(string); ok {
return maskedToken
}
Expand All @@ -29,7 +29,7 @@ func Token(r *http.Request) string {
// This is useful when you want to log the cause of the error or report it to
// client.
func FailureReason(r *http.Request) error {
if val, ok := context.GetOk(r, errorKey); ok {
if val, err := contextGet(r, errorKey); err == nil {
if err, ok := val.(error); ok {
return err
}
Expand All @@ -44,8 +44,8 @@ func FailureReason(r *http.Request) error {
// Note: You should not set this without otherwise securing the request from
// CSRF attacks. The primary use-case for this function is to turn off CSRF
// checks for non-browser clients using authorization tokens against your API.
func UnsafeSkipCheck(r *http.Request) {
context.Set(r, skipCheckKey, true)
func UnsafeSkipCheck(r *http.Request) *http.Request {
return contextSave(r, skipCheckKey, true)
}

// TemplateField is a template helper for html/template that provides an <input> field
Expand All @@ -60,8 +60,7 @@ func UnsafeSkipCheck(r *http.Request) {
// <input type="hidden" name="gorilla.csrf.Token" value="<token>">
//
func TemplateField(r *http.Request) template.HTML {
name, ok := context.GetOk(r, formKey)
if ok {
if name, err := contextGet(r, formKey); err == nil {
fragment := fmt.Sprintf(`<input type="hidden" name="%s" value="%s">`,
name, Token(r))

Expand Down
2 changes: 1 addition & 1 deletion helpers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,7 @@ func TestUnsafeSkipCSRFCheck(t *testing.T) {
s := http.NewServeMux()
skipCheck := func(h http.Handler) http.Handler {
fn := func(w http.ResponseWriter, r *http.Request) {
UnsafeSkipCheck(r)
r = UnsafeSkipCheck(r)
h.ServeHTTP(w, r)
}

Expand Down

0 comments on commit 50eb875

Please sign in to comment.