diff --git a/config.go b/config.go index 7f68c03b1..d1b0eee67 100644 --- a/config.go +++ b/config.go @@ -128,6 +128,10 @@ type Config struct { // SessionStore is the container to store session for resumption. SessionStore SessionStore + // PeerCertDisablesSessionResumption prevents session resumption if a client certificate + // is provided, regardless of the state of the session store + PeerCertDisablesSessionResumption bool + // List of application protocols the peer supports, for ALPN SupportedProtocols []string } diff --git a/conn.go b/conn.go index aff5a95fc..6c4d692a9 100644 --- a/conn.go +++ b/conn.go @@ -155,26 +155,27 @@ func createConn(ctx context.Context, nextConn net.Conn, config *Config, isClient } hsCfg := &handshakeConfig{ - localPSKCallback: config.PSK, - localPSKIdentityHint: config.PSKIdentityHint, - localCipherSuites: cipherSuites, - localSignatureSchemes: signatureSchemes, - extendedMasterSecret: config.ExtendedMasterSecret, - localSRTPProtectionProfiles: config.SRTPProtectionProfiles, - serverName: serverName, - supportedProtocols: config.SupportedProtocols, - clientAuth: config.ClientAuth, - localCertificates: config.Certificates, - insecureSkipVerify: config.InsecureSkipVerify, - verifyPeerCertificate: config.VerifyPeerCertificate, - rootCAs: config.RootCAs, - clientCAs: config.ClientCAs, - customCipherSuites: config.CustomCipherSuites, - retransmitInterval: workerInterval, - log: logger, - initialEpoch: 0, - keyLogWriter: config.KeyLogWriter, - sessionStore: config.SessionStore, + localPSKCallback: config.PSK, + localPSKIdentityHint: config.PSKIdentityHint, + localCipherSuites: cipherSuites, + localSignatureSchemes: signatureSchemes, + extendedMasterSecret: config.ExtendedMasterSecret, + localSRTPProtectionProfiles: config.SRTPProtectionProfiles, + serverName: serverName, + supportedProtocols: config.SupportedProtocols, + clientAuth: config.ClientAuth, + localCertificates: config.Certificates, + insecureSkipVerify: config.InsecureSkipVerify, + verifyPeerCertificate: config.VerifyPeerCertificate, + rootCAs: config.RootCAs, + clientCAs: config.ClientCAs, + customCipherSuites: config.CustomCipherSuites, + retransmitInterval: workerInterval, + log: logger, + initialEpoch: 0, + keyLogWriter: config.KeyLogWriter, + sessionStore: config.SessionStore, + peerCertDisablesSessionResumption: config.PeerCertDisablesSessionResumption, } // rfc5246#section-7.4.3 diff --git a/conn_test.go b/conn_test.go index 9810971fe..385b20766 100644 --- a/conn_test.go +++ b/conn_test.go @@ -11,6 +11,7 @@ import ( "crypto/tls" "crypto/x509" "encoding/hex" + "encoding/pem" "errors" "fmt" "io" @@ -2528,6 +2529,269 @@ func TestSessionResume(t *testing.T) { } _ = res.c.Close() }) + + t.Run("resumed client cert", func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + type result struct { + c *Conn + err error + } + clientRes := make(chan result, 1) + + commonCert, _ := selfsign.GenerateSelfSignedWithDNS("example.com") + + certPool := x509.NewCertPool() + certPool.AppendCertsFromPEM(pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: commonCert.Certificate[0]})) + + ss := &memSessStore{} + + id, _ := hex.DecodeString("9b9fc92255634d9fb109febed42166717bb8ded8c738ba71bc7f2a0d9dae0306") + secret, _ := hex.DecodeString("2e942a37aca5241deb2295b5fcedac221c7078d2503d2b62aeb48c880d7da73c001238b708559686b9da6e829c05ead7") + + s := Session{ID: id, Secret: secret} + + ca, cb := dpipe.Pipe() + + _ = ss.Set(id, s) + _ = ss.Set([]byte(ca.RemoteAddr().String()+"_"+commonCert.Leaf.Subject.CommonName), s) + + go func() { + config := &Config{ + CipherSuites: []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256}, + ServerName: commonCert.Leaf.Subject.CommonName, + SessionStore: ss, + RootCAs: certPool, + Certificates: nil, // Client shouldn't need to send a cert to resume a session + MTU: 100, + } + c, err := ClientWithContext(ctx, ca, config) + clientRes <- result{c, err} + }() + + config := &Config{ + CipherSuites: []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256}, + SessionStore: ss, + ClientCAs: certPool, + Certificates: []tls.Certificate{commonCert}, + MTU: 100, + ClientAuth: RequireAndVerifyClientCert, + } + server, err := testServer(ctx, cb, config, true) + if err != nil { + t.Fatalf("TestSessionResume: Server failed(%v)", err) + } + + actualSessionID := server.ConnectionState().SessionID + actualMasterSecret := server.ConnectionState().masterSecret + if !bytes.Equal(actualSessionID, id) { + t.Errorf("TestSessionResumetion: SessionID Mismatch: expected(%v) actual(%v)", id, actualSessionID) + } + if !bytes.Equal(actualMasterSecret, secret) { + t.Errorf("TestSessionResumetion: masterSecret Mismatch: expected(%v) actual(%v)", secret, actualMasterSecret) + } + + defer func() { + _ = server.Close() + }() + + res := <-clientRes + if res.err != nil { + t.Fatal(res.err) + } + _ = res.c.Close() + }) + + t.Run("new session client cert", func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + type result struct { + c *Conn + err error + } + clientRes := make(chan result, 1) + + commonCert, _ := selfsign.GenerateSelfSignedWithDNS("example.com") + + certPool := x509.NewCertPool() + certPool.AppendCertsFromPEM(pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: commonCert.Certificate[0]})) + + s1 := &memSessStore{} + s2 := &memSessStore{} + + ca, cb := dpipe.Pipe() + go func() { + config := &Config{ + ServerName: commonCert.Leaf.Subject.CommonName, + SessionStore: s1, + RootCAs: certPool, + Certificates: []tls.Certificate{commonCert}, + } + c, err := ClientWithContext(ctx, ca, config) + clientRes <- result{c, err} + }() + + config := &Config{ + SessionStore: s2, + ClientAuth: RequireAndVerifyClientCert, + ClientCAs: certPool, + Certificates: []tls.Certificate{commonCert}, + } + server, err := testServer(ctx, cb, config, false) + if err != nil { + t.Fatalf("TestSessionResumetion: Server failed(%v)", err) + } + + actualSessionID := server.ConnectionState().SessionID + actualMasterSecret := server.ConnectionState().masterSecret + ss, _ := s2.Get(actualSessionID) + if !bytes.Equal(actualMasterSecret, ss.Secret) { + t.Errorf("TestSessionResumetion: masterSecret Mismatch: expected/actual:\n(%v)\n(%v)", ss.Secret, actualMasterSecret) + } + + if ss.Expiry.Unix() != commonCert.Leaf.NotAfter.Unix() { + t.Errorf("TestSessionResumption: expected server session store to contain certificate expiry") + } + + defer func() { + _ = server.Close() + }() + + res := <-clientRes + if res.err != nil { + t.Fatal(res.err) + } + cs, _ := s1.Get([]byte(ca.RemoteAddr().String() + "_" + commonCert.Leaf.Subject.CommonName)) + if !bytes.Equal(actualMasterSecret, cs.Secret) { + t.Errorf("TestSessionResumetion: masterSecret Mismatch: expected/actual\n(%v)\n(%v)", cs.Secret, actualMasterSecret) + } + + if cs.Expiry.Unix() != commonCert.Leaf.NotAfter.Unix() { + t.Errorf("TestSessionResumption: expected client session store to contain certificate expiry") + } + + _ = res.c.Close() + }) + + t.Run("expire client cert session", func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + type result struct { + c *Conn + err error + } + clientRes := make(chan result, 1) + + commonCert, _ := selfsign.GenerateSelfSignedWithDNS("example.com") + + certPool := x509.NewCertPool() + certPool.AppendCertsFromPEM(pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: commonCert.Certificate[0]})) + + ss := &memSessStore{} + + id, _ := hex.DecodeString("9b9fc92255634d9fb109febed42166717bb8ded8c738ba71bc7f2a0d9dae0306") + secret, _ := hex.DecodeString("2e942a37aca5241deb2295b5fcedac221c7078d2503d2b62aeb48c880d7da73c001238b708559686b9da6e829c05ead7") + + oldClientSessionTime := time.Now().Add(time.Hour) + clientSession := Session{ + ID: id, + Secret: secret, + Expiry: oldClientSessionTime, + } + + expiredServerSession := Session{ + ID: id, + Secret: secret, + Expiry: time.Now().Add(-time.Hour), // server should treat this as expired session and force a new cert verification + } + + ca, cb := dpipe.Pipe() + + _ = ss.Set(id, expiredServerSession) + _ = ss.Set([]byte(ca.RemoteAddr().String()+"_"+commonCert.Leaf.Subject.CommonName), clientSession) + + go func() { + config := &Config{ + CipherSuites: []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256}, + ServerName: commonCert.Leaf.Subject.CommonName, + SessionStore: ss, + RootCAs: certPool, + Certificates: []tls.Certificate{commonCert}, + MTU: 1200, // MTU must be able to fit cert chain in one packet + } + c, err := ClientWithContext(ctx, ca, config) + clientRes <- result{c, err} + }() + + config := &Config{ + CipherSuites: []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256}, + SessionStore: ss, + ClientCAs: certPool, + Certificates: []tls.Certificate{commonCert}, + MTU: 1200, + ClientAuth: RequireAndVerifyClientCert, + } + server, err := testServer(ctx, cb, config, false) + if err != nil { + t.Fatalf("TestSessionResume: Server failed(%v)", err) + } + + actualSessionID := server.ConnectionState().SessionID + actualMasterSecret := server.ConnectionState().masterSecret + if bytes.Equal(actualSessionID, id) { + t.Errorf("TestSessionResumption: SessionID Mismatch: expected new session ID(%v) actual(%v)", id, actualSessionID) + } + + if bytes.Equal(actualMasterSecret, secret) { + t.Errorf("TestSessionResumption: masterSecret Mismatch: expected new master secret (%v) actual(%v)", secret, actualMasterSecret) + } + + defer func() { + _ = server.Close() + }() + + res := <-clientRes + if res.err != nil { + t.Fatal(res.err) + } + + _, ok := ss.Map.Load(hex.EncodeToString(expiredServerSession.ID)) + if ok { + t.Errorf("expected server to have deleted session") + } + + cSess, ok := ss.Map.Load(hex.EncodeToString([]byte(ca.RemoteAddr().String() + "_" + commonCert.Leaf.Subject.CommonName))) + if !ok { + t.Errorf("expected client store to have cached new session ID") + } + + newClientSession := cSess.(Session) + if bytes.Equal(secret, newClientSession.Secret) { + t.Errorf("expected : expected client session store to contain new master secret (%v) actual(%v)", secret, newClientSession.Secret) + } + + if newClientSession.Expiry.Unix() == oldClientSessionTime.Unix() { + t.Errorf("expected new client session to have updated") + } + + if newClientSession.Expiry.Unix() != commonCert.Leaf.NotAfter.Unix() { + t.Errorf("expected new client session to expire with client cert") + } + + sSess, ok := ss.Map.Load(hex.EncodeToString(newClientSession.ID)) + if !ok { + t.Errorf("expected server store to have cached new client session ID") + } + newServerSession := sSess.(Session) + + if !bytes.Equal(newServerSession.Secret, newClientSession.Secret) { + t.Errorf("expected : expected session store to contain new shared secret (%v) actual(%v)", newServerSession.Secret, newClientSession.Secret) + } + _ = res.c.Close() + }) } type memSessStore struct { @@ -2549,12 +2813,17 @@ func (ms *memSessStore) Get(key []byte) (Session, error) { return Session{}, nil } - s, ok := v.(Session) - if !ok { + session := v.(Session) + if session.Expiry.IsZero() { + return session, nil + } + + if time.Now().After(session.Expiry) { + _ = ms.Del(key) return Session{}, nil } + return session, nil - return s, nil } func (ms *memSessStore) Del(key []byte) error { diff --git a/crypto.go b/crypto.go index 768ee470e..7cf3065a5 100644 --- a/crypto.go +++ b/crypto.go @@ -184,28 +184,30 @@ func loadCerts(rawCertificates [][]byte) ([]*x509.Certificate, error) { return certs, nil } -func verifyClientCert(rawCertificates [][]byte, roots *x509.CertPool) (chains [][]*x509.Certificate, err error) { +func verifyClientCert(rawCertificates [][]byte, roots *x509.CertPool) (chains [][]*x509.Certificate, expiry time.Time, err error) { certificate, err := loadCerts(rawCertificates) if err != nil { - return nil, err + return nil, time.Time{}, err } intermediateCAPool := x509.NewCertPool() for _, cert := range certificate[1:] { intermediateCAPool.AddCert(cert) } + opts := x509.VerifyOptions{ Roots: roots, CurrentTime: time.Now(), Intermediates: intermediateCAPool, KeyUsages: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth}, } - return certificate[0].Verify(opts) + chains, err = certificate[0].Verify(opts) + return chains, certificate[0].NotAfter, err } -func verifyServerCert(rawCertificates [][]byte, roots *x509.CertPool, serverName string) (chains [][]*x509.Certificate, err error) { +func verifyServerCert(rawCertificates [][]byte, roots *x509.CertPool, serverName string) (chains [][]*x509.Certificate, expiry time.Time, err error) { certificate, err := loadCerts(rawCertificates) if err != nil { - return nil, err + return nil, time.Time{}, err } intermediateCAPool := x509.NewCertPool() for _, cert := range certificate[1:] { @@ -217,5 +219,6 @@ func verifyServerCert(rawCertificates [][]byte, roots *x509.CertPool, serverName DNSName: serverName, Intermediates: intermediateCAPool, } - return certificate[0].Verify(opts) + chains, err = certificate[0].Verify(opts) + return chains, certificate[0].NotAfter, err } diff --git a/flight4handler.go b/flight4handler.go index 9e90f1993..0401d4af6 100644 --- a/flight4handler.go +++ b/flight4handler.go @@ -4,6 +4,7 @@ import ( "context" "crypto/rand" "crypto/x509" + "time" "github.com/pion/dtls/v2/pkg/crypto/clientcertificate" "github.com/pion/dtls/v2/pkg/crypto/elliptic" @@ -40,7 +41,9 @@ func flight4Parse(ctx context.Context, c flightConn, state *State, cache *handsh // And we have to check whether this certificate expired, revoked or changed. // // https://curl.se/docs/CVE-2016-5419.html - state.SessionID = nil + if cfg.peerCertDisablesSessionResumption { + state.SessionID = nil + } } if h, hasCertVerify := msgs[handshake.TypeCertificateVerify].(*handshake.MessageCertificateVerify); hasCertVerify { @@ -75,13 +78,15 @@ func flight4Parse(ctx context.Context, c flightConn, state *State, cache *handsh return 0, &alert.Alert{Level: alert.Fatal, Description: alert.BadCertificate}, err } var chains [][]*x509.Certificate + var expiry time.Time var err error var verified bool if cfg.clientAuth >= VerifyClientCertIfGiven { - if chains, err = verifyClientCert(state.PeerCertificates, cfg.clientCAs); err != nil { + if chains, expiry, err = verifyClientCert(state.PeerCertificates, cfg.clientCAs); err != nil { return 0, &alert.Alert{Level: alert.Fatal, Description: alert.BadCertificate}, err } verified = true + state.VerifiedCertExpiry = expiry } if cfg.verifyPeerCertificate != nil { if err := cfg.verifyPeerCertificate(state.PeerCertificates, chains); err != nil { @@ -148,6 +153,7 @@ func flight4Parse(ctx context.Context, c flightConn, state *State, cache *handsh s := Session{ ID: state.SessionID, Secret: state.masterSecret, + Expiry: state.VerifiedCertExpiry, } cfg.log.Tracef("[handshake] save new session: %x", s.ID) if err := cfg.sessionStore.Set(state.SessionID, s); err != nil { diff --git a/flight5handler.go b/flight5handler.go index 86435a532..36937f9b8 100644 --- a/flight5handler.go +++ b/flight5handler.go @@ -5,6 +5,7 @@ import ( "context" "crypto" "crypto/x509" + "time" "github.com/pion/dtls/v2/pkg/crypto/prf" "github.com/pion/dtls/v2/pkg/crypto/signaturehash" @@ -52,6 +53,7 @@ func flight5Parse(ctx context.Context, c flightConn, state *State, cache *handsh s := Session{ ID: state.SessionID, Secret: state.masterSecret, + Expiry: state.VerifiedCertExpiry, } cfg.log.Tracef("[handshake] save new session: %x", s.ID) if err := cfg.sessionStore.Set(c.sessionKey(), s); err != nil { @@ -317,10 +319,12 @@ func initalizeCipherSuite(state *State, cache *handshakeCache, cfg *handshakeCon return &alert.Alert{Level: alert.Fatal, Description: alert.BadCertificate}, err } var chains [][]*x509.Certificate + var expiry time.Time if !cfg.insecureSkipVerify { - if chains, err = verifyServerCert(state.PeerCertificates, cfg.rootCAs, cfg.serverName); err != nil { + if chains, expiry, err = verifyServerCert(state.PeerCertificates, cfg.rootCAs, cfg.serverName); err != nil { return &alert.Alert{Level: alert.Fatal, Description: alert.BadCertificate}, err } + state.VerifiedCertExpiry = expiry } if cfg.verifyPeerCertificate != nil { if err = cfg.verifyPeerCertificate(state.PeerCertificates, chains); err != nil { diff --git a/handshaker.go b/handshaker.go index 1c7b9ffa2..68a6339a8 100644 --- a/handshaker.go +++ b/handshaker.go @@ -88,24 +88,25 @@ type handshakeFSM struct { } type handshakeConfig struct { - localPSKCallback PSKCallback - localPSKIdentityHint []byte - localCipherSuites []CipherSuite // Available CipherSuites - localSignatureSchemes []signaturehash.Algorithm // Available signature schemes - extendedMasterSecret ExtendedMasterSecretType // Policy for the Extended Master Support extension - localSRTPProtectionProfiles []SRTPProtectionProfile // Available SRTPProtectionProfiles, if empty no SRTP support - serverName string - supportedProtocols []string - clientAuth ClientAuthType // If we are a client should we request a client certificate - localCertificates []tls.Certificate - nameToCertificate map[string]*tls.Certificate - insecureSkipVerify bool - verifyPeerCertificate func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error - sessionStore SessionStore - rootCAs *x509.CertPool - clientCAs *x509.CertPool - retransmitInterval time.Duration - customCipherSuites func() []CipherSuite + localPSKCallback PSKCallback + localPSKIdentityHint []byte + localCipherSuites []CipherSuite // Available CipherSuites + localSignatureSchemes []signaturehash.Algorithm // Available signature schemes + extendedMasterSecret ExtendedMasterSecretType // Policy for the Extended Master Support extension + localSRTPProtectionProfiles []SRTPProtectionProfile // Available SRTPProtectionProfiles, if empty no SRTP support + serverName string + supportedProtocols []string + clientAuth ClientAuthType // If we are a client should we request a client certificate + localCertificates []tls.Certificate + nameToCertificate map[string]*tls.Certificate + insecureSkipVerify bool + verifyPeerCertificate func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error + sessionStore SessionStore + peerCertDisablesSessionResumption bool + rootCAs *x509.CertPool + clientCAs *x509.CertPool + retransmitInterval time.Duration + customCipherSuites func() []CipherSuite onFlightState func(flightVal, handshakeState) log logging.LeveledLogger diff --git a/pkg/crypto/selfsign/selfsign.go b/pkg/crypto/selfsign/selfsign.go index 581e0a296..dceb90e3f 100644 --- a/pkg/crypto/selfsign/selfsign.go +++ b/pkg/crypto/selfsign/selfsign.go @@ -10,6 +10,7 @@ import ( "crypto/rsa" "crypto/tls" "crypto/x509" + "crypto/x509/pkix" "encoding/hex" "errors" "math/big" @@ -85,6 +86,9 @@ func WithDNS(key crypto.PrivateKey, cn string, sans ...string) (tls.Certificate, Version: 2, IsCA: true, DNSNames: names, + Subject: pkix.Name{ + CommonName: cn, + }, } raw, err := x509.CreateCertificate(rand.Reader, &template, &template, pubKey, key) diff --git a/session.go b/session.go index f52120cd8..f8befb708 100644 --- a/session.go +++ b/session.go @@ -1,11 +1,15 @@ package dtls +import "time" + // Session store data needed in resumption type Session struct { // ID store session id ID []byte // Secret store session master secret Secret []byte + // Optional expiry based on a verified certificate NotAfter date, otherwise empty + Expiry time.Time } // SessionStore defines methods needed for session resumption. diff --git a/state.go b/state.go index 90afb89ca..ca8d02f6c 100644 --- a/state.go +++ b/state.go @@ -4,6 +4,7 @@ import ( "bytes" "encoding/gob" "sync/atomic" + "time" "github.com/pion/dtls/v2/pkg/crypto/elliptic" "github.com/pion/dtls/v2/pkg/crypto/prf" @@ -23,6 +24,7 @@ type State struct { PeerCertificates [][]byte IdentityHint []byte SessionID []byte + VerifiedCertExpiry time.Time isClient bool