Skip to content

Commit

Permalink
Merge pull request #99 from wepala/WEOS-1349
Browse files Browse the repository at this point in the history
feature: WEOS-1343 Create middleware for handling JWT in the incoming Authorization header
  • Loading branch information
atoniaw committed Feb 17, 2022
2 parents 1a5b302 + 6202269 commit 6bc3d91
Show file tree
Hide file tree
Showing 18 changed files with 1,122 additions and 137 deletions.
1 change: 1 addition & 0 deletions .github/workflows/dev.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ env:
SLACK_ICON: https://github.com/wepala.png?size=48
SLACK_WEBHOOK: ${{ secrets.SLACK_WEBHOOK }}
SLACK_FOOTER: copyright 2022 Wepala
OAUTH_TEST_KEY: ${{ secrets.OAUTH_TEST_KEY }}

jobs:
build-api:
Expand Down
1 change: 1 addition & 0 deletions .github/workflows/release.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ env:
SLACK_ICON: https://github.com/wepala.png?size=48
SLACK_WEBHOOK: ${{ secrets.SLACK_WEBHOOK }}
SLACK_FOOTER: copyright 2022 Wepala
OAUTH_TEST_KEY: ${{ secrets.OAUTH_TEST_KEY }}

jobs:
build-api:
Expand Down
2 changes: 2 additions & 0 deletions context/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ const FILTERS ContextKey = "_filters"
const SORTS ContextKey = "_sorts"
const PAYLOAD ContextKey = "_payload"
const SEQUENCE_NO string = "sequence_no"
const AUTHORIZATION string = "Authorization"
const USER_ID_EXTENSION ContextKey = "X-USER-ID"

//Path initializers are run per path and can be used to configure routes that are not defined in the open api spec
const METHODS_FOUND ContextKey = "_methods_found"
Expand Down
1 change: 1 addition & 0 deletions controllers/rest/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,7 @@ func (p *RESTAPI) Initialize(ctxt context.Context) error {
p.RegisterController("CreateBatchController", CreateBatchController)
//register standard middleware
p.RegisterMiddleware("Context", Context)
p.RegisterMiddleware("AuthorizationMiddleware", AuthorizationMiddleware)
p.RegisterMiddleware("CreateMiddleware", CreateMiddleware)
p.RegisterMiddleware("CreateBatchMiddleware", CreateBatchMiddleware)
p.RegisterMiddleware("UpdateMiddleware", UpdateMiddleware)
Expand Down
120 changes: 120 additions & 0 deletions controllers/rest/controller_standard.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,12 @@ import (
"encoding/json"
"errors"
"fmt"
"github.com/coreos/go-oidc/v3/oidc"
"io/ioutil"
"net/http"
"strconv"
"strings"
"time"

"github.com/wepala/weos/projections"

Expand Down Expand Up @@ -664,3 +667,120 @@ func HealthCheck(api *RESTAPI, projection projections.Projection, commandDispatc
}

}

//AuthorizationMiddleware handling JWT in incoming Authorization header
func AuthorizationMiddleware(api *RESTAPI, projection projections.Projection, commandDispatcher model.CommandDispatcher, eventSource model.EventRepository, entityFactory model.EntityFactory, path *openapi3.PathItem, operation *openapi3.Operation) echo.MiddlewareFunc {
var openIdConnectUrl string
securityCheck := true
var verifiers []*oidc.IDTokenVerifier
algs := []string{"RS256", "RS384", "RS512", "HS256"}
if operation.Security != nil && len(*operation.Security) == 0 {
securityCheck = false
}
for _, schemes := range api.Swagger.Components.SecuritySchemes {
//get the open id connect url
if openIdUrl, ok := schemes.Value.ExtensionProps.Extensions[OPENIDCONNECTURLEXTENSION]; ok {
err := json.Unmarshal(openIdUrl.(json.RawMessage), &openIdConnectUrl)
if err != nil {
api.EchoInstance().Logger.Errorf("unable to unmarshal open id connect url '%s'", err)
} else {
//check if it is a valid open id connect url
if !strings.Contains(openIdConnectUrl, ".well-known/openid-configuration") {
api.EchoInstance().Logger.Warnf("invalid open id connect url: %s", openIdConnectUrl)
} else {
//get the Jwk url from open id connect url
jwksUrl, err := GetJwkUrl(openIdConnectUrl)
if err != nil {
api.EchoInstance().Logger.Errorf("unexpected error getting the jwks url: %s", err)
} else {
//by default skipExpiryCheck is false meaning it will not run an expiry check
skipExpiryCheck := false
//get skipexpirycheck that is an extension in the openapi spec
if expireCheck, ok := schemes.Value.ExtensionProps.Extensions[SKIPEXPIRYCHECKEXTENSION]; ok {
err := json.Unmarshal(expireCheck.(json.RawMessage), &skipExpiryCheck)
if err != nil {
api.EchoInstance().Logger.Errorf("unable to unmarshal skip expiry '%s'", err)
}
}
//create key set and verifier
keySet := oidc.NewRemoteKeySet(context.Background(), jwksUrl)
tokenVerifier := oidc.NewVerifier(openIdConnectUrl, keySet, &oidc.Config{
ClientID: "",
SupportedSigningAlgs: algs,
SkipClientIDCheck: true,
SkipExpiryCheck: skipExpiryCheck,
SkipIssuerCheck: true,
Now: time.Now,
})
verifiers = append(verifiers, tokenVerifier)
}
}

}
}

}

return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(ctxt echo.Context) error {
var err error
if !securityCheck {
return next(ctxt)
}
if len(verifiers) == 0 {
api.e.Logger.Errorf("unexpected error no verifiers were set")
return NewControllerError("unexpected error no verifiers were set", nil, http.StatusBadRequest)
}
newContext := ctxt.Request().Context()
token, ok := newContext.Value(weoscontext.AUTHORIZATION).(string)
if !ok || token == "" {
api.e.Logger.Errorf("no JWT token was found")
return NewControllerError("no JWT token was found", nil, http.StatusUnauthorized)
}
jwtToken := strings.Replace(token, "Bearer ", "", -1)
var idToken *oidc.IDToken
for _, tokenVerifier := range verifiers {
idToken, err = tokenVerifier.Verify(newContext, jwtToken)
if err != nil || idToken == nil {
api.e.Logger.Errorf(err.Error())
return NewControllerError("unexpected error verifying token", err, http.StatusUnauthorized)
}
}

newContext = context.WithValue(newContext, weoscontext.USER_ID_EXTENSION, idToken.Subject)
request := ctxt.Request().WithContext(newContext)
ctxt.SetRequest(request)
return next(ctxt)

}
}
}

//GetJwkUrl fetches the jwk url from the open id connect url
func GetJwkUrl(openIdUrl string) (string, error) {
//fetches the response from the connect id url
resp, err := http.Get(openIdUrl)
if err != nil || resp == nil {
return "", fmt.Errorf("unexpected error fetching open id connect url: %s", err)
}
defer resp.Body.Close()
// reads the body
body, err := ioutil.ReadAll(resp.Body)
if err != nil {
return "", fmt.Errorf("unable to read response body: %v", err)
}
//check the response status
if resp.StatusCode != http.StatusOK {
return "", fmt.Errorf("expected open id connect url response code to be %d got %d ", http.StatusOK, resp.StatusCode)
}
// unmarshall the body to a struct we can use to find the jwk uri
var info map[string]interface{}
err = json.Unmarshal(body, &info)
if err != nil {
return "", fmt.Errorf("unexpected error unmarshalling open id connect url response %s", err)
}
if info["jwks_uri"] == nil || info["jwks_uri"].(string) == "" {
return "", fmt.Errorf("no jwks uri found")
}
return info["jwks_uri"].(string), nil
}
92 changes: 75 additions & 17 deletions controllers/rest/controller_standard_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ func TestStandardControllers_Create(t *testing.T) {
Property: mockPayload,
}

projections := &ProjectionMock{
projections := &GormProjectionMock{
GetContentEntityFunc: func(ctx context.Context, entityFactory model.EntityFactory, weosID string) (*model.ContentEntity, error) {
if ctx == nil {
t.Errorf("expected to find context but got nil")
Expand Down Expand Up @@ -281,7 +281,7 @@ func TestStandardControllers_CreateBatch(t *testing.T) {
},
}

projection := &ProjectionMock{
projection := &GormProjectionMock{
GetByKeyFunc: func(ctxt context.Context, entityFactory model.EntityFactory, identifiers map[string]interface{}) (map[string]interface{}, error) {
return nil, nil
},
Expand Down Expand Up @@ -457,7 +457,7 @@ func TestStandardControllers_Update(t *testing.T) {
mockEntity.SequenceNo = int64(1)
mockEntity.Property = mockBlog

projection := &ProjectionMock{
projection := &GormProjectionMock{
GetByKeyFunc: func(ctxt context.Context, entityFactory model.EntityFactory, identifiers map[string]interface{}) (map[string]interface{}, error) {
return nil, nil
},
Expand Down Expand Up @@ -575,7 +575,7 @@ func TestStandardControllers_View(t *testing.T) {
}}

t.Run("Testing the generic view endpoint", func(t *testing.T) {
projection := &ProjectionMock{
projection := &GormProjectionMock{
GetByKeyFunc: func(ctxt context.Context, entityFactory model.EntityFactory, identifiers map[string]interface{}) (map[string]interface{}, error) {
return map[string]interface{}{
"id": "1",
Expand Down Expand Up @@ -628,7 +628,7 @@ func TestStandardControllers_View(t *testing.T) {
}
})
t.Run("Testing view with entity id", func(t *testing.T) {
projection := &ProjectionMock{
projection := &GormProjectionMock{
GetByKeyFunc: func(ctxt context.Context, entityFactory model.EntityFactory, identifiers map[string]interface{}) (map[string]interface{}, error) {
if entityFactory == nil {
t.Errorf("expected to find entity factory got nil")
Expand Down Expand Up @@ -696,7 +696,7 @@ func TestStandardControllers_View(t *testing.T) {
}
})
t.Run("invalid entity id should return 404", func(t *testing.T) {
projection := &ProjectionMock{
projection := &GormProjectionMock{
GetByKeyFunc: func(ctxt context.Context, entityFactory model.EntityFactory, identifiers map[string]interface{}) (map[string]interface{}, error) {
return map[string]interface{}{
"id": "1",
Expand Down Expand Up @@ -771,7 +771,7 @@ func TestStandardControllers_View(t *testing.T) {
}
})
t.Run("invalid numeric entity id should return 404", func(t *testing.T) {
projection := &ProjectionMock{
projection := &GormProjectionMock{
GetByKeyFunc: func(ctxt context.Context, entityFactory model.EntityFactory, identifiers map[string]interface{}) (map[string]interface{}, error) {
return map[string]interface{}{
"id": "1",
Expand Down Expand Up @@ -846,7 +846,7 @@ func TestStandardControllers_View(t *testing.T) {
}
})
t.Run("view with sequence no", func(t *testing.T) {
projection := &ProjectionMock{
projection := &GormProjectionMock{
GetByKeyFunc: func(ctxt context.Context, entityFactory model.EntityFactory, identifiers map[string]interface{}) (map[string]interface{}, error) {
return map[string]interface{}{
"id": "1",
Expand Down Expand Up @@ -926,7 +926,7 @@ func TestStandardControllers_View(t *testing.T) {
}
})
t.Run("view with invalid sequence no", func(t *testing.T) {
projection := &ProjectionMock{
projection := &GormProjectionMock{
GetByKeyFunc: func(ctxt context.Context, entityFactory model.EntityFactory, identifiers map[string]interface{}) (map[string]interface{}, error) {
return map[string]interface{}{
"id": "1",
Expand Down Expand Up @@ -1019,7 +1019,7 @@ func TestStandardControllers_List(t *testing.T) {
array := []map[string]interface{}{}
array = append(array, mockBlog, mockBlog1)

mockProjection := &ProjectionMock{
mockProjection := &GormProjectionMock{
GetContentEntitiesFunc: func(ctx context.Context, entityFactory model.EntityFactory, page, limit int, query string, sortOptions map[string]string, filterOptions map[string]interface{}) ([]map[string]interface{}, int64, error) {
return array, 2, nil
},
Expand Down Expand Up @@ -1133,7 +1133,7 @@ func TestStandardControllers_ListFilters(t *testing.T) {
array := []map[string]interface{}{}
array = append(array, mockBlog, mockBlog1)

mockProjection := &ProjectionMock{
mockProjection := &GormProjectionMock{
GetContentEntitiesFunc: func(ctx context.Context, entityFactory model.EntityFactory, page, limit int, query string, sortOptions map[string]string, filterOptions map[string]interface{}) ([]map[string]interface{}, int64, error) {
if entityFactory == nil {
t.Errorf("no entity factory found")
Expand Down Expand Up @@ -1382,7 +1382,7 @@ func TestStandardControllers_FormUrlEncoded_Create(t *testing.T) {
Property: mockPayload,
}

projections := &ProjectionMock{
projections := &GormProjectionMock{
GetContentEntityFunc: func(ctx context.Context, entityFactory model.EntityFactory, weosID string) (*model.ContentEntity, error) {
return mockContentEntity, nil
},
Expand Down Expand Up @@ -1540,7 +1540,7 @@ func TestStandardControllers_FormData_Create(t *testing.T) {
Property: mockPayload,
}

projections := &ProjectionMock{
projections := &GormProjectionMock{
GetContentEntityFunc: func(ctx context.Context, entityFactory model.EntityFactory, weosID string) (*model.ContentEntity, error) {
return mockContentEntity, nil
},
Expand Down Expand Up @@ -1690,7 +1690,7 @@ func TestStandardControllers_DeleteEtag(t *testing.T) {
mockEntity.SequenceNo = int64(1)
mockEntity.Property = mockBlog

projection := &ProjectionMock{
projection := &GormProjectionMock{
GetByKeyFunc: func(ctxt context.Context, entityFactory model.EntityFactory, identifiers map[string]interface{}) (map[string]interface{}, error) {
return nil, nil
},
Expand Down Expand Up @@ -1799,7 +1799,7 @@ func TestStandardControllers_DeleteID(t *testing.T) {
},
}

projection := &ProjectionMock{
projection := &GormProjectionMock{
GetByKeyFunc: func(ctxt context.Context, entityFactory model.EntityFactory, identifiers map[string]interface{}) (map[string]interface{}, error) {
return mockInterface, nil
},
Expand Down Expand Up @@ -1852,7 +1852,7 @@ func TestStandardControllers_DeleteID(t *testing.T) {
},
}

projection1 := &ProjectionMock{
projection1 := &GormProjectionMock{
GetByKeyFunc: func(ctxt context.Context, entityFactory model.EntityFactory, identifiers map[string]interface{}) (map[string]interface{}, error) {
return mockInterface1, nil
},
Expand Down Expand Up @@ -1907,7 +1907,7 @@ func TestStandardControllers_DeleteID(t *testing.T) {

err1 := fmt.Errorf("this is an error")

projection1 := &ProjectionMock{
projection1 := &GormProjectionMock{
GetByKeyFunc: func(ctxt context.Context, entityFactory model.EntityFactory, identifiers map[string]interface{}) (map[string]interface{}, error) {
return nil, err1
},
Expand Down Expand Up @@ -1948,3 +1948,61 @@ func TestStandardControllers_DeleteID(t *testing.T) {
}
})
}

func TestStandardControllers_AuthenticateMiddleware(t *testing.T) {
//instantiate api
api, err := rest.New("./fixtures/blog-security.yaml")
if err != nil {
t.Fatalf("unexpected error loading api '%s'", err)
}
err = api.Initialize(context.TODO())
if err != nil {
t.Fatalf("un expected error initializing api '%s'", err)
}
e := api.EchoInstance()

t.Run("no jwt token added when required", func(t *testing.T) {
description := "testing 1st blog description"
mockBlog := &TestBlog{Description: &description}
reqBytes, err := json.Marshal(mockBlog)
if err != nil {
t.Fatalf("error setting up request %s", err)
}
body := bytes.NewReader(reqBytes)
resp := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/blogs", body)
req.Header.Set("Content-Type", "application/json")
e.ServeHTTP(resp, req)
if resp.Result().StatusCode != http.StatusUnauthorized {
t.Errorf("expected the response code to be %d, got %d", http.StatusUnauthorized, resp.Result().StatusCode)
}
})
t.Run("security parameter array is empty", func(t *testing.T) {
resp := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/health", nil)
e.ServeHTTP(resp, req)
if resp.Result().StatusCode != http.StatusOK {
t.Errorf("expected the response code to be %d, got %d", http.StatusOK, resp.Result().StatusCode)
}
})
t.Run("jwt token added", func(t *testing.T) {
description := "testing 1st blog description"
url := "www.example.com"
title := "example"
mockBlog := &TestBlog{Title: &title, Url: &url, Description: &description}
reqBytes, err := json.Marshal(mockBlog)
if err != nil {
t.Fatalf("error setting up request %s", err)
}
token := os.Getenv("OAUTH_TEST_KEY")
body := bytes.NewReader(reqBytes)
resp := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/blogs", body)
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+token)
e.ServeHTTP(resp, req)
if resp.Result().StatusCode != http.StatusCreated {
t.Errorf("expected the response code to be %d, got %d", http.StatusCreated, resp.Result().StatusCode)
}
})
}
Loading

0 comments on commit 6bc3d91

Please sign in to comment.