Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feature: WEOS-1343 Create middleware for handling JWT in the incoming Authorization header #99

Merged
merged 6 commits into from
Feb 17, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/dev.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ env:
SLACK_ICON: https://github.com/wepala.png?size=48
SLACK_WEBHOOK: ${{ secrets.SLACK_WEBHOOK }}
SLACK_FOOTER: copyright 2022 Wepala
OAUTH_TEST_KEY: ${{ secrets.OAUTH_TEST_KEY }}

jobs:
build-api:
Expand Down
1 change: 1 addition & 0 deletions .github/workflows/release.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ env:
SLACK_ICON: https://github.com/wepala.png?size=48
SLACK_WEBHOOK: ${{ secrets.SLACK_WEBHOOK }}
SLACK_FOOTER: copyright 2022 Wepala
OAUTH_TEST_KEY: ${{ secrets.OAUTH_TEST_KEY }}

jobs:
build-api:
Expand Down
2 changes: 2 additions & 0 deletions context/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ const FILTERS ContextKey = "_filters"
const SORTS ContextKey = "_sorts"
const PAYLOAD ContextKey = "_payload"
const SEQUENCE_NO string = "sequence_no"
const AUTHORIZATION string = "Authorization"
const USER_ID_EXTENSION ContextKey = "X-USER-ID"

//Path initializers are run per path and can be used to configure routes that are not defined in the open api spec
const METHODS_FOUND ContextKey = "_methods_found"
Expand Down
1 change: 1 addition & 0 deletions controllers/rest/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,7 @@ func (p *RESTAPI) Initialize(ctxt context.Context) error {
p.RegisterController("CreateBatchController", CreateBatchController)
//register standard middleware
p.RegisterMiddleware("Context", Context)
p.RegisterMiddleware("AuthorizationMiddleware", AuthorizationMiddleware)
p.RegisterMiddleware("CreateMiddleware", CreateMiddleware)
p.RegisterMiddleware("CreateBatchMiddleware", CreateBatchMiddleware)
p.RegisterMiddleware("UpdateMiddleware", UpdateMiddleware)
Expand Down
120 changes: 120 additions & 0 deletions controllers/rest/controller_standard.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,12 @@ import (
"encoding/json"
"errors"
"fmt"
"github.com/coreos/go-oidc/v3/oidc"
"io/ioutil"
"net/http"
"strconv"
"strings"
"time"

"github.com/wepala/weos/projections"

Expand Down Expand Up @@ -664,3 +667,120 @@ func HealthCheck(api *RESTAPI, projection projections.Projection, commandDispatc
}

}

//AuthorizationMiddleware handling JWT in incoming Authorization header
func AuthorizationMiddleware(api *RESTAPI, projection projections.Projection, commandDispatcher model.CommandDispatcher, eventSource model.EventRepository, entityFactory model.EntityFactory, path *openapi3.PathItem, operation *openapi3.Operation) echo.MiddlewareFunc {
var openIdConnectUrl string
securityCheck := true
var verifiers []*oidc.IDTokenVerifier
algs := []string{"RS256", "RS384", "RS512", "HS256"}
if operation.Security != nil && len(*operation.Security) == 0 {
securityCheck = false
}
for _, schemes := range api.Swagger.Components.SecuritySchemes {
//get the open id connect url
if openIdUrl, ok := schemes.Value.ExtensionProps.Extensions[OPENIDCONNECTURLEXTENSION]; ok {
err := json.Unmarshal(openIdUrl.(json.RawMessage), &openIdConnectUrl)
if err != nil {
api.EchoInstance().Logger.Errorf("unable to unmarshal open id connect url '%s'", err)
} else {
//check if it is a valid open id connect url
if !strings.Contains(openIdConnectUrl, ".well-known/openid-configuration") {
api.EchoInstance().Logger.Warnf("invalid open id connect url: %s", openIdConnectUrl)
} else {
//get the Jwk url from open id connect url
jwksUrl, err := GetJwkUrl(openIdConnectUrl)
if err != nil {
api.EchoInstance().Logger.Errorf("unexpected error getting the jwks url: %s", err)
} else {
//by default skipExpiryCheck is false meaning it will not run an expiry check
skipExpiryCheck := false
//get skipexpirycheck that is an extension in the openapi spec
if expireCheck, ok := schemes.Value.ExtensionProps.Extensions[SKIPEXPIRYCHECKEXTENSION]; ok {
err := json.Unmarshal(expireCheck.(json.RawMessage), &skipExpiryCheck)
if err != nil {
api.EchoInstance().Logger.Errorf("unable to unmarshal skip expiry '%s'", err)
}
}
//create key set and verifier
keySet := oidc.NewRemoteKeySet(context.Background(), jwksUrl)
tokenVerifier := oidc.NewVerifier(openIdConnectUrl, keySet, &oidc.Config{
ClientID: "",
SupportedSigningAlgs: algs,
SkipClientIDCheck: true,
SkipExpiryCheck: skipExpiryCheck,
SkipIssuerCheck: true,
Now: time.Now,
})
verifiers = append(verifiers, tokenVerifier)
}
}

}
}

}

return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(ctxt echo.Context) error {
var err error
if !securityCheck {
return next(ctxt)
}
if len(verifiers) == 0 {
api.e.Logger.Errorf("unexpected error no verifiers were set")
return NewControllerError("unexpected error no verifiers were set", nil, http.StatusBadRequest)
}
newContext := ctxt.Request().Context()
token, ok := newContext.Value(weoscontext.AUTHORIZATION).(string)
if !ok || token == "" {
api.e.Logger.Errorf("no JWT token was found")
return NewControllerError("no JWT token was found", nil, http.StatusUnauthorized)
}
jwtToken := strings.Replace(token, "Bearer ", "", -1)
var idToken *oidc.IDToken
for _, tokenVerifier := range verifiers {
idToken, err = tokenVerifier.Verify(newContext, jwtToken)
if err != nil || idToken == nil {
api.e.Logger.Errorf(err.Error())
return NewControllerError("unexpected error verifying token", err, http.StatusUnauthorized)
}
}

newContext = context.WithValue(newContext, weoscontext.USER_ID_EXTENSION, idToken.Subject)
request := ctxt.Request().WithContext(newContext)
ctxt.SetRequest(request)
return next(ctxt)

}
}
}

//GetJwkUrl fetches the jwk url from the open id connect url
func GetJwkUrl(openIdUrl string) (string, error) {
//fetches the response from the connect id url
resp, err := http.Get(openIdUrl)
if err != nil || resp == nil {
return "", fmt.Errorf("unexpected error fetching open id connect url: %s", err)
}
defer resp.Body.Close()
// reads the body
body, err := ioutil.ReadAll(resp.Body)
if err != nil {
return "", fmt.Errorf("unable to read response body: %v", err)
}
//check the response status
if resp.StatusCode != http.StatusOK {
return "", fmt.Errorf("expected open id connect url response code to be %d got %d ", http.StatusOK, resp.StatusCode)
}
// unmarshall the body to a struct we can use to find the jwk uri
var info map[string]interface{}
err = json.Unmarshal(body, &info)
if err != nil {
return "", fmt.Errorf("unexpected error unmarshalling open id connect url response %s", err)
}
if info["jwks_uri"] == nil || info["jwks_uri"].(string) == "" {
return "", fmt.Errorf("no jwks uri found")
}
return info["jwks_uri"].(string), nil
}
92 changes: 75 additions & 17 deletions controllers/rest/controller_standard_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ func TestStandardControllers_Create(t *testing.T) {
Property: mockPayload,
}

projections := &ProjectionMock{
projections := &GormProjectionMock{
GetContentEntityFunc: func(ctx context.Context, entityFactory model.EntityFactory, weosID string) (*model.ContentEntity, error) {
if ctx == nil {
t.Errorf("expected to find context but got nil")
Expand Down Expand Up @@ -281,7 +281,7 @@ func TestStandardControllers_CreateBatch(t *testing.T) {
},
}

projection := &ProjectionMock{
projection := &GormProjectionMock{
GetByKeyFunc: func(ctxt context.Context, entityFactory model.EntityFactory, identifiers map[string]interface{}) (map[string]interface{}, error) {
return nil, nil
},
Expand Down Expand Up @@ -457,7 +457,7 @@ func TestStandardControllers_Update(t *testing.T) {
mockEntity.SequenceNo = int64(1)
mockEntity.Property = mockBlog

projection := &ProjectionMock{
projection := &GormProjectionMock{
GetByKeyFunc: func(ctxt context.Context, entityFactory model.EntityFactory, identifiers map[string]interface{}) (map[string]interface{}, error) {
return nil, nil
},
Expand Down Expand Up @@ -575,7 +575,7 @@ func TestStandardControllers_View(t *testing.T) {
}}

t.Run("Testing the generic view endpoint", func(t *testing.T) {
projection := &ProjectionMock{
projection := &GormProjectionMock{
GetByKeyFunc: func(ctxt context.Context, entityFactory model.EntityFactory, identifiers map[string]interface{}) (map[string]interface{}, error) {
return map[string]interface{}{
"id": "1",
Expand Down Expand Up @@ -628,7 +628,7 @@ func TestStandardControllers_View(t *testing.T) {
}
})
t.Run("Testing view with entity id", func(t *testing.T) {
projection := &ProjectionMock{
projection := &GormProjectionMock{
GetByKeyFunc: func(ctxt context.Context, entityFactory model.EntityFactory, identifiers map[string]interface{}) (map[string]interface{}, error) {
if entityFactory == nil {
t.Errorf("expected to find entity factory got nil")
Expand Down Expand Up @@ -696,7 +696,7 @@ func TestStandardControllers_View(t *testing.T) {
}
})
t.Run("invalid entity id should return 404", func(t *testing.T) {
projection := &ProjectionMock{
projection := &GormProjectionMock{
GetByKeyFunc: func(ctxt context.Context, entityFactory model.EntityFactory, identifiers map[string]interface{}) (map[string]interface{}, error) {
return map[string]interface{}{
"id": "1",
Expand Down Expand Up @@ -771,7 +771,7 @@ func TestStandardControllers_View(t *testing.T) {
}
})
t.Run("invalid numeric entity id should return 404", func(t *testing.T) {
projection := &ProjectionMock{
projection := &GormProjectionMock{
GetByKeyFunc: func(ctxt context.Context, entityFactory model.EntityFactory, identifiers map[string]interface{}) (map[string]interface{}, error) {
return map[string]interface{}{
"id": "1",
Expand Down Expand Up @@ -846,7 +846,7 @@ func TestStandardControllers_View(t *testing.T) {
}
})
t.Run("view with sequence no", func(t *testing.T) {
projection := &ProjectionMock{
projection := &GormProjectionMock{
GetByKeyFunc: func(ctxt context.Context, entityFactory model.EntityFactory, identifiers map[string]interface{}) (map[string]interface{}, error) {
return map[string]interface{}{
"id": "1",
Expand Down Expand Up @@ -926,7 +926,7 @@ func TestStandardControllers_View(t *testing.T) {
}
})
t.Run("view with invalid sequence no", func(t *testing.T) {
projection := &ProjectionMock{
projection := &GormProjectionMock{
GetByKeyFunc: func(ctxt context.Context, entityFactory model.EntityFactory, identifiers map[string]interface{}) (map[string]interface{}, error) {
return map[string]interface{}{
"id": "1",
Expand Down Expand Up @@ -1019,7 +1019,7 @@ func TestStandardControllers_List(t *testing.T) {
array := []map[string]interface{}{}
array = append(array, mockBlog, mockBlog1)

mockProjection := &ProjectionMock{
mockProjection := &GormProjectionMock{
GetContentEntitiesFunc: func(ctx context.Context, entityFactory model.EntityFactory, page, limit int, query string, sortOptions map[string]string, filterOptions map[string]interface{}) ([]map[string]interface{}, int64, error) {
return array, 2, nil
},
Expand Down Expand Up @@ -1133,7 +1133,7 @@ func TestStandardControllers_ListFilters(t *testing.T) {
array := []map[string]interface{}{}
array = append(array, mockBlog, mockBlog1)

mockProjection := &ProjectionMock{
mockProjection := &GormProjectionMock{
GetContentEntitiesFunc: func(ctx context.Context, entityFactory model.EntityFactory, page, limit int, query string, sortOptions map[string]string, filterOptions map[string]interface{}) ([]map[string]interface{}, int64, error) {
if entityFactory == nil {
t.Errorf("no entity factory found")
Expand Down Expand Up @@ -1382,7 +1382,7 @@ func TestStandardControllers_FormUrlEncoded_Create(t *testing.T) {
Property: mockPayload,
}

projections := &ProjectionMock{
projections := &GormProjectionMock{
GetContentEntityFunc: func(ctx context.Context, entityFactory model.EntityFactory, weosID string) (*model.ContentEntity, error) {
return mockContentEntity, nil
},
Expand Down Expand Up @@ -1540,7 +1540,7 @@ func TestStandardControllers_FormData_Create(t *testing.T) {
Property: mockPayload,
}

projections := &ProjectionMock{
projections := &GormProjectionMock{
GetContentEntityFunc: func(ctx context.Context, entityFactory model.EntityFactory, weosID string) (*model.ContentEntity, error) {
return mockContentEntity, nil
},
Expand Down Expand Up @@ -1690,7 +1690,7 @@ func TestStandardControllers_DeleteEtag(t *testing.T) {
mockEntity.SequenceNo = int64(1)
mockEntity.Property = mockBlog

projection := &ProjectionMock{
projection := &GormProjectionMock{
GetByKeyFunc: func(ctxt context.Context, entityFactory model.EntityFactory, identifiers map[string]interface{}) (map[string]interface{}, error) {
return nil, nil
},
Expand Down Expand Up @@ -1799,7 +1799,7 @@ func TestStandardControllers_DeleteID(t *testing.T) {
},
}

projection := &ProjectionMock{
projection := &GormProjectionMock{
GetByKeyFunc: func(ctxt context.Context, entityFactory model.EntityFactory, identifiers map[string]interface{}) (map[string]interface{}, error) {
return mockInterface, nil
},
Expand Down Expand Up @@ -1852,7 +1852,7 @@ func TestStandardControllers_DeleteID(t *testing.T) {
},
}

projection1 := &ProjectionMock{
projection1 := &GormProjectionMock{
GetByKeyFunc: func(ctxt context.Context, entityFactory model.EntityFactory, identifiers map[string]interface{}) (map[string]interface{}, error) {
return mockInterface1, nil
},
Expand Down Expand Up @@ -1907,7 +1907,7 @@ func TestStandardControllers_DeleteID(t *testing.T) {

err1 := fmt.Errorf("this is an error")

projection1 := &ProjectionMock{
projection1 := &GormProjectionMock{
GetByKeyFunc: func(ctxt context.Context, entityFactory model.EntityFactory, identifiers map[string]interface{}) (map[string]interface{}, error) {
return nil, err1
},
Expand Down Expand Up @@ -1948,3 +1948,61 @@ func TestStandardControllers_DeleteID(t *testing.T) {
}
})
}

func TestStandardControllers_AuthenticateMiddleware(t *testing.T) {
//instantiate api
api, err := rest.New("./fixtures/blog-security.yaml")
if err != nil {
t.Fatalf("unexpected error loading api '%s'", err)
}
err = api.Initialize(context.TODO())
if err != nil {
t.Fatalf("un expected error initializing api '%s'", err)
}
e := api.EchoInstance()

t.Run("no jwt token added when required", func(t *testing.T) {
description := "testing 1st blog description"
mockBlog := &TestBlog{Description: &description}
reqBytes, err := json.Marshal(mockBlog)
if err != nil {
t.Fatalf("error setting up request %s", err)
}
body := bytes.NewReader(reqBytes)
resp := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/blogs", body)
req.Header.Set("Content-Type", "application/json")
e.ServeHTTP(resp, req)
if resp.Result().StatusCode != http.StatusUnauthorized {
t.Errorf("expected the response code to be %d, got %d", http.StatusUnauthorized, resp.Result().StatusCode)
}
})
t.Run("security parameter array is empty", func(t *testing.T) {
resp := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/health", nil)
e.ServeHTTP(resp, req)
if resp.Result().StatusCode != http.StatusOK {
t.Errorf("expected the response code to be %d, got %d", http.StatusOK, resp.Result().StatusCode)
}
})
t.Run("jwt token added", func(t *testing.T) {
description := "testing 1st blog description"
url := "www.example.com"
title := "example"
mockBlog := &TestBlog{Title: &title, Url: &url, Description: &description}
reqBytes, err := json.Marshal(mockBlog)
if err != nil {
t.Fatalf("error setting up request %s", err)
}
token := os.Getenv("OAUTH_TEST_KEY")
body := bytes.NewReader(reqBytes)
resp := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/blogs", body)
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+token)
e.ServeHTTP(resp, req)
if resp.Result().StatusCode != http.StatusCreated {
t.Errorf("expected the response code to be %d, got %d", http.StatusCreated, resp.Result().StatusCode)
}
})
}
Loading