Skip to content

Commit

Permalink
Use external library for SCRAM authentication
Browse files Browse the repository at this point in the history
Removes custom SCRAM implementation replacing it with a wrapper for the
existing xdg-go/scram library. Changes the saslNewScram interface to
take a new type *scram.Method argument replacing the  func () hash.Hash
type. Adds a scram.NewMethod function that validates and returns a
supported method.
  • Loading branch information
mhill-anynines committed Oct 8, 2018
1 parent cee0d26 commit 223c75e
Show file tree
Hide file tree
Showing 4 changed files with 77 additions and 242 deletions.
1 change: 1 addition & 0 deletions .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ install:
- go get gopkg.in/yaml.v2
- go get gopkg.in/tomb.v2
- go get github.com/golang/lint
- go get github.com/xdg-go/scram

before_script:
- golint ./... | grep -v 'ID' | cat
Expand Down
17 changes: 7 additions & 10 deletions auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,9 @@ package mgo

import (
"crypto/md5"
"crypto/sha1"
"crypto/sha256"
"encoding/hex"
"errors"
"fmt"
"hash"
"sync"

"github.com/globalsign/mgo/bson"
Expand Down Expand Up @@ -276,11 +273,11 @@ func (socket *mongoSocket) loginPlain(cred Credential) error {
func (socket *mongoSocket) loginSASL(cred Credential) error {
var sasl saslStepper
var err error
if cred.Mechanism == "SCRAM-SHA-1" {
// SCRAM is handled without external libraries.
sasl = saslNewScram(sha1.New, cred)
} else if cred.Mechanism == "SCRAM-SHA-256" {
sasl = saslNewScram(sha256.New, cred)
if cred.Mechanism == "SCRAM-SHA-1" || cred.Mechanism == "SCRAM-SHA-256" {
// SCRAM is handled with github.com/xdg-go/scram.
var method *scram.Method
method, err = scram.NewMethod(cred.Mechanism)
sasl = saslNewScram(method, cred)
} else if len(cred.ServiceHost) > 0 {
sasl, err = saslNew(cred, cred.ServiceHost)
} else {
Expand Down Expand Up @@ -357,10 +354,10 @@ func (socket *mongoSocket) loginSASL(cred Credential) error {
return nil
}

func saslNewScram(hash func() hash.Hash, cred Credential) *saslScram {
func saslNewScram(method *scram.Method, cred Credential) *saslScram {
credsum := md5.New()
credsum.Write([]byte(cred.Username + ":mongo:" + cred.Password))
client := scram.NewClient(hash, cred.Username, hex.EncodeToString(credsum.Sum(nil)))
client := scram.NewClient(method, cred.Username, hex.EncodeToString(credsum.Sum(nil)))
return &saslScram{cred: cred, client: client}
}

Expand Down
240 changes: 59 additions & 181 deletions internal/scram/scram.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,21 +32,23 @@ package scram

import (
"bytes"
"crypto/hmac"
"crypto/rand"
"encoding/base64"
"fmt"
"hash"
"strconv"
"strings"
"errors"

xdg "github.com/xdg-go/scram"
)

// Client implements a SCRAM-* client (SCRAM-SHA-1, SCRAM-SHA-256, etc).
// Client adapts a SCRAM client (SCRAM-SHA-1, SCRAM-SHA-256).
//
// A Client may be used within a SASL conversation with logic resembling:
//
// mechanism, err := scram.NewMethod("SCRAM-SHA-256")
//
// if err != nil {
// log.Fatal(err)
// }
//
// var in []byte
// var client = scram.NewClient(sha1.New, user, pass)
// var client = scram.NewClient(, user, pass)
// for client.Step(in) {
// out := client.Out()
// // send out to server
Expand All @@ -57,34 +59,62 @@ import (
// }
//
type Client struct {
newHash func() hash.Hash

user string
pass string
step int
out bytes.Buffer
err error
conv *xdg.ClientConversation
}

// Method defines the variant of SCRAM to use
type Method struct {
method string
}

const (
// ScramSha1 use the SCRAM-SHA-1 variant
ScramSha1 = "SCRAM-SHA-1"

// ScramSha256 use the SCRAM-SHA-256 variant
ScramSha256 = "SCRAM-SHA-256"
)

clientNonce []byte
serverNonce []byte
saltedPass []byte
authMsg bytes.Buffer
// NewMethod returns a Method if the input method string is supported
// otherwise it returns an error.
// Supported method strings:
// - "SCRAM-SHA-1"
// - "SCRAM-SHA-256"
func NewMethod(methodString string) (*Method, error) {
switch methodString {
case ScramSha1, ScramSha256:
return &Method{method: methodString}, nil
default:
return nil, errors.New("invalid SCRAM mechanism")
}
}

// NewClient returns a new SCRAM-* client with the provided hash algorithm.
// NewClient returns a new SCRAM client with the provided hash algorithm.
//
// For SCRAM-SHA-1, for example, use:
//
// client := scram.NewClient(sha1.New, user, pass)
// method, _ := scram.NewMethod("SCRAM-SHA-1")
//
// client := scram.NewClient(method, user, pass)
//
func NewClient(newHash func() hash.Hash, user, pass string) *Client {
func NewClient(method *Method, user, pass string) *Client {
var client *xdg.Client
var err error

switch method.method {
case ScramSha1:
client, err = xdg.SHA1.NewClient(user, pass, "")
case ScramSha256:
client, err = xdg.SHA256.NewClient(user, pass, "")
}

c := &Client{
newHash: newHash,
user: user,
pass: pass,
conv: client.NewConversation(),
err: err,
}
c.out.Grow(256)
c.authMsg.Grow(256)
return c
}

Expand All @@ -101,166 +131,14 @@ func (c *Client) Err() error {
return c.err
}

// SetNonce sets the client nonce to the provided value.
// If not set, the nonce is generated automatically out of crypto/rand on the first step.
func (c *Client) SetNonce(nonce []byte) {
c.clientNonce = nonce
}

var escaper = strings.NewReplacer("=", "=3D", ",", "=2C")

// Step processes the incoming data from the server and makes the
// next round of data for the server available via Client.Out.
// Step returns false if there are no errors and more data is
// still expected.
func (c *Client) Step(in []byte) bool {
var resp string
c.out.Reset()
if c.step > 2 || c.err != nil {
return false
}
c.step++
switch c.step {
case 1:
c.err = c.step1(in)
case 2:
c.err = c.step2(in)
case 3:
c.err = c.step3(in)
}
return c.step > 2 || c.err != nil
}

func (c *Client) step1(in []byte) error {
if len(c.clientNonce) == 0 {
const nonceLen = 6
buf := make([]byte, nonceLen+b64.EncodedLen(nonceLen))
if _, err := rand.Read(buf[:nonceLen]); err != nil {
return fmt.Errorf("cannot read random SCRAM-SHA-1 nonce from operating system: %v", err)
}
c.clientNonce = buf[nonceLen:]
b64.Encode(c.clientNonce, buf[:nonceLen])
}
c.authMsg.WriteString("n=")
escaper.WriteString(&c.authMsg, c.user)
c.authMsg.WriteString(",r=")
c.authMsg.Write(c.clientNonce)

c.out.WriteString("n,,")
c.out.Write(c.authMsg.Bytes())
return nil
}

var b64 = base64.StdEncoding

func (c *Client) step2(in []byte) error {
c.authMsg.WriteByte(',')
c.authMsg.Write(in)

fields := bytes.Split(in, []byte(","))
if len(fields) != 3 {
return fmt.Errorf("expected 3 fields in first SCRAM-SHA-1 server message, got %d: %q", len(fields), in)
}
if !bytes.HasPrefix(fields[0], []byte("r=")) || len(fields[0]) < 2 {
return fmt.Errorf("server sent an invalid SCRAM-SHA-1 nonce: %q", fields[0])
}
if !bytes.HasPrefix(fields[1], []byte("s=")) || len(fields[1]) < 6 {
return fmt.Errorf("server sent an invalid SCRAM-SHA-1 salt: %q", fields[1])
}
if !bytes.HasPrefix(fields[2], []byte("i=")) || len(fields[2]) < 6 {
return fmt.Errorf("server sent an invalid SCRAM-SHA-1 iteration count: %q", fields[2])
}

c.serverNonce = fields[0][2:]
if !bytes.HasPrefix(c.serverNonce, c.clientNonce) {
return fmt.Errorf("server SCRAM-SHA-1 nonce is not prefixed by client nonce: got %q, want %q+\"...\"", c.serverNonce, c.clientNonce)
}

salt := make([]byte, b64.DecodedLen(len(fields[1][2:])))
n, err := b64.Decode(salt, fields[1][2:])
if err != nil {
return fmt.Errorf("cannot decode SCRAM-SHA-1 salt sent by server: %q", fields[1])
}
salt = salt[:n]
iterCount, err := strconv.Atoi(string(fields[2][2:]))
if err != nil {
return fmt.Errorf("server sent an invalid SCRAM-SHA-1 iteration count: %q", fields[2])
}
c.saltPassword(salt, iterCount)

c.authMsg.WriteString(",c=biws,r=")
c.authMsg.Write(c.serverNonce)

c.out.WriteString("c=biws,r=")
c.out.Write(c.serverNonce)
c.out.WriteString(",p=")
c.out.Write(c.clientProof())
return nil
}

func (c *Client) step3(in []byte) error {
var isv, ise bool
var fields = bytes.Split(in, []byte(","))
if len(fields) == 1 {
isv = bytes.HasPrefix(fields[0], []byte("v="))
ise = bytes.HasPrefix(fields[0], []byte("e="))
}
if ise {
return fmt.Errorf("SCRAM-SHA-1 authentication error: %s", fields[0][2:])
} else if !isv {
return fmt.Errorf("unsupported SCRAM-SHA-1 final message from server: %q", in)
}
if !bytes.Equal(c.serverSignature(), fields[0][2:]) {
return fmt.Errorf("cannot authenticate SCRAM-SHA-1 server signature: %q", fields[0][2:])
}
return nil
}

func (c *Client) saltPassword(salt []byte, iterCount int) {
mac := hmac.New(c.newHash, []byte(c.pass))
mac.Write(salt)
mac.Write([]byte{0, 0, 0, 1})
ui := mac.Sum(nil)
hi := make([]byte, len(ui))
copy(hi, ui)
for i := 1; i < iterCount; i++ {
mac.Reset()
mac.Write(ui)
mac.Sum(ui[:0])
for j, b := range ui {
hi[j] ^= b
}
}
c.saltedPass = hi
}

func (c *Client) clientProof() []byte {
mac := hmac.New(c.newHash, c.saltedPass)
mac.Write([]byte("Client Key"))
clientKey := mac.Sum(nil)
hash := c.newHash()
hash.Write(clientKey)
storedKey := hash.Sum(nil)
mac = hmac.New(c.newHash, storedKey)
mac.Write(c.authMsg.Bytes())
clientProof := mac.Sum(nil)
for i, b := range clientKey {
clientProof[i] ^= b
}
clientProof64 := make([]byte, b64.EncodedLen(len(clientProof)))
b64.Encode(clientProof64, clientProof)
return clientProof64
}

func (c *Client) serverSignature() []byte {
mac := hmac.New(c.newHash, c.saltedPass)
mac.Write([]byte("Server Key"))
serverKey := mac.Sum(nil)

mac = hmac.New(c.newHash, serverKey)
mac.Write(c.authMsg.Bytes())
serverSignature := mac.Sum(nil)

encoded := make([]byte, b64.EncodedLen(len(serverSignature)))
b64.Encode(encoded, serverSignature)
return encoded
resp, c.err = c.conv.Step(string(in))
_, c.err = c.out.Write([]byte(resp))
return c.conv.Done() || c.err != nil
}
Loading

0 comments on commit 223c75e

Please sign in to comment.