Skip to content

Commit

Permalink
feat: authentication using CSP
Browse files Browse the repository at this point in the history
Signed-off-by: Ernst Riemer <[email protected]>
Signed-off-by: Luke Winikates <[email protected]>
  • Loading branch information
ernst-riemer authored and LukeWinikates committed Aug 24, 2023
1 parent fb1a1cf commit d88633b
Show file tree
Hide file tree
Showing 38 changed files with 1,339 additions and 581 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
.idea
vendor/
.DS_Store
.DS_Store
.envrc
51 changes: 51 additions & 0 deletions internal/auth/csp/api_token_client.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
package csp

import (
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"strings"
)

type APITokenClient struct {
BaseURL string
APIToken string
}

func (c *APITokenClient) GetAccessToken() (*AuthorizeResponse, error) {
var oauthPath = "/csp/gateway/am/api/auth/api-tokens/authorize"
client := &http.Client{}

requestBody := url.Values{"grant_type": {"api_token"}, "refresh_token": {c.APIToken}}.Encode()
req, err := http.NewRequest("POST", c.BaseURL+oauthPath, strings.NewReader(requestBody))

if err != nil {
return nil, err
}

req.Header.Add("Content-Type", "application/x-www-form-urlencoded")
req.Header.Add("Accept", "application/json")

resp, err := client.Do(req)

if err != nil {
return nil, err
}

defer resp.Body.Close()

if resp.StatusCode > 399 {
return nil, fmt.Errorf("authentication failed: %d", resp.StatusCode)
}

body, err := io.ReadAll(resp.Body)
var cspResponse AuthorizeResponse
err = json.Unmarshal(body, &cspResponse)

if err != nil {
return nil, err
}
return &cspResponse, nil
}
57 changes: 57 additions & 0 deletions internal/auth/csp/client_credentials.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
package csp

import (
"encoding/base64"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"strings"
)

type ClientCredentialsClient struct {
BaseURL string
ClientID string
ClientSecret string
}

func (c *ClientCredentialsClient) authHeaderValue() string {
return "Basic " + base64.StdEncoding.EncodeToString([]byte(c.ClientID+":"+c.ClientSecret))
}

func (c *ClientCredentialsClient) GetAccessToken() (*AuthorizeResponse, error) {
var oauthPath = "/csp/gateway/am/api/auth/authorize"
client := &http.Client{}

requestBody := url.Values{"grant_type": {"client_credentials"}}.Encode()
req, err := http.NewRequest("POST", c.BaseURL+oauthPath, strings.NewReader(requestBody))

if err != nil {
return nil, err
}

req.Header.Add("Authorization", c.authHeaderValue())
req.Header.Add("Content-Type", "application/x-www-form-urlencoded")

resp, err := client.Do(req)

if err != nil {
return nil, err
}

defer resp.Body.Close()

if resp.StatusCode > 399 {
return nil, fmt.Errorf("authentication failed: %d", resp.StatusCode)
}

body, err := io.ReadAll(resp.Body)
var cspResponse AuthorizeResponse
err = json.Unmarshal(body, &cspResponse)

if err != nil {
return nil, err
}
return &cspResponse, nil
}
88 changes: 88 additions & 0 deletions internal/auth/csp/fake_csp_handler.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
package csp

import (
"encoding/base64"
"encoding/json"
"net/http"
"strings"
)

func FakeCSPHandler(apiTokens []string) http.Handler {
basicAuthCredentials := "Basic " + base64.StdEncoding.EncodeToString([]byte("a:b"))
firstRun := true

mux := http.NewServeMux()
mux.HandleFunc("/csp/gateway/am/api/auth/authorize", func(w http.ResponseWriter, r *http.Request) {
if strings.HasSuffix(r.Header.Get("Authorization"), basicAuthCredentials) {
var sup AuthorizeResponse

if firstRun {
sup = AuthorizeResponse{
ExpiresIn: 1,
AccessToken: "abc",
Scope: "aoa:directDataIngestion",
}
firstRun = false
} else {
sup = AuthorizeResponse{
ExpiresIn: 1,
AccessToken: "def",
Scope: "aoa:directDataIngestion",
}
}

w.WriteHeader(http.StatusOK)
marshal, _ := json.Marshal(sup)
w.Write(marshal)
return
}
w.WriteHeader(http.StatusUnauthorized)
})
mux.HandleFunc("/csp/gateway/am/api/auth/api-tokens/authorize", func(w http.ResponseWriter, r *http.Request) {
if err := r.ParseForm(); err != nil {
w.WriteHeader(http.StatusNotAcceptable)
return
}
if !(r.Form.Has("grant_type") && r.Form.Get("grant_type") == "api_token") {
w.WriteHeader(http.StatusUnauthorized)
return
}
if !(r.Form.Has("refresh_token")) {
w.WriteHeader(http.StatusUnauthorized)
return
}
var tokenMatch = false
for _, token := range apiTokens {
if r.Form.Get("refresh_token") == token {
tokenMatch = true
break
}
}

if !(tokenMatch) {
w.WriteHeader(http.StatusUnauthorized)
return
}
var sup AuthorizeResponse
if firstRun {
sup = AuthorizeResponse{
ExpiresIn: 1,
AccessToken: "abc",
Scope: "aoa:directDataIngestion",
}
firstRun = false
} else {
sup = AuthorizeResponse{
ExpiresIn: 1,
AccessToken: "def",
Scope: "aoa:directDataIngestion",
}
}

w.WriteHeader(http.StatusOK)
marshal, _ := json.Marshal(sup)
w.Write(marshal)
return
})
return mux
}
17 changes: 17 additions & 0 deletions internal/auth/csp/scope.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
package csp

import "strings"

func HasDirectIngestScope(scope string) bool {
if len(scope) == 0 {
return false
}

for _, s := range strings.Split(scope, " ") {
if strings.Contains(s, "aoa:directDataIngestion") || strings.Contains(s, "aoa/*") || strings.Contains(s, "aoa:*") {
return true
}
}

return false
}
17 changes: 17 additions & 0 deletions internal/auth/csp/scope_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
package csp

import (
"github.com/stretchr/testify/assert"
"testing"
)

func TestDirectIngestScopes(t *testing.T) {
assert.False(t, HasDirectIngestScope(""))
assert.False(t, HasDirectIngestScope("no direct ingest scopes"))

var scopeString = "external/51d98d2c-3ae1-11ee-be56-0242ac120002/*/aoa:directDataIngestion external/51d98d2c-3ae1-11ee-be56-0242ac120002/aoa:directDataIngestion csp:org_member"

assert.True(t, HasDirectIngestScope(scopeString))
assert.True(t, HasDirectIngestScope("some aoa:*"))
assert.True(t, HasDirectIngestScope("some aoa/*"))
}
14 changes: 14 additions & 0 deletions internal/auth/csp/types.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
package csp

type Client interface {
GetAccessToken() (*AuthorizeResponse, error)
}

type AuthorizeResponse struct {
IdToken string `json:"id_token"`
TokenType string `json:"token_type"`
ExpiresIn int `json:"expires_in"`
Scope string `json:"scope"`
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token"`
}
129 changes: 129 additions & 0 deletions internal/auth/csp_service.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
package auth

import (
"fmt"
"github.com/wavefronthq/wavefront-sdk-go/internal/auth/csp"
"log"
"net/http"
"sync"
"time"
)

type tokenResult struct {
accessToken string
err error
}

type CSPService struct {
client csp.Client
mutex sync.Mutex
tokenResult *tokenResult
refreshTicker *time.Ticker
done chan bool
defaultRefreshInterval time.Duration
}

// NewCSPServerToServerService returns a Service instance that gets access tokens via CSP client credentials
func NewCSPServerToServerService(CSPBaseUrl string, ClientId string, ClientSecret string) Service {
return newService(&csp.ClientCredentialsClient{
BaseURL: CSPBaseUrl,
ClientID: ClientId,
ClientSecret: ClientSecret,
})
}

func NewCSPTokenService(CSPBaseUrl, apiToken string) Service {
return newService(&csp.APITokenClient{
BaseURL: CSPBaseUrl,
APIToken: apiToken,
})
}

func newService(client csp.Client) Service {
return &CSPService{
client: client,
defaultRefreshInterval: 60 * time.Second,
}
}

func (s *CSPService) IsDirect() bool {
return true
}

func (s *CSPService) Authorize(r *http.Request) error {
s.mutex.Lock()
defer s.mutex.Unlock()

if s.tokenResult == nil {
s.RefreshAccessToken()
}

if s.tokenResult.err != nil {
return &Err{
error: s.tokenResult.err,
}
}

r.Header.Set("Authorization", "Bearer "+s.tokenResult.accessToken)
return nil
}

func (s *CSPService) RefreshAccessToken() {
cspResponse, err := s.client.GetAccessToken()

if err != nil {
s.tokenResult = &tokenResult{
accessToken: "",
err: err,
}
s.scheduleNextTokenRefresh(s.defaultRefreshInterval)
return
}

if !csp.HasDirectIngestScope(cspResponse.Scope) {
s.tokenResult = &tokenResult{
accessToken: "",
err: fmt.Errorf("response did not include required scope: 'aoa:directDataIngestion'"),
}
s.scheduleNextTokenRefresh(s.defaultRefreshInterval)
return
}

s.scheduleNextTokenRefresh(time.Duration(cspResponse.ExpiresIn) * time.Second)
s.tokenResult = &tokenResult{
accessToken: cspResponse.AccessToken,
err: nil,
}
}

func (s *CSPService) scheduleNextTokenRefresh(expiresIn time.Duration) {
tickerInterval := calculateNewTickerInterval(expiresIn, s.defaultRefreshInterval)

if s.refreshTicker == nil {
s.refreshTicker = time.NewTicker(tickerInterval)
s.done = make(chan bool)
go func() {
for {
select {
case <-s.done:
return
case tick := <-s.refreshTicker.C:
s.mutex.Lock()
log.Printf("Re-fetching CSP credentials at: %v \n", tick)
s.RefreshAccessToken()
s.mutex.Unlock()
}
}
}()
} else {
s.refreshTicker.Reset(tickerInterval)
}
}

func (s *CSPService) Close() {
log.Println("Shutting down the CSPService")
if s.refreshTicker == nil {
return
}
s.done <- true
}
Loading

0 comments on commit d88633b

Please sign in to comment.