diff --git a/recws.go b/recws.go index 62ba6c4..60eeb91 100644 --- a/recws.go +++ b/recws.go @@ -3,6 +3,7 @@ package recws import ( + "crypto/tls" "errors" "log" "math/rand" @@ -36,6 +37,8 @@ type RecConn struct { // Proxy specifies the proxy function for the dialer // defaults to ProxyFromEnvironment Proxy func(*http.Request) (*url.URL, error) + // Client TLS config to use on reconnect + TLSClientConfig *tls.Config // SubscribeHandler fires after the connection successfully establish. SubscribeHandler func() error // KeepAliveTimeout is an interval for sending ping/pong messages @@ -243,13 +246,14 @@ func (rc *RecConn) setDefaultProxy() { } } -func (rc *RecConn) setDefaultDialer(handshakeTimeout time.Duration) { +func (rc *RecConn) setDefaultDialer(tlsClientConfig *tls.Config, handshakeTimeout time.Duration) { rc.mu.Lock() defer rc.mu.Unlock() rc.dialer = &websocket.Dialer{ HandshakeTimeout: handshakeTimeout, Proxy: rc.Proxy, + TLSClientConfig: tlsClientConfig, } } @@ -260,6 +264,20 @@ func (rc *RecConn) getHandshakeTimeout() time.Duration { return rc.HandshakeTimeout } +func (rc *RecConn) getTLSClientConfig() *tls.Config { + rc.mu.RLock() + defer rc.mu.RUnlock() + + return rc.TLSClientConfig +} + +func (rc *RecConn) SetTLSClientConfig(tlsClientConfig *tls.Config) { + rc.mu.Lock() + defer rc.mu.Unlock() + + rc.TLSClientConfig = tlsClientConfig +} + // Dial creates a new client connection. // The URL url specifies the host and request URI. Use requestHeader to specify // the origin (Origin), subprotocols (Sec-WebSocket-Protocol) and cookies @@ -280,7 +298,7 @@ func (rc *RecConn) Dial(urlStr string, reqHeader http.Header) { rc.setDefaultRecIntvlFactor() rc.setDefaultHandshakeTimeout() rc.setDefaultProxy() - rc.setDefaultDialer(rc.getHandshakeTimeout()) + rc.setDefaultDialer(rc.getTLSClientConfig(), rc.getHandshakeTimeout()) // Connect go rc.connect()