From 16832c7579082ba2fb37a00a6339d006e57315aa Mon Sep 17 00:00:00 2001 From: Mike Bland Date: Wed, 30 Sep 2015 18:42:21 -0400 Subject: [PATCH] Allow for other signature algorithms than sha1 --- oauthproxy.go | 21 +++++++------ oauthproxy_test.go | 8 ++--- options.go | 61 +++++++++++++++++++++++++++++++++---- options_test.go | 44 ++++++++++++++++++-------- signature/signature.go | 45 ++++++++++++++++++++++----- signature/signature_test.go | 27 ++++++++++++---- 6 files changed, 161 insertions(+), 45 deletions(-) diff --git a/oauthproxy.go b/oauthproxy.go index e4e0bd949..dc0a0b855 100644 --- a/oauthproxy.go +++ b/oauthproxy.go @@ -52,16 +52,17 @@ type OauthProxy struct { } type UpstreamProxy struct { - upstream string - handler http.Handler - signatureKey string + upstream string + handler http.Handler + signature *SignatureData } func (u *UpstreamProxy) ServeHTTP(w http.ResponseWriter, r *http.Request) { w.Header().Set("GAP-Upstream-Address", u.upstream) - if u.signatureKey != "" { + if u.signature != nil { r.Header.Set("GAP-Auth", w.Header().Get("GAP-Auth")) - sig := signature.RequestSignature(r, u.signatureKey) + sig := signature.RequestSignature(r, u.signature.hash, + u.signature.key) r.Header.Set("GAP-Signature", sig) } u.handler.ServeHTTP(w, r) @@ -107,19 +108,19 @@ func NewOauthProxy(opts *Options, validator func(string) bool) *OauthProxy { } else { setProxyDirector(proxy) } - signatureKey := opts.upstreamKeys[u.Host] - if signatureKey == "" { - signatureKey = opts.SignatureKey + signatureData := opts.upstreamKeys[u.Host] + if signatureData == nil { + signatureData = opts.signatureData } serveMux.Handle(path, - &UpstreamProxy{u.Host, proxy, signatureKey}) + &UpstreamProxy{u.Host, proxy, signatureData}) case "file": if u.Fragment != "" { path = u.Fragment } log.Printf("mapping path %q => file system %q", path, u.Path) proxy := NewFileServer(path, u.Path) - serveMux.Handle(path, &UpstreamProxy{path, proxy, ""}) + serveMux.Handle(path, &UpstreamProxy{path, proxy, nil}) default: panic(fmt.Sprintf("unknown upstream protocol %s", u.Scheme)) } diff --git a/oauthproxy_test.go b/oauthproxy_test.go index b629920ad..f0e8ba6fb 100644 --- a/oauthproxy_test.go +++ b/oauthproxy_test.go @@ -644,7 +644,7 @@ func TestNoRequestSignature(t *testing.T) { func TestDefaultRequestSignature(t *testing.T) { st := NewSignatureTest() defer st.Close() - st.opts.SignatureKey = "foobar" + st.opts.SignatureKey = "sha1:foobar" st.MakeRequestWithExpectedKey("GET", "", "foobar") assert.Equal(t, 200, st.rw.Code) assert.Equal(t, st.rw.Body.String(), "signatures match") @@ -653,7 +653,7 @@ func TestDefaultRequestSignature(t *testing.T) { func TestDefaultRequestSignaturePostRequest(t *testing.T) { st := NewSignatureTest() defer st.Close() - st.opts.SignatureKey = "foobar" + st.opts.SignatureKey = "sha1:foobar" payload := `{ "hello": "world!" }` st.MakeRequestWithExpectedKey("POST", payload, "foobar") assert.Equal(t, 200, st.rw.Code) @@ -663,9 +663,9 @@ func TestDefaultRequestSignaturePostRequest(t *testing.T) { func TestUpstreamSpecificRequestSignature(t *testing.T) { st := NewSignatureTest() defer st.Close() - st.opts.SignatureKey = "foobar" + st.opts.SignatureKey = "sha1:foobar" st.opts.UpstreamKeys = append(st.opts.UpstreamKeys, - st.upstream_host+"=bazquux") + st.upstream_host+"=sha1:bazquux") st.MakeRequestWithExpectedKey("GET", "", "bazquux") assert.Equal(t, 200, st.rw.Code) assert.Equal(t, st.rw.Body.String(), "signatures match") diff --git a/options.go b/options.go index be1b2dcaa..e2b58d54f 100644 --- a/options.go +++ b/options.go @@ -1,6 +1,7 @@ package main import ( + "crypto" "fmt" "net/url" "os" @@ -9,6 +10,7 @@ import ( "time" "github.com/bitly/oauth2_proxy/providers" + "github.com/bitly/oauth2_proxy/signature" ) // Configuration Options that can be set by Command Line Flag, or Config File @@ -68,7 +70,13 @@ type Options struct { proxyUrls []*url.URL CompiledRegex []*regexp.Regexp provider providers.Provider - upstreamKeys map[string]string + signatureData *SignatureData + upstreamKeys map[string]*SignatureData +} + +type SignatureData struct { + hash crypto.Hash + key string } func NewOptions() *Options { @@ -87,7 +95,7 @@ func NewOptions() *Options { PassHostHeader: true, ApprovalPrompt: "force", RequestLogging: true, - upstreamKeys: make(map[string]string), + upstreamKeys: make(map[string]*SignatureData), } } @@ -219,6 +227,12 @@ func parseProviderInfo(o *Options, msgs []string) []string { } func parseSignatureKeys(o *Options, msgs []string) []string { + var specErr string + o.signatureData, specErr = parseSignatureSpec(o.SignatureKey) + if specErr != "" { + msgs = append(msgs, specErr+": "+o.SignatureKey) + } + numKeys := len(o.UpstreamKeys) if numKeys == 0 { return msgs @@ -231,16 +245,28 @@ func parseSignatureKeys(o *Options, msgs []string) []string { invalidSpecs := make([]string, 0) invalidHosts := make([]string, 0) - o.upstreamKeys = make(map[string]string, numKeys) + duplicateHosts := make([]string, 0) for i := 0; i != numKeys; i++ { keySpec := o.UpstreamKeys[i] - if hostKey := strings.Split(keySpec, "="); len(hostKey) != 2 { + hostKey := strings.Split(keySpec, "=") + if len(hostKey) != 2 { invalidSpecs = append(invalidSpecs, keySpec) - } else if hostSet[hostKey[0]] == false { + continue + } + + host, spec := hostKey[0], hostKey[1] + if hostSet[host] == false { invalidHosts = append(invalidHosts, keySpec) + } else if o.upstreamKeys[host] != nil { + duplicateHosts = append(duplicateHosts, keySpec) + } + + if sigData, specErr := parseSignatureSpec(spec); specErr != "" { + invalidSpecs = append(invalidSpecs, specErr+": "+ + keySpec) } else { - o.upstreamKeys[hostKey[0]] = hostKey[1] + o.upstreamKeys[host] = sigData } } @@ -253,5 +279,28 @@ func parseSignatureKeys(o *Options, msgs []string) []string { "any defined upstreams:\n "+ strings.Join(invalidHosts, "\n ")) } + if len(duplicateHosts) != 0 { + msgs = append(msgs, + "specs that duplicate other host specs:\n "+ + strings.Join(duplicateHosts, "\n ")) + } return msgs } + +func parseSignatureSpec(data string) (result *SignatureData, err string) { + if data == "" { + return nil, "" + } + + components := strings.Split(data, ":") + if len(components) != 2 { + return nil, "invalid signature hash:key spec" + } + + algorithm, secretKey := components[0], components[1] + if hash, err := signature.HashAlgorithm(algorithm); err != nil { + return nil, "unsupported signature hash algorithm" + } else { + return &SignatureData{hash, secretKey}, "" + } +} diff --git a/options_test.go b/options_test.go index 7bdb2e342..cdaec9eec 100644 --- a/options_test.go +++ b/options_test.go @@ -1,6 +1,7 @@ package main import ( + "crypto" "net/url" "strings" "testing" @@ -176,15 +177,18 @@ func TestValidateUpstreamSignatureKeys(t *testing.T) { "https://bar.com/bar", "https://baz.com", } - o.SignatureKey = "default secret" - upstreamKeys := "foo.com:8000=secret0,bar.com=secret1,baz.com=secret2" - o.UpstreamKeys = strings.Split(upstreamKeys, ",") + o.SignatureKey = "sha1:default secret" + o.UpstreamKeys = []string{ + "foo.com:8000=sha1:secret0", + "bar.com=sha1:secret1", + "baz.com=sha1:secret2", + } assert.Equal(t, nil, o.Validate()) - assert.Equal(t, o.upstreamKeys, map[string]string{ - "foo.com:8000": "secret0", - "bar.com": "secret1", - "baz.com": "secret2", + assert.Equal(t, o.upstreamKeys, map[string]*SignatureData{ + "foo.com:8000": &SignatureData{crypto.SHA1, "secret0"}, + "bar.com": &SignatureData{crypto.SHA1, "secret1"}, + "baz.com": &SignatureData{crypto.SHA1, "secret2"}, }) } @@ -195,17 +199,33 @@ func TestValidateUpstreamSignatureKeysWithErrors(t *testing.T) { o.Upstreams = []string{ "https://bar.com/bar", "https://baz.com", + "https://quux.com", + "https://xyzzy.com", + } + o.SignatureKey = "unsupported:default secret" + o.UpstreamKeys = []string{ + "foo.com:8000=sha1:secret0", + "bar.com=secret1", + "baz.com:sha1:secret2", + "quux.com=sha1:secret3", + "quux.com=sha1:secret4", + "xyzzy.com=unsupported:secret5", } - o.SignatureKey = "default secret" - upstreamKeys := "foo.com:8000=secret0,bar.com=secret1,baz.com:secret2" - o.UpstreamKeys = strings.Split(upstreamKeys, ",") err := o.Validate() assert.NotEqual(t, nil, err) expected := errorMsg([]string{ + "unsupported signature hash algorithm: " + + "unsupported:default secret", "invalid upstream key specs:", - " baz.com:secret2", + " invalid signature hash:key spec: bar.com=secret1", + " baz.com:sha1:secret2", + " unsupported signature hash algorithm: " + + "xyzzy.com=unsupported:secret5", "specs with hosts that do not match any defined upstreams:", - " foo.com:8000=secret0"}) + " foo.com:8000=sha1:secret0", + "specs that duplicate other host specs:", + " quux.com=sha1:secret4", + }) assert.Equal(t, err.Error(), expected) } diff --git a/signature/signature.go b/signature/signature.go index 1af8a2664..b3a156444 100644 --- a/signature/signature.go +++ b/signature/signature.go @@ -1,14 +1,28 @@ package signature import ( + "crypto" "crypto/hmac" - "crypto/sha1" "encoding/base64" "net/http" "strconv" "strings" ) +var supportedAlgorithms map[string]crypto.Hash +var algorithmName map[crypto.Hash]string + +func init() { + supportedAlgorithms = map[string]crypto.Hash{ + "sha1": crypto.SHA1, + } + + algorithmName = make(map[crypto.Hash]string) + for name, algorithm := range supportedAlgorithms { + algorithmName[algorithm] = name + } +} + // The string to sign is based on the following request elements, inspired by: // http://docs.aws.amazon.com/AmazonS3/latest/dev/RESTAuthentication.html func StringToSign(req *http.Request) string { @@ -28,8 +42,24 @@ func StringToSign(req *http.Request) string { }, "\n") } -func RequestSignature(req *http.Request, secretKey string) string { - h := hmac.New(sha1.New, []byte(secretKey)) +type unsupportedAlgorithm struct { + algorithm string +} + +func (e unsupportedAlgorithm) Error() string { + return "unsupported request signature algorithm: " + e.algorithm +} + +func HashAlgorithm(algorithm string) (result crypto.Hash, err error) { + if result = supportedAlgorithms[algorithm]; result == crypto.Hash(0) { + err = unsupportedAlgorithm{algorithm} + } + return +} + +func RequestSignature(req *http.Request, hashAlgorithm crypto.Hash, + secretKey string) string { + h := hmac.New(hashAlgorithm.New, []byte(secretKey)) h.Write([]byte(StringToSign(req))) if req.ContentLength != -1 && req.Body != nil { @@ -40,7 +70,8 @@ func RequestSignature(req *http.Request, secretKey string) string { var sig []byte sig = h.Sum(sig) - return "sha1 " + base64.URLEncoding.EncodeToString(sig) + return algorithmName[hashAlgorithm] + " " + + base64.URLEncoding.EncodeToString(sig) } type ValidationResult int @@ -71,13 +102,13 @@ func ValidateRequest(request *http.Request, key string) ( return } - algorithm := components[0] - if algorithm != "sha1" { + algorithm, err := HashAlgorithm(components[0]) + if err != nil { result = UNSUPPORTED_ALGORITHM return } - computedSignature = RequestSignature(request, key) + computedSignature = RequestSignature(request, algorithm, key) if hmac.Equal([]byte(headerSignature), []byte(computedSignature)) { result = MATCH } else { diff --git a/signature/signature_test.go b/signature/signature_test.go index 9c62c90c8..1a2ec58c6 100644 --- a/signature/signature_test.go +++ b/signature/signature_test.go @@ -2,6 +2,7 @@ package signature import ( "bufio" + "crypto" "net/http" "strconv" "strings" @@ -10,6 +11,20 @@ import ( "github.com/bmizerany/assert" ) +func TestSupportedHashAlgorithm(t *testing.T) { + algorithm, err := HashAlgorithm("sha1") + assert.Equal(t, err, nil) + assert.Equal(t, algorithm, crypto.SHA1) + assert.Equal(t, algorithm.Available(), true) +} + +func TestUnsupportedHashAlgorithm(t *testing.T) { + algorithm, err := HashAlgorithm("unsupported") + assert.NotEqual(t, err, nil) + assert.Equal(t, algorithm, crypto.Hash(0)) + assert.Equal(t, algorithm.Available(), false) +} + func newTestRequest(request ...string) (req *http.Request) { reqBuf := bufio.NewReader( strings.NewReader(strings.Join(request, "\n"))) @@ -51,7 +66,7 @@ func TestRequestSignaturePost(t *testing.T) { "mbland", "/foo/bar", }, "\n")) - assert.Equal(t, RequestSignature(req, "foobar"), + assert.Equal(t, RequestSignature(req, crypto.SHA1, "foobar"), "sha1 722UbRYfC6MnjtIxqEJMDPrW2mk=") } @@ -78,7 +93,7 @@ func TestRequestSignatureGet(t *testing.T) { "mbland", "/foo/bar", }, "\n")) - assert.Equal(t, RequestSignature(req, "foobar"), + assert.Equal(t, RequestSignature(req, crypto.SHA1, "foobar"), "sha1 JBQJcmSTteQyHZXFUA9glis9BIk=") } @@ -113,7 +128,7 @@ func TestValidateRequestInvalidFormat(t *testing.T) { func TestValidateRequestUnsupportedAlgorithm(t *testing.T) { req := newGetRequest() - validSignature := RequestSignature(req, "foobar") + validSignature := RequestSignature(req, crypto.SHA1, "foobar") components := strings.Split(validSignature, " ") signatureWithUnsupportedAlgorithm := "unsupported " + components[1] req.Header.Set("GAP-Signature", signatureWithUnsupportedAlgorithm) @@ -125,7 +140,7 @@ func TestValidateRequestUnsupportedAlgorithm(t *testing.T) { func TestValidateRequestMatch(t *testing.T) { req := newGetRequest() - expectedSignature := RequestSignature(req, "foobar") + expectedSignature := RequestSignature(req, crypto.SHA1, "foobar") req.Header.Set("GAP-Signature", expectedSignature) result, header, computed := ValidateRequest(req, "foobar") assert.Equal(t, result, MATCH) @@ -135,8 +150,8 @@ func TestValidateRequestMatch(t *testing.T) { func TestValidateRequestMismatch(t *testing.T) { req := newGetRequest() - foobarSignature := RequestSignature(req, "foobar") - barbazSignature := RequestSignature(req, "barbaz") + foobarSignature := RequestSignature(req, crypto.SHA1, "foobar") + barbazSignature := RequestSignature(req, crypto.SHA1, "barbaz") req.Header.Set("GAP-Signature", foobarSignature) result, header, computed := ValidateRequest(req, "barbaz") assert.Equal(t, result, MISMATCH)