Skip to content

Commit

Permalink
check token for authz based on authoritative list of group memberships
Browse files Browse the repository at this point in the history
Signed-off-by: Austen Lacy <[email protected]>
  • Loading branch information
austenLacy committed Aug 29, 2023
1 parent 4d1cf29 commit ffd6a7e
Show file tree
Hide file tree
Showing 21 changed files with 149 additions and 43 deletions.
132 changes: 125 additions & 7 deletions go/acl/shopify_jwt_policy.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,24 +19,46 @@ package acl
import (
"bytes"
"context"
b64 "encoding/base64"
"encoding/json"
"errors"
"fmt"
"log"
"net/http"
"os"
"strconv"
"strings"

keyfunc "github.com/MicahParks/keyfunc/v2"
jwt "github.com/golang-jwt/jwt/v5"
)

const (
SHOPIFY_JWT = "shopify_jwt"
SHOPIFY_COOKIE_NAME_ENV = "JWT_COOKIE_NAME"
SHOPIFY_JWKS_URL_ENV = "JWKS_URL"
SHOPIFY_JWT = "shopify_jwt"
SHOPIFY_JWT_HEADER_ENV = "SHOPIFY_JWT_HEADER"
SHOPIFY_JWKS_URL_ENV = "SHOPIFY_JWKS_URL"
SHOPIFY_USER_ID_HEADER_ENV = "SHOPIFY_USER_ID_HEADER"
SHOPIFY_AUTHZ_URL_ENV = "SHOPIFY_AUTHZ_URL"
SHOPIFY_AUTHZ_GROUPS_ENV = "SHOPIFY_AUTHZ_GROUPS"
SHOPIFY_AUTHZ_USERNAME_ENV = "SHOPIFY_AUTHZ_USERNAME"
SHOPIFY_AUTHZ_PASSWORD_ENV = "SHOPIFY_AUTHZ_PASSWORD"
)

var errDenyShopifyJwt = errors.New("not allowed: shopify_jwt security_policy enforced")

type membership struct {
Group string
Member bool
}

type shopifyAuthzData struct {
Memberships []membership
}

type shopifyAuthzResponse struct {
Data shopifyAuthzData
}

func jwksRequestFactory(ctx context.Context, url string) (*http.Request, error) {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, bytes.NewReader(nil))

Expand Down Expand Up @@ -75,6 +97,76 @@ func validateJWT(tokenString string, jwksURL string) (bool, error) {
return false, nil
}

func buildAuth(username string, password string) string {
return fmt.Sprintf("Basic %s", b64.URLEncoding.EncodeToString([]byte(fmt.Sprintf("%s:%s", username, password))))
}

func authorizeUser(userId int) (bool, error) {
url := os.Getenv(SHOPIFY_AUTHZ_URL_ENV)

var groups string

for _, group := range strings.Split(os.Getenv(SHOPIFY_AUTHZ_GROUPS_ENV), ",") {
groups += fmt.Sprintf("\"%s\",", group)
}

query := map[string]string{
"query": fmt.Sprintf(`
{
memberships(
userEmployeeId: %d
groups: [%s]
) {
member
group
}
}
`, userId, groups),
}

queryJson, err := json.Marshal(query)

if err != nil {
return false, err
}

req, err := http.NewRequest("POST", url, bytes.NewBuffer(queryJson))

if err != nil {
return false, err
}

req.Header.Add("Content-Type", "application/json")
req.Header.Add("Authorization", buildAuth(os.Getenv(SHOPIFY_AUTHZ_USERNAME_ENV), os.Getenv(SHOPIFY_AUTHZ_PASSWORD_ENV)))

client := &http.Client{}
res, err := client.Do(req)

if err != nil {
return false, err
}

defer res.Body.Close()

if res.StatusCode != http.StatusOK {
return false, fmt.Errorf("failed to authorize user. status: %s", res.Status)
}

var jsonRes shopifyAuthzResponse
err = json.NewDecoder(res.Body).Decode(&jsonRes)
if err != nil {
return false, err
}

for _, membership := range jsonRes.Data.Memberships {
if membership.Member {
return true, nil
}
}

return false, fmt.Errorf("user is not a member of any authorized groups")
}

// CheckAccessActor disallows actor access not verified by shopifyJwt
func (shopifyJwt) CheckAccessActor(actor, role string) error {
switch role {
Expand All @@ -89,17 +181,43 @@ func (shopifyJwt) CheckAccessActor(actor, role string) error {
func (shopifyJwt) CheckAccessHTTP(req *http.Request, role string) error {
switch role {
case SHOPIFY_JWT:
jwtCookie, err := req.Cookie(os.Getenv("SHOPIFY_COOKIE_NAME_ENV"))
jwtToken := req.Header.Get(os.Getenv(SHOPIFY_JWT_HEADER_ENV))

if len(jwtToken) < 1 {
log.Println("failed to get jwt token from header")
return errDenyShopifyJwt
}

_, err := validateJWT(jwtToken, os.Getenv(SHOPIFY_JWKS_URL_ENV))

if err != nil {
log.Printf("failed to get jwt token from cookie: %s", err)
log.Printf("invalid JWT token provided: %s", err)
return err
}

_, err = validateJWT(jwtCookie.Value, os.Getenv(SHOPIFY_JWKS_URL_ENV))
userId := req.Header.Get(os.Getenv(SHOPIFY_USER_ID_HEADER_ENV))

if len(userId) < 1 {
log.Println("failed to get user id from header")
return errDenyShopifyJwt
}

userIdInt, err := strconv.Atoi(userId)

if err != nil {
log.Printf("invalid JWT token provided: %s", err)
log.Printf("failed to convert user id to int: %s", err)
return err
}

authorized, err := authorizeUser(userIdInt)

if err != nil {
log.Printf("failed to authorize user ID: %s, %v", userId, err)
return err
}

if !authorized {
log.Printf("user ID %s is not authorized", userId)
return err
}

Expand Down
2 changes: 1 addition & 1 deletion go/streamlog/streamlog.go
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ func (logger *StreamLogger) Name() string {
// It is safe to register multiple URLs for the same StreamLogger.
func (logger *StreamLogger) ServeLogs(url string, logf LogFormatter) {
http.HandleFunc(url, func(w http.ResponseWriter, r *http.Request) {
if err := acl.CheckAccessHTTP(r, acl.DEBUGGING); err != nil {
if err := acl.CheckAccessHTTP(r, acl.SHOPIFY_JWT); err != nil {
acl.SendError(w, err)
return
}
Expand Down
6 changes: 3 additions & 3 deletions go/vt/servenv/status.go
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ func (sp *statusPage) addStatusSection(banner string, f func() string) {
}

func (sp *statusPage) statusHandler(w http.ResponseWriter, r *http.Request) {
if err := acl.CheckAccessHTTP(r, acl.DEBUGGING); err != nil {
if err := acl.CheckAccessHTTP(r, acl.SHOPIFY_JWT); err != nil {
acl.SendError(w, err)
return
}
Expand Down Expand Up @@ -250,7 +250,7 @@ func (sp *statusPage) reparse(sections []section) (*template.Template, error) {
// Toggle the block profile rate to/from 100%, unless specific rate is passed in
func registerDebugBlockProfileRate() {
http.HandleFunc("/debug/blockprofilerate", func(w http.ResponseWriter, r *http.Request) {
if err := acl.CheckAccessHTTP(r, acl.DEBUGGING); err != nil {
if err := acl.CheckAccessHTTP(r, acl.SHOPIFY_JWT); err != nil {
acl.SendError(w, err)
return
}
Expand Down Expand Up @@ -280,7 +280,7 @@ func registerDebugBlockProfileRate() {
// Toggle the mutex profiling fraction to/from 100%, unless specific fraction is passed in
func registerDebugMutexProfileFraction() {
http.HandleFunc("/debug/mutexprofilefraction", func(w http.ResponseWriter, r *http.Request) {
if err := acl.CheckAccessHTTP(r, acl.DEBUGGING); err != nil {
if err := acl.CheckAccessHTTP(r, acl.SHOPIFY_JWT); err != nil {
acl.SendError(w, err)
return
}
Expand Down
2 changes: 1 addition & 1 deletion go/vt/vtctld/debug_health.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ import (
// RegisterDebugHealthHandler register a debug health http endpoint for a vtcld server
func RegisterDebugHealthHandler(ts *topo.Server) {
http.HandleFunc("/debug/health", func(w http.ResponseWriter, r *http.Request) {
if err := acl.CheckAccessHTTP(r, acl.MONITORING); err != nil {
if err := acl.CheckAccessHTTP(r, acl.SHOPIFY_JWT); err != nil {
acl.SendError(w, err)
return
}
Expand Down
4 changes: 0 additions & 4 deletions go/vt/vtgate/debugenv.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,10 +88,6 @@ type envValue struct {
}

func debugEnvHandler(vtg *VTGate, w http.ResponseWriter, r *http.Request) {
// if err := acl.CheckAccessHTTP(r, acl.ADMIN); err != nil {
// acl.SendError(w, err)
// return
// }
if err := acl.CheckAccessHTTP(r, acl.SHOPIFY_JWT); err != nil {
acl.SendError(w, err)
return
Expand Down
2 changes: 1 addition & 1 deletion go/vt/vtgate/executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -1104,7 +1104,7 @@ func (e *Executor) debugCacheEntries() (items []cacheItem) {

// ServeHTTP shows the current plans in the query cache.
func (e *Executor) ServeHTTP(response http.ResponseWriter, request *http.Request) {
if err := acl.CheckAccessHTTP(request, acl.DEBUGGING); err != nil {
if err := acl.CheckAccessHTTP(request, acl.SHOPIFY_JWT); err != nil {
acl.SendError(response, err)
return
}
Expand Down
2 changes: 1 addition & 1 deletion go/vt/vtgate/querylogz.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ var (
// querylogzHandler serves a human readable snapshot of the
// current query log.
func querylogzHandler(ch chan any, w http.ResponseWriter, r *http.Request) {
if err := acl.CheckAccessHTTP(r, acl.DEBUGGING); err != nil {
if err := acl.CheckAccessHTTP(r, acl.SHOPIFY_JWT); err != nil {
acl.SendError(w, err)
return
}
Expand Down
2 changes: 1 addition & 1 deletion go/vt/vtgate/queryz.go
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ func (s *queryzSorter) Swap(i, j int) { s.rows[i], s.rows[j] = s.rows[j], s
func (s *queryzSorter) Less(i, j int) bool { return s.less(s.rows[i], s.rows[j]) }

func queryzHandler(e *Executor, w http.ResponseWriter, r *http.Request) {
if err := acl.CheckAccessHTTP(r, acl.DEBUGGING); err != nil {
if err := acl.CheckAccessHTTP(r, acl.SHOPIFY_JWT); err != nil {
acl.SendError(w, err)
return
}
Expand Down
2 changes: 1 addition & 1 deletion go/vt/vtgate/vtgate.go
Original file line number Diff line number Diff line change
Expand Up @@ -389,7 +389,7 @@ func (vtg *VTGate) registerDebugEnvHandler() {

func (vtg *VTGate) registerDebugHealthHandler() {
http.HandleFunc("/debug/health", func(w http.ResponseWriter, r *http.Request) {
if err := acl.CheckAccessHTTP(r, acl.MONITORING); err != nil {
if err := acl.CheckAccessHTTP(r, acl.SHOPIFY_JWT); err != nil {
acl.SendError(w, err)
return
}
Expand Down
4 changes: 0 additions & 4 deletions go/vt/vttablet/tabletserver/debugenv.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,10 +64,6 @@ func addVar[T any](vars []envValue, name string, f func() T) []envValue {
}

func debugEnvHandler(tsv *TabletServer, w http.ResponseWriter, r *http.Request) {
// if err := acl.CheckAccessHTTP(r, acl.ADMIN); err != nil {
// acl.SendError(w, err)
// return
// }
if err := acl.CheckAccessHTTP(r, acl.SHOPIFY_JWT); err != nil {
acl.SendError(w, err)
return
Expand Down
2 changes: 1 addition & 1 deletion go/vt/vttablet/tabletserver/livequeryz.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ var (
)

func livequeryzHandler(queryLists []*QueryList, w http.ResponseWriter, r *http.Request) {
if err := acl.CheckAccessHTTP(r, acl.DEBUGGING); err != nil {
if err := acl.CheckAccessHTTP(r, acl.SHOPIFY_JWT); err != nil {
acl.SendError(w, err)
return
}
Expand Down
14 changes: 5 additions & 9 deletions go/vt/vttablet/tabletserver/query_engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -517,7 +517,7 @@ type perQueryStats struct {
}

func (qe *QueryEngine) handleHTTPQueryPlans(response http.ResponseWriter, request *http.Request) {
if err := acl.CheckAccessHTTP(request, acl.DEBUGGING); err != nil {
if err := acl.CheckAccessHTTP(request, acl.SHOPIFY_JWT); err != nil {
acl.SendError(response, err)
return
}
Expand All @@ -537,7 +537,7 @@ func (qe *QueryEngine) handleHTTPQueryPlans(response http.ResponseWriter, reques
}

func (qe *QueryEngine) handleHTTPQueryStats(response http.ResponseWriter, request *http.Request) {
if err := acl.CheckAccessHTTP(request, acl.DEBUGGING); err != nil {
if err := acl.CheckAccessHTTP(request, acl.SHOPIFY_JWT); err != nil {
acl.SendError(response, err)
return
}
Expand All @@ -563,7 +563,7 @@ func (qe *QueryEngine) handleHTTPQueryStats(response http.ResponseWriter, reques
}

func (qe *QueryEngine) handleHTTPQueryRules(response http.ResponseWriter, request *http.Request) {
if err := acl.CheckAccessHTTP(request, acl.DEBUGGING); err != nil {
if err := acl.CheckAccessHTTP(request, acl.SHOPIFY_JWT); err != nil {
acl.SendError(response, err)
return
}
Expand All @@ -579,7 +579,7 @@ func (qe *QueryEngine) handleHTTPQueryRules(response http.ResponseWriter, reques
}

func (qe *QueryEngine) handleHTTPAclJSON(response http.ResponseWriter, request *http.Request) {
if err := acl.CheckAccessHTTP(request, acl.DEBUGGING); err != nil {
if err := acl.CheckAccessHTTP(request, acl.SHOPIFY_JWT); err != nil {
acl.SendError(response, err)
return
}
Expand All @@ -601,11 +601,7 @@ func (qe *QueryEngine) handleHTTPAclJSON(response http.ResponseWriter, request *

// ServeHTTP lists the most recent, cached queries and their count.
func (qe *QueryEngine) handleHTTPConsolidations(response http.ResponseWriter, request *http.Request) {
if err := acl.CheckAccessHTTP(request, acl.DEBUGGING); err != nil {
acl.SendError(response, err)
return
}
if err := acl.CheckAccessHTTP(request, acl.DEBUGGING); err != nil {
if err := acl.CheckAccessHTTP(request, acl.SHOPIFY_JWT); err != nil {
acl.SendError(response, err)
return
}
Expand Down
2 changes: 1 addition & 1 deletion go/vt/vttablet/tabletserver/querylogz.go
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ func init() {
// querylogzHandler serves a human readable snapshot of the
// current query log.
func querylogzHandler(ch chan any, w http.ResponseWriter, r *http.Request) {
if err := acl.CheckAccessHTTP(r, acl.DEBUGGING); err != nil {
if err := acl.CheckAccessHTTP(r, acl.SHOPIFY_JWT); err != nil {
acl.SendError(w, err)
return
}
Expand Down
2 changes: 1 addition & 1 deletion go/vt/vttablet/tabletserver/queryz.go
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ func (s *queryzSorter) Swap(i, j int) { s.rows[i], s.rows[j] = s.rows[j], s
func (s *queryzSorter) Less(i, j int) bool { return s.less(s.rows[i], s.rows[j]) }

func queryzHandler(qe *QueryEngine, w http.ResponseWriter, r *http.Request) {
if err := acl.CheckAccessHTTP(r, acl.DEBUGGING); err != nil {
if err := acl.CheckAccessHTTP(r, acl.SHOPIFY_JWT); err != nil {
acl.SendError(w, err)
return
}
Expand Down
2 changes: 1 addition & 1 deletion go/vt/vttablet/tabletserver/schema/engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -621,7 +621,7 @@ func (se *Engine) GetConnection(ctx context.Context) (*connpool.DBConn, error) {
}

func (se *Engine) handleDebugSchema(response http.ResponseWriter, request *http.Request) {
if err := acl.CheckAccessHTTP(request, acl.DEBUGGING); err != nil {
if err := acl.CheckAccessHTTP(request, acl.SHOPIFY_JWT); err != nil {
acl.SendError(response, err)
return
}
Expand Down
2 changes: 1 addition & 1 deletion go/vt/vttablet/tabletserver/schema/schemaz.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ func (sorter *schemazSorter) Less(i, j int) bool {
}

func schemazHandler(tables map[string]*Table, w http.ResponseWriter, r *http.Request) {
if err := acl.CheckAccessHTTP(r, acl.DEBUGGING); err != nil {
if err := acl.CheckAccessHTTP(r, acl.SHOPIFY_JWT); err != nil {
acl.SendError(w, err)
return
}
Expand Down
2 changes: 1 addition & 1 deletion go/vt/vttablet/tabletserver/tabletserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -1713,7 +1713,7 @@ func (tsv *TabletServer) healthzHandler(w http.ResponseWriter, r *http.Request)
// Returns ok if a query can go all the way to database and back
func (tsv *TabletServer) registerDebugHealthHandler() {
tsv.exporter.HandleFunc("/debug/health", func(w http.ResponseWriter, r *http.Request) {
if err := acl.CheckAccessHTTP(r, acl.MONITORING); err != nil {
if err := acl.CheckAccessHTTP(r, acl.SHOPIFY_JWT); err != nil {
acl.SendError(w, err)
return
}
Expand Down
2 changes: 1 addition & 1 deletion go/vt/vttablet/tabletserver/twopcz.go
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ var (
)

func twopczHandler(txe *TxExecutor, w http.ResponseWriter, r *http.Request) {
if err := acl.CheckAccessHTTP(r, acl.DEBUGGING); err != nil {
if err := acl.CheckAccessHTTP(r, acl.SHOPIFY_JWT); err != nil {
acl.SendError(w, err)
return
}
Expand Down
2 changes: 1 addition & 1 deletion go/vt/vttablet/tabletserver/txlogz.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ func init() {
// timeout: the txlogz will keep dumping transactions until timeout
// limit: txlogz will keep dumping transactions until it hits the limit
func txlogzHandler(w http.ResponseWriter, req *http.Request) {
if err := acl.CheckAccessHTTP(req, acl.DEBUGGING); err != nil {
if err := acl.CheckAccessHTTP(req, acl.SHOPIFY_JWT); err != nil {
acl.SendError(w, err)
return
}
Expand Down
Loading

0 comments on commit ffd6a7e

Please sign in to comment.