Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for session resumption when using client certificates #447

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions config.go
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,10 @@ type Config struct {
// SessionStore is the container to store session for resumption.
SessionStore SessionStore

// PeerCertDisablesSessionResumption prevents session resumption if a client certificate
// is provided, regardless of the state of the session store
PeerCertDisablesSessionResumption bool

// List of application protocols the peer supports, for ALPN
SupportedProtocols []string
}
Expand Down
41 changes: 21 additions & 20 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -155,26 +155,27 @@ func createConn(ctx context.Context, nextConn net.Conn, config *Config, isClient
}

hsCfg := &handshakeConfig{
localPSKCallback: config.PSK,
localPSKIdentityHint: config.PSKIdentityHint,
localCipherSuites: cipherSuites,
localSignatureSchemes: signatureSchemes,
extendedMasterSecret: config.ExtendedMasterSecret,
localSRTPProtectionProfiles: config.SRTPProtectionProfiles,
serverName: serverName,
supportedProtocols: config.SupportedProtocols,
clientAuth: config.ClientAuth,
localCertificates: config.Certificates,
insecureSkipVerify: config.InsecureSkipVerify,
verifyPeerCertificate: config.VerifyPeerCertificate,
rootCAs: config.RootCAs,
clientCAs: config.ClientCAs,
customCipherSuites: config.CustomCipherSuites,
retransmitInterval: workerInterval,
log: logger,
initialEpoch: 0,
keyLogWriter: config.KeyLogWriter,
sessionStore: config.SessionStore,
localPSKCallback: config.PSK,
localPSKIdentityHint: config.PSKIdentityHint,
localCipherSuites: cipherSuites,
localSignatureSchemes: signatureSchemes,
extendedMasterSecret: config.ExtendedMasterSecret,
localSRTPProtectionProfiles: config.SRTPProtectionProfiles,
serverName: serverName,
supportedProtocols: config.SupportedProtocols,
clientAuth: config.ClientAuth,
localCertificates: config.Certificates,
insecureSkipVerify: config.InsecureSkipVerify,
verifyPeerCertificate: config.VerifyPeerCertificate,
rootCAs: config.RootCAs,
clientCAs: config.ClientCAs,
customCipherSuites: config.CustomCipherSuites,
retransmitInterval: workerInterval,
log: logger,
initialEpoch: 0,
keyLogWriter: config.KeyLogWriter,
sessionStore: config.SessionStore,
peerCertDisablesSessionResumption: config.PeerCertDisablesSessionResumption,
}

// rfc5246#section-7.4.3
Expand Down
275 changes: 272 additions & 3 deletions conn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"crypto/tls"
"crypto/x509"
"encoding/hex"
"encoding/pem"
"errors"
"fmt"
"io"
Expand Down Expand Up @@ -2528,6 +2529,269 @@ func TestSessionResume(t *testing.T) {
}
_ = res.c.Close()
})

t.Run("resumed client cert", func(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()

type result struct {
c *Conn
err error
}
clientRes := make(chan result, 1)

commonCert, _ := selfsign.GenerateSelfSignedWithDNS("example.com")

certPool := x509.NewCertPool()
certPool.AppendCertsFromPEM(pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: commonCert.Certificate[0]}))

ss := &memSessStore{}

id, _ := hex.DecodeString("9b9fc92255634d9fb109febed42166717bb8ded8c738ba71bc7f2a0d9dae0306")
secret, _ := hex.DecodeString("2e942a37aca5241deb2295b5fcedac221c7078d2503d2b62aeb48c880d7da73c001238b708559686b9da6e829c05ead7")

s := Session{ID: id, Secret: secret}

ca, cb := dpipe.Pipe()

_ = ss.Set(id, s)
_ = ss.Set([]byte(ca.RemoteAddr().String()+"_"+commonCert.Leaf.Subject.CommonName), s)

go func() {
config := &Config{
CipherSuites: []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256},
ServerName: commonCert.Leaf.Subject.CommonName,
SessionStore: ss,
RootCAs: certPool,
Certificates: nil, // Client shouldn't need to send a cert to resume a session
MTU: 100,
}
c, err := ClientWithContext(ctx, ca, config)
clientRes <- result{c, err}
}()

config := &Config{
CipherSuites: []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256},
SessionStore: ss,
ClientCAs: certPool,
Certificates: []tls.Certificate{commonCert},
MTU: 100,
ClientAuth: RequireAndVerifyClientCert,
}
server, err := testServer(ctx, cb, config, true)
if err != nil {
t.Fatalf("TestSessionResume: Server failed(%v)", err)
}

actualSessionID := server.ConnectionState().SessionID
actualMasterSecret := server.ConnectionState().masterSecret
if !bytes.Equal(actualSessionID, id) {
t.Errorf("TestSessionResumetion: SessionID Mismatch: expected(%v) actual(%v)", id, actualSessionID)
}
if !bytes.Equal(actualMasterSecret, secret) {
t.Errorf("TestSessionResumetion: masterSecret Mismatch: expected(%v) actual(%v)", secret, actualMasterSecret)
}

defer func() {
_ = server.Close()
}()

res := <-clientRes
if res.err != nil {
t.Fatal(res.err)
}
_ = res.c.Close()
})

t.Run("new session client cert", func(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()

type result struct {
c *Conn
err error
}
clientRes := make(chan result, 1)

commonCert, _ := selfsign.GenerateSelfSignedWithDNS("example.com")

certPool := x509.NewCertPool()
certPool.AppendCertsFromPEM(pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: commonCert.Certificate[0]}))

s1 := &memSessStore{}
s2 := &memSessStore{}

ca, cb := dpipe.Pipe()
go func() {
config := &Config{
ServerName: commonCert.Leaf.Subject.CommonName,
SessionStore: s1,
RootCAs: certPool,
Certificates: []tls.Certificate{commonCert},
}
c, err := ClientWithContext(ctx, ca, config)
clientRes <- result{c, err}
}()

config := &Config{
SessionStore: s2,
ClientAuth: RequireAndVerifyClientCert,
ClientCAs: certPool,
Certificates: []tls.Certificate{commonCert},
}
server, err := testServer(ctx, cb, config, false)
if err != nil {
t.Fatalf("TestSessionResumetion: Server failed(%v)", err)
}

actualSessionID := server.ConnectionState().SessionID
actualMasterSecret := server.ConnectionState().masterSecret
ss, _ := s2.Get(actualSessionID)
if !bytes.Equal(actualMasterSecret, ss.Secret) {
t.Errorf("TestSessionResumetion: masterSecret Mismatch: expected/actual:\n(%v)\n(%v)", ss.Secret, actualMasterSecret)
}

if ss.Expiry.Unix() != commonCert.Leaf.NotAfter.Unix() {
t.Errorf("TestSessionResumption: expected server session store to contain certificate expiry")
}

defer func() {
_ = server.Close()
}()

res := <-clientRes
if res.err != nil {
t.Fatal(res.err)
}
cs, _ := s1.Get([]byte(ca.RemoteAddr().String() + "_" + commonCert.Leaf.Subject.CommonName))
if !bytes.Equal(actualMasterSecret, cs.Secret) {
t.Errorf("TestSessionResumetion: masterSecret Mismatch: expected/actual\n(%v)\n(%v)", cs.Secret, actualMasterSecret)
}

if cs.Expiry.Unix() != commonCert.Leaf.NotAfter.Unix() {
t.Errorf("TestSessionResumption: expected client session store to contain certificate expiry")
}

_ = res.c.Close()
})

t.Run("expire client cert session", func(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()

type result struct {
c *Conn
err error
}
clientRes := make(chan result, 1)

commonCert, _ := selfsign.GenerateSelfSignedWithDNS("example.com")

certPool := x509.NewCertPool()
certPool.AppendCertsFromPEM(pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: commonCert.Certificate[0]}))

ss := &memSessStore{}

id, _ := hex.DecodeString("9b9fc92255634d9fb109febed42166717bb8ded8c738ba71bc7f2a0d9dae0306")
secret, _ := hex.DecodeString("2e942a37aca5241deb2295b5fcedac221c7078d2503d2b62aeb48c880d7da73c001238b708559686b9da6e829c05ead7")

oldClientSessionTime := time.Now().Add(time.Hour)
clientSession := Session{
ID: id,
Secret: secret,
Expiry: oldClientSessionTime,
}

expiredServerSession := Session{
ID: id,
Secret: secret,
Expiry: time.Now().Add(-time.Hour), // server should treat this as expired session and force a new cert verification
}

ca, cb := dpipe.Pipe()

_ = ss.Set(id, expiredServerSession)
_ = ss.Set([]byte(ca.RemoteAddr().String()+"_"+commonCert.Leaf.Subject.CommonName), clientSession)

go func() {
config := &Config{
CipherSuites: []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256},
ServerName: commonCert.Leaf.Subject.CommonName,
SessionStore: ss,
RootCAs: certPool,
Certificates: []tls.Certificate{commonCert},
MTU: 1200, // MTU must be able to fit cert chain in one packet
}
c, err := ClientWithContext(ctx, ca, config)
clientRes <- result{c, err}
}()

config := &Config{
CipherSuites: []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256},
SessionStore: ss,
ClientCAs: certPool,
Certificates: []tls.Certificate{commonCert},
MTU: 1200,
ClientAuth: RequireAndVerifyClientCert,
}
server, err := testServer(ctx, cb, config, false)
if err != nil {
t.Fatalf("TestSessionResume: Server failed(%v)", err)
}

actualSessionID := server.ConnectionState().SessionID
actualMasterSecret := server.ConnectionState().masterSecret
if bytes.Equal(actualSessionID, id) {
t.Errorf("TestSessionResumption: SessionID Mismatch: expected new session ID(%v) actual(%v)", id, actualSessionID)
}

if bytes.Equal(actualMasterSecret, secret) {
t.Errorf("TestSessionResumption: masterSecret Mismatch: expected new master secret (%v) actual(%v)", secret, actualMasterSecret)
}

defer func() {
_ = server.Close()
}()

res := <-clientRes
if res.err != nil {
t.Fatal(res.err)
}

_, ok := ss.Map.Load(hex.EncodeToString(expiredServerSession.ID))
if ok {
t.Errorf("expected server to have deleted session")
}

cSess, ok := ss.Map.Load(hex.EncodeToString([]byte(ca.RemoteAddr().String() + "_" + commonCert.Leaf.Subject.CommonName)))
if !ok {
t.Errorf("expected client store to have cached new session ID")
}

newClientSession := cSess.(Session)
if bytes.Equal(secret, newClientSession.Secret) {
t.Errorf("expected : expected client session store to contain new master secret (%v) actual(%v)", secret, newClientSession.Secret)
}

if newClientSession.Expiry.Unix() == oldClientSessionTime.Unix() {
t.Errorf("expected new client session to have updated")
}

if newClientSession.Expiry.Unix() != commonCert.Leaf.NotAfter.Unix() {
t.Errorf("expected new client session to expire with client cert")
}

sSess, ok := ss.Map.Load(hex.EncodeToString(newClientSession.ID))
if !ok {
t.Errorf("expected server store to have cached new client session ID")
}
newServerSession := sSess.(Session)

if !bytes.Equal(newServerSession.Secret, newClientSession.Secret) {
t.Errorf("expected : expected session store to contain new shared secret (%v) actual(%v)", newServerSession.Secret, newClientSession.Secret)
}
_ = res.c.Close()
})
}

type memSessStore struct {
Expand All @@ -2549,12 +2813,17 @@ func (ms *memSessStore) Get(key []byte) (Session, error) {
return Session{}, nil
}

s, ok := v.(Session)
if !ok {
session := v.(Session)
if session.Expiry.IsZero() {
return session, nil
}

if time.Now().After(session.Expiry) {
_ = ms.Del(key)
return Session{}, nil
}
return session, nil

return s, nil
}

func (ms *memSessStore) Del(key []byte) error {
Expand Down
Loading