Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Pass requests through context #596

Merged
merged 1 commit into from
May 19, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions access_request_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,9 @@ import (
func (f *Fosite) NewAccessRequest(ctx context.Context, r *http.Request, session Session) (AccessRequester, error) {
accessRequest := NewAccessRequest(session)

ctx = context.WithValue(ctx, RequestContextKey, r)
ctx = context.WithValue(ctx, AccessRequestContextKey, accessRequest)

if r.Method != "POST" {
return accessRequest, errorsx.WithStack(ErrInvalidRequest.WithHintf("HTTP method is '%s', expected 'POST'.", r.Method))
} else if err := r.ParseMultipartForm(1 << 20); err != nil && err != http.ErrNotMultipart {
Expand Down
14 changes: 9 additions & 5 deletions access_request_handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ func TestNewAccessRequest(t *testing.T) {
hasher := internal.NewMockHasher(ctrl)
defer ctrl.Finish()

ctx := gomock.AssignableToTypeOf(context.WithValue(context.TODO(), ContextKey("test"), nil))

client := &DefaultClient{}
fosite := &Fosite{Store: store, Hasher: hasher, AudienceMatchingStrategy: DefaultAudienceMatchingStrategy}
for k, c := range []struct {
Expand Down Expand Up @@ -136,7 +138,7 @@ func TestNewAccessRequest(t *testing.T) {
store.EXPECT().GetClient(gomock.Any(), gomock.Eq("foo")).Return(client, nil)
client.Public = false
client.Secret = []byte("foo")
hasher.EXPECT().Compare(context.TODO(), gomock.Eq([]byte("foo")), gomock.Eq([]byte("bar"))).Return(errors.New(""))
hasher.EXPECT().Compare(ctx, gomock.Eq([]byte("foo")), gomock.Eq([]byte("bar"))).Return(errors.New(""))
},
handlers: TokenEndpointHandlers{handler},
},
Expand All @@ -153,7 +155,7 @@ func TestNewAccessRequest(t *testing.T) {
store.EXPECT().GetClient(gomock.Any(), gomock.Eq("foo")).Return(client, nil)
client.Public = false
client.Secret = []byte("foo")
hasher.EXPECT().Compare(context.TODO(), gomock.Eq([]byte("foo")), gomock.Eq([]byte("bar"))).Return(nil)
hasher.EXPECT().Compare(ctx, gomock.Eq([]byte("foo")), gomock.Eq([]byte("bar"))).Return(nil)
handler.EXPECT().HandleTokenEndpointRequest(gomock.Any(), gomock.Any()).Return(ErrServerError)
},
handlers: TokenEndpointHandlers{handler},
Expand All @@ -170,7 +172,7 @@ func TestNewAccessRequest(t *testing.T) {
store.EXPECT().GetClient(gomock.Any(), gomock.Eq("foo")).Return(client, nil)
client.Public = false
client.Secret = []byte("foo")
hasher.EXPECT().Compare(context.TODO(), gomock.Eq([]byte("foo")), gomock.Eq([]byte("bar"))).Return(nil)
hasher.EXPECT().Compare(ctx, gomock.Eq([]byte("foo")), gomock.Eq([]byte("bar"))).Return(nil)
handler.EXPECT().HandleTokenEndpointRequest(gomock.Any(), gomock.Any()).Return(nil)
},
handlers: TokenEndpointHandlers{handler},
Expand Down Expand Up @@ -369,6 +371,8 @@ func TestNewAccessRequestWithMixedClientAuth(t *testing.T) {
hasher := internal.NewMockHasher(ctrl)
defer ctrl.Finish()

ctx := gomock.AssignableToTypeOf(context.WithValue(context.TODO(), ContextKey("test"), nil))

client := &DefaultClient{}
fosite := &Fosite{Store: store, Hasher: hasher, AudienceMatchingStrategy: DefaultAudienceMatchingStrategy}
for k, c := range []struct {
Expand All @@ -391,7 +395,7 @@ func TestNewAccessRequestWithMixedClientAuth(t *testing.T) {
store.EXPECT().GetClient(gomock.Any(), gomock.Eq("foo")).Return(client, nil)
client.Public = false
client.Secret = []byte("foo")
hasher.EXPECT().Compare(context.TODO(), gomock.Eq([]byte("foo")), gomock.Eq([]byte("bar"))).Return(errors.New("hash err"))
hasher.EXPECT().Compare(ctx, gomock.Eq([]byte("foo")), gomock.Eq([]byte("bar"))).Return(errors.New("hash err"))
handlerWithoutClientAuth.EXPECT().HandleTokenEndpointRequest(gomock.Any(), gomock.Any()).Return(nil)
},
method: "POST",
Expand All @@ -409,7 +413,7 @@ func TestNewAccessRequestWithMixedClientAuth(t *testing.T) {
store.EXPECT().GetClient(gomock.Any(), gomock.Eq("foo")).Return(client, nil)
client.Public = false
client.Secret = []byte("foo")
hasher.EXPECT().Compare(context.TODO(), gomock.Eq([]byte("foo")), gomock.Eq([]byte("bar"))).Return(nil)
hasher.EXPECT().Compare(ctx, gomock.Eq([]byte("foo")), gomock.Eq([]byte("bar"))).Return(nil)
handlerWithoutClientAuth.EXPECT().HandleTokenEndpointRequest(gomock.Any(), gomock.Any()).Return(nil)
handlerWithClientAuth.EXPECT().HandleTokenEndpointRequest(gomock.Any(), gomock.Any()).Return(nil)
},
Expand Down
4 changes: 4 additions & 0 deletions access_response_writer.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,10 @@ func (f *Fosite) NewAccessResponse(ctx context.Context, requester AccessRequeste
var tk TokenEndpointHandler

response := NewAccessResponse()

ctx = context.WithValue(ctx, AccessRequestContextKey, requester)
ctx = context.WithValue(ctx, AccessResponseContextKey, response)

for _, tk = range f.TokenEndpointHandlers {
if err = tk.PopulateTokenEndpointResponse(ctx, requester, response); err == nil {
// do nothing
Expand Down
2 changes: 1 addition & 1 deletion access_response_writer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ func TestNewAccessResponse(t *testing.T) {
t.Run(fmt.Sprintf("case=%d", k), func(t *testing.T) {
f.TokenEndpointHandlers = c.handlers
c.mock()
ar, err := f.NewAccessResponse(nil, nil)
ar, err := f.NewAccessResponse(context.TODO(), nil)

if c.expectErr != nil {
assert.EqualError(t, err, c.expectErr.Error())
Expand Down
3 changes: 3 additions & 0 deletions authorize_request_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,9 @@ func (f *Fosite) validateResponseMode(r *http.Request, request *AuthorizeRequest
func (f *Fosite) NewAuthorizeRequest(ctx context.Context, r *http.Request) (AuthorizeRequester, error) {
request := NewAuthorizeRequest()

ctx = context.WithValue(ctx, RequestContextKey, r)
ctx = context.WithValue(ctx, AuthorizeRequestContextKey, request)

if err := r.ParseMultipartForm(1 << 20); err != nil && err != http.ErrNotMultipart {
return request, errorsx.WithStack(ErrInvalidRequest.WithHint("Unable to parse HTTP body, make sure to send a properly formatted form request body.").WithWrap(err).WithDebug(err.Error()))
}
Expand Down
3 changes: 3 additions & 0 deletions authorize_response_writer.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@ func (f *Fosite) NewAuthorizeResponse(ctx context.Context, ar AuthorizeRequester
Parameters: url.Values{},
}

ctx = context.WithValue(ctx, AuthorizeRequestContextKey, ar)
ctx = context.WithValue(ctx, AuthorizeResponseContextKey, resp)

ar.SetSession(session)
for _, h := range f.AuthorizeEndpointHandlers {
if err := h.HandleAuthorizeEndpointRequest(ctx, ar, resp); err != nil {
Expand Down
10 changes: 10 additions & 0 deletions context.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,13 @@ import "context"
func NewContext() context.Context {
return context.Background()
}

type ContextKey string

const (
RequestContextKey = ContextKey("request")
AccessRequestContextKey = ContextKey("accessRequest")
AccessResponseContextKey = ContextKey("accessResponse")
AuthorizeRequestContextKey = ContextKey("authorizeRequest")
AuthorizeResponseContextKey = ContextKey("authorizeResponse")
)
2 changes: 2 additions & 0 deletions introspection_request_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,8 @@ import (
//
// token=mF_9.B5f-4.1JqM&token_type_hint=access_token
func (f *Fosite) NewIntrospectionRequest(ctx context.Context, r *http.Request, session Session) (IntrospectionResponder, error) {
ctx = context.WithValue(ctx, RequestContextKey, r)

if r.Method != "POST" {
return &IntrospectionResponse{Active: false}, errorsx.WithStack(ErrInvalidRequest.WithHintf("HTTP method is '%s' but expected 'POST'.", r.Method))
} else if err := r.ParseMultipartForm(1 << 20); err != nil && err != http.ErrNotMultipart {
Expand Down
24 changes: 14 additions & 10 deletions introspection_request_handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ func TestIntrospectionResponseTokenUse(t *testing.T) {
validator := internal.NewMockTokenIntrospector(ctrl)
defer ctrl.Finish()

ctx := gomock.AssignableToTypeOf(context.WithValue(context.TODO(), ContextKey("test"), nil))

f := compose.ComposeAllEnabled(new(compose.Config), storage.NewExampleStore(), []byte{}, nil).(*Fosite)
httpreq := &http.Request{
Method: "POST",
Expand All @@ -65,8 +67,8 @@ func TestIntrospectionResponseTokenUse(t *testing.T) {
description: "introspecting access token",
setup: func() {
f.TokenIntrospectionHandlers = TokenIntrospectionHandlers{validator}
validator.EXPECT().IntrospectToken(context.TODO(), "some-token", gomock.Any(), gomock.Any(), gomock.Any()).Return(TokenUse(""), nil)
validator.EXPECT().IntrospectToken(context.TODO(), "introspect-token", gomock.Any(), gomock.Any(), gomock.Any()).Return(AccessToken, nil)
validator.EXPECT().IntrospectToken(ctx, "some-token", gomock.Any(), gomock.Any(), gomock.Any()).Return(TokenUse(""), nil)
validator.EXPECT().IntrospectToken(ctx, "introspect-token", gomock.Any(), gomock.Any(), gomock.Any()).Return(AccessToken, nil)
},
expectedATT: BearerAccessToken,
expectedTU: AccessToken,
Expand All @@ -75,8 +77,8 @@ func TestIntrospectionResponseTokenUse(t *testing.T) {
description: "introspecting refresh token",
setup: func() {
f.TokenIntrospectionHandlers = TokenIntrospectionHandlers{validator}
validator.EXPECT().IntrospectToken(context.TODO(), "some-token", gomock.Any(), gomock.Any(), gomock.Any()).Return(TokenUse(""), nil)
validator.EXPECT().IntrospectToken(context.TODO(), "introspect-token", gomock.Any(), gomock.Any(), gomock.Any()).Return(RefreshToken, nil)
validator.EXPECT().IntrospectToken(ctx, "some-token", gomock.Any(), gomock.Any(), gomock.Any()).Return(TokenUse(""), nil)
validator.EXPECT().IntrospectToken(ctx, "introspect-token", gomock.Any(), gomock.Any(), gomock.Any()).Return(RefreshToken, nil)
},
expectedATT: "",
expectedTU: RefreshToken,
Expand Down Expand Up @@ -106,6 +108,8 @@ func TestNewIntrospectionRequest(t *testing.T) {
validator := internal.NewMockTokenIntrospector(ctrl)
defer ctrl.Finish()

ctx := gomock.AssignableToTypeOf(context.WithValue(context.TODO(), ContextKey("test"), nil))

f := compose.ComposeAllEnabled(new(compose.Config), storage.NewExampleStore(), []byte{}, nil).(*Fosite)
httpreq := &http.Request{
Method: "POST",
Expand Down Expand Up @@ -139,8 +143,8 @@ func TestNewIntrospectionRequest(t *testing.T) {
"token": []string{"introspect-token"},
},
}
validator.EXPECT().IntrospectToken(context.TODO(), "some-token", gomock.Any(), gomock.Any(), gomock.Any()).Return(TokenUse(""), nil)
validator.EXPECT().IntrospectToken(context.TODO(), "introspect-token", gomock.Any(), gomock.Any(), gomock.Any()).Return(TokenUse(""), newErr)
validator.EXPECT().IntrospectToken(ctx, "some-token", gomock.Any(), gomock.Any(), gomock.Any()).Return(TokenUse(""), nil)
validator.EXPECT().IntrospectToken(ctx, "introspect-token", gomock.Any(), gomock.Any(), gomock.Any()).Return(TokenUse(""), newErr)
},
isActive: false,
expectErr: ErrInactiveToken,
Expand All @@ -158,8 +162,8 @@ func TestNewIntrospectionRequest(t *testing.T) {
"token": []string{"introspect-token"},
},
}
validator.EXPECT().IntrospectToken(context.TODO(), "some-token", gomock.Any(), gomock.Any(), gomock.Any()).Return(TokenUse(""), nil)
validator.EXPECT().IntrospectToken(context.TODO(), "introspect-token", gomock.Any(), gomock.Any(), gomock.Any()).Return(TokenUse(""), nil)
validator.EXPECT().IntrospectToken(ctx, "some-token", gomock.Any(), gomock.Any(), gomock.Any()).Return(TokenUse(""), nil)
validator.EXPECT().IntrospectToken(ctx, "introspect-token", gomock.Any(), gomock.Any(), gomock.Any()).Return(TokenUse(""), nil)
},
isActive: true,
},
Expand All @@ -177,7 +181,7 @@ func TestNewIntrospectionRequest(t *testing.T) {
"token": []string{"introspect-token"},
},
}
validator.EXPECT().IntrospectToken(context.TODO(), "introspect-token", gomock.Any(), gomock.Any(), gomock.Any()).Return(TokenUse(""), nil)
validator.EXPECT().IntrospectToken(ctx, "introspect-token", gomock.Any(), gomock.Any(), gomock.Any()).Return(TokenUse(""), nil)
},
isActive: true,
},
Expand All @@ -195,7 +199,7 @@ func TestNewIntrospectionRequest(t *testing.T) {
"token": []string{"introspect-token"},
},
}
validator.EXPECT().IntrospectToken(context.TODO(), "introspect-token", gomock.Any(), gomock.Any(), gomock.Any()).Return(TokenUse(""), nil)
validator.EXPECT().IntrospectToken(ctx, "introspect-token", gomock.Any(), gomock.Any(), gomock.Any()).Return(TokenUse(""), nil)
},
isActive: true,
},
Expand Down
2 changes: 2 additions & 0 deletions revoke_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ import (
// An invalid token type hint value is ignored by the authorization
// server and does not influence the revocation response.
func (f *Fosite) NewRevocationRequest(ctx context.Context, r *http.Request) error {
ctx = context.WithValue(ctx, RequestContextKey, r)

if r.Method != "POST" {
return errorsx.WithStack(ErrInvalidRequest.WithHintf("HTTP method is '%s' but expected 'POST'.", r.Method))
} else if err := r.ParseMultipartForm(1 << 20); err != nil && err != http.ErrNotMultipart {
Expand Down
12 changes: 7 additions & 5 deletions revoke_handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ func TestNewRevocationRequest(t *testing.T) {
hasher := internal.NewMockHasher(ctrl)
defer ctrl.Finish()

ctx := gomock.AssignableToTypeOf(context.WithValue(context.TODO(), ContextKey("test"), nil))

client := &DefaultClient{}
fosite := &Fosite{Store: store, Hasher: hasher}
for k, c := range []struct {
Expand Down Expand Up @@ -102,7 +104,7 @@ func TestNewRevocationRequest(t *testing.T) {
store.EXPECT().GetClient(gomock.Any(), gomock.Eq("foo")).Return(client, nil)
client.Secret = []byte("foo")
client.Public = false
hasher.EXPECT().Compare(context.TODO(), gomock.Eq([]byte("foo")), gomock.Eq([]byte("bar"))).Return(errors.New(""))
hasher.EXPECT().Compare(ctx, gomock.Eq([]byte("foo")), gomock.Eq([]byte("bar"))).Return(errors.New(""))
},
},
{
Expand All @@ -118,7 +120,7 @@ func TestNewRevocationRequest(t *testing.T) {
store.EXPECT().GetClient(gomock.Any(), gomock.Eq("foo")).Return(client, nil)
client.Secret = []byte("foo")
client.Public = false
hasher.EXPECT().Compare(context.TODO(), gomock.Eq([]byte("foo")), gomock.Eq([]byte("bar"))).Return(nil)
hasher.EXPECT().Compare(ctx, gomock.Eq([]byte("foo")), gomock.Eq([]byte("bar"))).Return(nil)
handler.EXPECT().RevokeToken(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil)
},
handlers: RevocationHandlers{handler},
Expand All @@ -137,7 +139,7 @@ func TestNewRevocationRequest(t *testing.T) {
store.EXPECT().GetClient(gomock.Any(), gomock.Eq("foo")).Return(client, nil)
client.Secret = []byte("foo")
client.Public = false
hasher.EXPECT().Compare(context.TODO(), gomock.Eq([]byte("foo")), gomock.Eq([]byte("bar"))).Return(nil)
hasher.EXPECT().Compare(ctx, gomock.Eq([]byte("foo")), gomock.Eq([]byte("bar"))).Return(nil)
handler.EXPECT().RevokeToken(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil)
},
handlers: RevocationHandlers{handler},
Expand Down Expand Up @@ -173,7 +175,7 @@ func TestNewRevocationRequest(t *testing.T) {
store.EXPECT().GetClient(gomock.Any(), gomock.Eq("foo")).Return(client, nil)
client.Secret = []byte("foo")
client.Public = false
hasher.EXPECT().Compare(context.TODO(), gomock.Eq([]byte("foo")), gomock.Eq([]byte("bar"))).Return(nil)
hasher.EXPECT().Compare(ctx, gomock.Eq([]byte("foo")), gomock.Eq([]byte("bar"))).Return(nil)
handler.EXPECT().RevokeToken(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil)
},
handlers: RevocationHandlers{handler},
Expand All @@ -192,7 +194,7 @@ func TestNewRevocationRequest(t *testing.T) {
store.EXPECT().GetClient(gomock.Any(), gomock.Eq("foo")).Return(client, nil)
client.Secret = []byte("foo")
client.Public = false
hasher.EXPECT().Compare(context.TODO(), gomock.Eq([]byte("foo")), gomock.Eq([]byte("bar"))).Return(nil)
hasher.EXPECT().Compare(ctx, gomock.Eq([]byte("foo")), gomock.Eq([]byte("bar"))).Return(nil)
handler.EXPECT().RevokeToken(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil)
},
handlers: RevocationHandlers{handler},
Expand Down