Skip to content

Commit

Permalink
feat: pass requests through context (#596)
Browse files Browse the repository at this point in the history
Closes #537
  • Loading branch information
mitar committed May 19, 2021
1 parent 40e9171 commit 2f96bb8
Show file tree
Hide file tree
Showing 11 changed files with 58 additions and 21 deletions.
3 changes: 3 additions & 0 deletions access_request_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,9 @@ import (
func (f *Fosite) NewAccessRequest(ctx context.Context, r *http.Request, session Session) (AccessRequester, error) {
accessRequest := NewAccessRequest(session)

ctx = context.WithValue(ctx, RequestContextKey, r)
ctx = context.WithValue(ctx, AccessRequestContextKey, accessRequest)

if r.Method != "POST" {
return accessRequest, errorsx.WithStack(ErrInvalidRequest.WithHintf("HTTP method is '%s', expected 'POST'.", r.Method))
} else if err := r.ParseMultipartForm(1 << 20); err != nil && err != http.ErrNotMultipart {
Expand Down
14 changes: 9 additions & 5 deletions access_request_handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ func TestNewAccessRequest(t *testing.T) {
hasher := internal.NewMockHasher(ctrl)
defer ctrl.Finish()

ctx := gomock.AssignableToTypeOf(context.WithValue(context.TODO(), ContextKey("test"), nil))

client := &DefaultClient{}
fosite := &Fosite{Store: store, Hasher: hasher, AudienceMatchingStrategy: DefaultAudienceMatchingStrategy}
for k, c := range []struct {
Expand Down Expand Up @@ -136,7 +138,7 @@ func TestNewAccessRequest(t *testing.T) {
store.EXPECT().GetClient(gomock.Any(), gomock.Eq("foo")).Return(client, nil)
client.Public = false
client.Secret = []byte("foo")
hasher.EXPECT().Compare(context.TODO(), gomock.Eq([]byte("foo")), gomock.Eq([]byte("bar"))).Return(errors.New(""))
hasher.EXPECT().Compare(ctx, gomock.Eq([]byte("foo")), gomock.Eq([]byte("bar"))).Return(errors.New(""))
},
handlers: TokenEndpointHandlers{handler},
},
Expand All @@ -153,7 +155,7 @@ func TestNewAccessRequest(t *testing.T) {
store.EXPECT().GetClient(gomock.Any(), gomock.Eq("foo")).Return(client, nil)
client.Public = false
client.Secret = []byte("foo")
hasher.EXPECT().Compare(context.TODO(), gomock.Eq([]byte("foo")), gomock.Eq([]byte("bar"))).Return(nil)
hasher.EXPECT().Compare(ctx, gomock.Eq([]byte("foo")), gomock.Eq([]byte("bar"))).Return(nil)
handler.EXPECT().HandleTokenEndpointRequest(gomock.Any(), gomock.Any()).Return(ErrServerError)
},
handlers: TokenEndpointHandlers{handler},
Expand All @@ -170,7 +172,7 @@ func TestNewAccessRequest(t *testing.T) {
store.EXPECT().GetClient(gomock.Any(), gomock.Eq("foo")).Return(client, nil)
client.Public = false
client.Secret = []byte("foo")
hasher.EXPECT().Compare(context.TODO(), gomock.Eq([]byte("foo")), gomock.Eq([]byte("bar"))).Return(nil)
hasher.EXPECT().Compare(ctx, gomock.Eq([]byte("foo")), gomock.Eq([]byte("bar"))).Return(nil)
handler.EXPECT().HandleTokenEndpointRequest(gomock.Any(), gomock.Any()).Return(nil)
},
handlers: TokenEndpointHandlers{handler},
Expand Down Expand Up @@ -369,6 +371,8 @@ func TestNewAccessRequestWithMixedClientAuth(t *testing.T) {
hasher := internal.NewMockHasher(ctrl)
defer ctrl.Finish()

ctx := gomock.AssignableToTypeOf(context.WithValue(context.TODO(), ContextKey("test"), nil))

client := &DefaultClient{}
fosite := &Fosite{Store: store, Hasher: hasher, AudienceMatchingStrategy: DefaultAudienceMatchingStrategy}
for k, c := range []struct {
Expand All @@ -391,7 +395,7 @@ func TestNewAccessRequestWithMixedClientAuth(t *testing.T) {
store.EXPECT().GetClient(gomock.Any(), gomock.Eq("foo")).Return(client, nil)
client.Public = false
client.Secret = []byte("foo")
hasher.EXPECT().Compare(context.TODO(), gomock.Eq([]byte("foo")), gomock.Eq([]byte("bar"))).Return(errors.New("hash err"))
hasher.EXPECT().Compare(ctx, gomock.Eq([]byte("foo")), gomock.Eq([]byte("bar"))).Return(errors.New("hash err"))
handlerWithoutClientAuth.EXPECT().HandleTokenEndpointRequest(gomock.Any(), gomock.Any()).Return(nil)
},
method: "POST",
Expand All @@ -409,7 +413,7 @@ func TestNewAccessRequestWithMixedClientAuth(t *testing.T) {
store.EXPECT().GetClient(gomock.Any(), gomock.Eq("foo")).Return(client, nil)
client.Public = false
client.Secret = []byte("foo")
hasher.EXPECT().Compare(context.TODO(), gomock.Eq([]byte("foo")), gomock.Eq([]byte("bar"))).Return(nil)
hasher.EXPECT().Compare(ctx, gomock.Eq([]byte("foo")), gomock.Eq([]byte("bar"))).Return(nil)
handlerWithoutClientAuth.EXPECT().HandleTokenEndpointRequest(gomock.Any(), gomock.Any()).Return(nil)
handlerWithClientAuth.EXPECT().HandleTokenEndpointRequest(gomock.Any(), gomock.Any()).Return(nil)
},
Expand Down
4 changes: 4 additions & 0 deletions access_response_writer.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,10 @@ func (f *Fosite) NewAccessResponse(ctx context.Context, requester AccessRequeste
var tk TokenEndpointHandler

response := NewAccessResponse()

ctx = context.WithValue(ctx, AccessRequestContextKey, requester)
ctx = context.WithValue(ctx, AccessResponseContextKey, response)

for _, tk = range f.TokenEndpointHandlers {
if err = tk.PopulateTokenEndpointResponse(ctx, requester, response); err == nil {
// do nothing
Expand Down
2 changes: 1 addition & 1 deletion access_response_writer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ func TestNewAccessResponse(t *testing.T) {
t.Run(fmt.Sprintf("case=%d", k), func(t *testing.T) {
f.TokenEndpointHandlers = c.handlers
c.mock()
ar, err := f.NewAccessResponse(nil, nil)
ar, err := f.NewAccessResponse(context.TODO(), nil)

if c.expectErr != nil {
assert.EqualError(t, err, c.expectErr.Error())
Expand Down
3 changes: 3 additions & 0 deletions authorize_request_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,9 @@ func (f *Fosite) validateResponseMode(r *http.Request, request *AuthorizeRequest
func (f *Fosite) NewAuthorizeRequest(ctx context.Context, r *http.Request) (AuthorizeRequester, error) {
request := NewAuthorizeRequest()

ctx = context.WithValue(ctx, RequestContextKey, r)
ctx = context.WithValue(ctx, AuthorizeRequestContextKey, request)

if err := r.ParseMultipartForm(1 << 20); err != nil && err != http.ErrNotMultipart {
return request, errorsx.WithStack(ErrInvalidRequest.WithHint("Unable to parse HTTP body, make sure to send a properly formatted form request body.").WithWrap(err).WithDebug(err.Error()))
}
Expand Down
3 changes: 3 additions & 0 deletions authorize_response_writer.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@ func (f *Fosite) NewAuthorizeResponse(ctx context.Context, ar AuthorizeRequester
Parameters: url.Values{},
}

ctx = context.WithValue(ctx, AuthorizeRequestContextKey, ar)
ctx = context.WithValue(ctx, AuthorizeResponseContextKey, resp)

ar.SetSession(session)
for _, h := range f.AuthorizeEndpointHandlers {
if err := h.HandleAuthorizeEndpointRequest(ctx, ar, resp); err != nil {
Expand Down
10 changes: 10 additions & 0 deletions context.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,13 @@ import "context"
func NewContext() context.Context {
return context.Background()
}

type ContextKey string

const (
RequestContextKey = ContextKey("request")
AccessRequestContextKey = ContextKey("accessRequest")
AccessResponseContextKey = ContextKey("accessResponse")
AuthorizeRequestContextKey = ContextKey("authorizeRequest")
AuthorizeResponseContextKey = ContextKey("authorizeResponse")
)
2 changes: 2 additions & 0 deletions introspection_request_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,8 @@ import (
//
// token=mF_9.B5f-4.1JqM&token_type_hint=access_token
func (f *Fosite) NewIntrospectionRequest(ctx context.Context, r *http.Request, session Session) (IntrospectionResponder, error) {
ctx = context.WithValue(ctx, RequestContextKey, r)

if r.Method != "POST" {
return &IntrospectionResponse{Active: false}, errorsx.WithStack(ErrInvalidRequest.WithHintf("HTTP method is '%s' but expected 'POST'.", r.Method))
} else if err := r.ParseMultipartForm(1 << 20); err != nil && err != http.ErrNotMultipart {
Expand Down
24 changes: 14 additions & 10 deletions introspection_request_handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ func TestIntrospectionResponseTokenUse(t *testing.T) {
validator := internal.NewMockTokenIntrospector(ctrl)
defer ctrl.Finish()

ctx := gomock.AssignableToTypeOf(context.WithValue(context.TODO(), ContextKey("test"), nil))

f := compose.ComposeAllEnabled(new(compose.Config), storage.NewExampleStore(), []byte{}, nil).(*Fosite)
httpreq := &http.Request{
Method: "POST",
Expand All @@ -65,8 +67,8 @@ func TestIntrospectionResponseTokenUse(t *testing.T) {
description: "introspecting access token",
setup: func() {
f.TokenIntrospectionHandlers = TokenIntrospectionHandlers{validator}
validator.EXPECT().IntrospectToken(context.TODO(), "some-token", gomock.Any(), gomock.Any(), gomock.Any()).Return(TokenUse(""), nil)
validator.EXPECT().IntrospectToken(context.TODO(), "introspect-token", gomock.Any(), gomock.Any(), gomock.Any()).Return(AccessToken, nil)
validator.EXPECT().IntrospectToken(ctx, "some-token", gomock.Any(), gomock.Any(), gomock.Any()).Return(TokenUse(""), nil)
validator.EXPECT().IntrospectToken(ctx, "introspect-token", gomock.Any(), gomock.Any(), gomock.Any()).Return(AccessToken, nil)
},
expectedATT: BearerAccessToken,
expectedTU: AccessToken,
Expand All @@ -75,8 +77,8 @@ func TestIntrospectionResponseTokenUse(t *testing.T) {
description: "introspecting refresh token",
setup: func() {
f.TokenIntrospectionHandlers = TokenIntrospectionHandlers{validator}
validator.EXPECT().IntrospectToken(context.TODO(), "some-token", gomock.Any(), gomock.Any(), gomock.Any()).Return(TokenUse(""), nil)
validator.EXPECT().IntrospectToken(context.TODO(), "introspect-token", gomock.Any(), gomock.Any(), gomock.Any()).Return(RefreshToken, nil)
validator.EXPECT().IntrospectToken(ctx, "some-token", gomock.Any(), gomock.Any(), gomock.Any()).Return(TokenUse(""), nil)
validator.EXPECT().IntrospectToken(ctx, "introspect-token", gomock.Any(), gomock.Any(), gomock.Any()).Return(RefreshToken, nil)
},
expectedATT: "",
expectedTU: RefreshToken,
Expand Down Expand Up @@ -106,6 +108,8 @@ func TestNewIntrospectionRequest(t *testing.T) {
validator := internal.NewMockTokenIntrospector(ctrl)
defer ctrl.Finish()

ctx := gomock.AssignableToTypeOf(context.WithValue(context.TODO(), ContextKey("test"), nil))

f := compose.ComposeAllEnabled(new(compose.Config), storage.NewExampleStore(), []byte{}, nil).(*Fosite)
httpreq := &http.Request{
Method: "POST",
Expand Down Expand Up @@ -139,8 +143,8 @@ func TestNewIntrospectionRequest(t *testing.T) {
"token": []string{"introspect-token"},
},
}
validator.EXPECT().IntrospectToken(context.TODO(), "some-token", gomock.Any(), gomock.Any(), gomock.Any()).Return(TokenUse(""), nil)
validator.EXPECT().IntrospectToken(context.TODO(), "introspect-token", gomock.Any(), gomock.Any(), gomock.Any()).Return(TokenUse(""), newErr)
validator.EXPECT().IntrospectToken(ctx, "some-token", gomock.Any(), gomock.Any(), gomock.Any()).Return(TokenUse(""), nil)
validator.EXPECT().IntrospectToken(ctx, "introspect-token", gomock.Any(), gomock.Any(), gomock.Any()).Return(TokenUse(""), newErr)
},
isActive: false,
expectErr: ErrInactiveToken,
Expand All @@ -158,8 +162,8 @@ func TestNewIntrospectionRequest(t *testing.T) {
"token": []string{"introspect-token"},
},
}
validator.EXPECT().IntrospectToken(context.TODO(), "some-token", gomock.Any(), gomock.Any(), gomock.Any()).Return(TokenUse(""), nil)
validator.EXPECT().IntrospectToken(context.TODO(), "introspect-token", gomock.Any(), gomock.Any(), gomock.Any()).Return(TokenUse(""), nil)
validator.EXPECT().IntrospectToken(ctx, "some-token", gomock.Any(), gomock.Any(), gomock.Any()).Return(TokenUse(""), nil)
validator.EXPECT().IntrospectToken(ctx, "introspect-token", gomock.Any(), gomock.Any(), gomock.Any()).Return(TokenUse(""), nil)
},
isActive: true,
},
Expand All @@ -177,7 +181,7 @@ func TestNewIntrospectionRequest(t *testing.T) {
"token": []string{"introspect-token"},
},
}
validator.EXPECT().IntrospectToken(context.TODO(), "introspect-token", gomock.Any(), gomock.Any(), gomock.Any()).Return(TokenUse(""), nil)
validator.EXPECT().IntrospectToken(ctx, "introspect-token", gomock.Any(), gomock.Any(), gomock.Any()).Return(TokenUse(""), nil)
},
isActive: true,
},
Expand All @@ -195,7 +199,7 @@ func TestNewIntrospectionRequest(t *testing.T) {
"token": []string{"introspect-token"},
},
}
validator.EXPECT().IntrospectToken(context.TODO(), "introspect-token", gomock.Any(), gomock.Any(), gomock.Any()).Return(TokenUse(""), nil)
validator.EXPECT().IntrospectToken(ctx, "introspect-token", gomock.Any(), gomock.Any(), gomock.Any()).Return(TokenUse(""), nil)
},
isActive: true,
},
Expand Down
2 changes: 2 additions & 0 deletions revoke_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ import (
// An invalid token type hint value is ignored by the authorization
// server and does not influence the revocation response.
func (f *Fosite) NewRevocationRequest(ctx context.Context, r *http.Request) error {
ctx = context.WithValue(ctx, RequestContextKey, r)

if r.Method != "POST" {
return errorsx.WithStack(ErrInvalidRequest.WithHintf("HTTP method is '%s' but expected 'POST'.", r.Method))
} else if err := r.ParseMultipartForm(1 << 20); err != nil && err != http.ErrNotMultipart {
Expand Down
12 changes: 7 additions & 5 deletions revoke_handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ func TestNewRevocationRequest(t *testing.T) {
hasher := internal.NewMockHasher(ctrl)
defer ctrl.Finish()

ctx := gomock.AssignableToTypeOf(context.WithValue(context.TODO(), ContextKey("test"), nil))

client := &DefaultClient{}
fosite := &Fosite{Store: store, Hasher: hasher}
for k, c := range []struct {
Expand Down Expand Up @@ -102,7 +104,7 @@ func TestNewRevocationRequest(t *testing.T) {
store.EXPECT().GetClient(gomock.Any(), gomock.Eq("foo")).Return(client, nil)
client.Secret = []byte("foo")
client.Public = false
hasher.EXPECT().Compare(context.TODO(), gomock.Eq([]byte("foo")), gomock.Eq([]byte("bar"))).Return(errors.New(""))
hasher.EXPECT().Compare(ctx, gomock.Eq([]byte("foo")), gomock.Eq([]byte("bar"))).Return(errors.New(""))
},
},
{
Expand All @@ -118,7 +120,7 @@ func TestNewRevocationRequest(t *testing.T) {
store.EXPECT().GetClient(gomock.Any(), gomock.Eq("foo")).Return(client, nil)
client.Secret = []byte("foo")
client.Public = false
hasher.EXPECT().Compare(context.TODO(), gomock.Eq([]byte("foo")), gomock.Eq([]byte("bar"))).Return(nil)
hasher.EXPECT().Compare(ctx, gomock.Eq([]byte("foo")), gomock.Eq([]byte("bar"))).Return(nil)
handler.EXPECT().RevokeToken(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil)
},
handlers: RevocationHandlers{handler},
Expand All @@ -137,7 +139,7 @@ func TestNewRevocationRequest(t *testing.T) {
store.EXPECT().GetClient(gomock.Any(), gomock.Eq("foo")).Return(client, nil)
client.Secret = []byte("foo")
client.Public = false
hasher.EXPECT().Compare(context.TODO(), gomock.Eq([]byte("foo")), gomock.Eq([]byte("bar"))).Return(nil)
hasher.EXPECT().Compare(ctx, gomock.Eq([]byte("foo")), gomock.Eq([]byte("bar"))).Return(nil)
handler.EXPECT().RevokeToken(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil)
},
handlers: RevocationHandlers{handler},
Expand Down Expand Up @@ -173,7 +175,7 @@ func TestNewRevocationRequest(t *testing.T) {
store.EXPECT().GetClient(gomock.Any(), gomock.Eq("foo")).Return(client, nil)
client.Secret = []byte("foo")
client.Public = false
hasher.EXPECT().Compare(context.TODO(), gomock.Eq([]byte("foo")), gomock.Eq([]byte("bar"))).Return(nil)
hasher.EXPECT().Compare(ctx, gomock.Eq([]byte("foo")), gomock.Eq([]byte("bar"))).Return(nil)
handler.EXPECT().RevokeToken(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil)
},
handlers: RevocationHandlers{handler},
Expand All @@ -192,7 +194,7 @@ func TestNewRevocationRequest(t *testing.T) {
store.EXPECT().GetClient(gomock.Any(), gomock.Eq("foo")).Return(client, nil)
client.Secret = []byte("foo")
client.Public = false
hasher.EXPECT().Compare(context.TODO(), gomock.Eq([]byte("foo")), gomock.Eq([]byte("bar"))).Return(nil)
hasher.EXPECT().Compare(ctx, gomock.Eq([]byte("foo")), gomock.Eq([]byte("bar"))).Return(nil)
handler.EXPECT().RevokeToken(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil)
},
handlers: RevocationHandlers{handler},
Expand Down

0 comments on commit 2f96bb8

Please sign in to comment.