From a17073a5c9d31bcc7507e07d7be4301ae5f2d250 Mon Sep 17 00:00:00 2001 From: Edoardo Spadolini Date: Fri, 20 Sep 2024 19:50:18 +0200 Subject: [PATCH 1/3] Simplify peer server TLS --- lib/proxy/peer/client_test.go | 12 +++-- lib/proxy/peer/credentials.go | 39 --------------- lib/proxy/peer/helpers_test.go | 36 +++++++------- lib/proxy/peer/server.go | 87 ++++++++++++++++++++++++---------- lib/service/service.go | 43 ++++++++--------- lib/utils/tls.go | 7 +++ 6 files changed, 114 insertions(+), 110 deletions(-) diff --git a/lib/proxy/peer/client_test.go b/lib/proxy/peer/client_test.go index 05f87c6610c3..d4e1cd7c5df5 100644 --- a/lib/proxy/peer/client_test.go +++ b/lib/proxy/peer/client_test.go @@ -138,10 +138,11 @@ func TestCAChange(t *testing.T) { currentServerCA := newAtomicCA(serverCA) client := setupClient(t, clientCA, currentServerCA, types.RoleProxy) - server, _ := setupServer(t, "s1", serverCA, clientCA, types.RoleProxy) + server, ts := setupServer(t, "s1", serverCA, clientCA, types.RoleProxy) + t.Cleanup(func() { server.Close() }) // dial server and send a test data frame - conn, err := client.connect("s1", server.config.Listener.Addr().String()) + conn, err := client.connect("s1", ts.GetPeerAddr()) require.NoError(t, err) require.NotNil(t, conn) ctx, cancel := context.WithCancel(context.Background()) @@ -153,11 +154,12 @@ func TestCAChange(t *testing.T) { // rotate server ca require.NoError(t, server.Close()) newServerCA := newSelfSignedCA(t) - server, _ = setupServer(t, "s1", newServerCA, clientCA, types.RoleProxy) + server2, ts := setupServer(t, "s1", newServerCA, clientCA, types.RoleProxy) + t.Cleanup(func() { server2.Close() }) // new connection should fail because client tls config still references old // RootCAs. - conn, err = client.connect("s1", server.config.Listener.Addr().String()) + conn, err = client.connect("s1", ts.GetPeerAddr()) require.NoError(t, err) require.NotNil(t, conn) stream, err = clientapi.NewProxyServiceClient(conn.ClientConn).DialNode(ctx) @@ -168,7 +170,7 @@ func TestCAChange(t *testing.T) { // RootCAs. currentServerCA.Store(newServerCA) - conn, err = client.connect("s1", server.config.Listener.Addr().String()) + conn, err = client.connect("s1", ts.GetPeerAddr()) require.NoError(t, err) require.NotNil(t, conn) stream, err = clientapi.NewProxyServiceClient(conn.ClientConn).DialNode(ctx) diff --git a/lib/proxy/peer/credentials.go b/lib/proxy/peer/credentials.go index e6dfc29160cb..4767102704eb 100644 --- a/lib/proxy/peer/credentials.go +++ b/lib/proxy/peer/credentials.go @@ -30,45 +30,6 @@ import ( "github.com/gravitational/teleport/lib/tlsca" ) -// serverCredentials wraps a [crendentials.TransportCredentials] that -// extends the ServerHandshake to ensure the credentials contain the proxy system role. -type serverCredentials struct { - credentials.TransportCredentials -} - -// newServerCredentials creates new serverCredentials from the given [crendentials.TransportCredentials]. -func newServerCredentials(creds credentials.TransportCredentials) *serverCredentials { - return &serverCredentials{ - TransportCredentials: creds, - } -} - -// ServerHandshake performs the TLS handshake and then verifies that the client -// attempting to connect is a Proxy. -func (c *serverCredentials) ServerHandshake(conn net.Conn) (_ net.Conn, _ credentials.AuthInfo, err error) { - conn, authInfo, err := c.TransportCredentials.ServerHandshake(conn) - if err != nil { - return nil, nil, trace.Wrap(err) - } - - defer func() { - if err != nil { - conn.Close() - } - }() - - identity, err := getIdentity(authInfo) - if err != nil { - return nil, nil, trace.Wrap(err) - } - - if err := checkProxyRole(identity); err != nil { - return nil, nil, trace.Wrap(err) - } - - return conn, authInfo, nil -} - // clientCredentials wraps a [crendentials.TransportCredentials] that // extends the ClientHandshake to ensure the credentials contain the proxy system role // and that connections are established to the expected peer. diff --git a/lib/proxy/peer/helpers_test.go b/lib/proxy/peer/helpers_test.go index a5217966ee96..eb7eaccfe268 100644 --- a/lib/proxy/peer/helpers_test.go +++ b/lib/proxy/peer/helpers_test.go @@ -219,22 +219,21 @@ func setupServer(t *testing.T, name string, serverCA, clientCA *tlsca.CertAuthor Username: name + ".test", Groups: []string{string(role)}, }) - tlsConf := &tls.Config{ - Certificates: []tls.Certificate{tlsCert}, - } - tlsConf.ClientCAs = x509.NewCertPool() - tlsConf.ClientCAs.AddCert(clientCA.Cert) - - listener, err := net.Listen("tcp", "localhost:0") - require.NoError(t, err) + clientCAs := x509.NewCertPool() + clientCAs.AddCert(clientCA.Cert) config := ServerConfig{ - Listener: listener, - TLSConfig: tlsConf, ClusterDialer: &mockClusterDialer{}, - service: &mockProxyService{}, - ClusterName: "test", + GetCertificate: func(*tls.ClientHelloInfo) (*tls.Certificate, error) { + return &tlsCert, nil + }, + GetClientCAs: func(*tls.ClientHelloInfo) (*x509.CertPool, error) { + return clientCAs, nil + }, + + service: &mockProxyService{}, } + for _, option := range options { option(&config) } @@ -242,17 +241,20 @@ func setupServer(t *testing.T, name string, serverCA, clientCA *tlsca.CertAuthor server, err := NewServer(config) require.NoError(t, err) - ts, err := types.NewServer( - name, types.KindProxy, - types.ServerSpecV2{PeerAddr: listener.Addr().String()}, - ) + listener, err := net.Listen("tcp", "localhost:0") require.NoError(t, err) - go server.Serve() + go server.Serve(listener) t.Cleanup(func() { require.NoError(t, server.Close()) }) + ts, err := types.NewServer( + name, types.KindProxy, + types.ServerSpecV2{PeerAddr: listener.Addr().String()}, + ) + require.NoError(t, err) + return server, ts } diff --git a/lib/proxy/peer/server.go b/lib/proxy/peer/server.go index bb9015954ebe..25d1220dd841 100644 --- a/lib/proxy/peer/server.go +++ b/lib/proxy/peer/server.go @@ -20,9 +20,11 @@ package peer import ( "crypto/tls" + "crypto/x509" "errors" "math" "net" + "slices" "time" "github.com/gravitational/trace" @@ -34,7 +36,9 @@ import ( "github.com/gravitational/teleport" "github.com/gravitational/teleport/api/client/proto" "github.com/gravitational/teleport/api/metadata" + "github.com/gravitational/teleport/api/types" "github.com/gravitational/teleport/api/utils/grpc/interceptors" + "github.com/gravitational/teleport/lib/tlsca" "github.com/gravitational/teleport/lib/utils" ) @@ -45,11 +49,12 @@ const ( // ServerConfig configures a Server instance. type ServerConfig struct { - Listener net.Listener - TLSConfig *tls.Config - ClusterDialer ClusterDialer Log logrus.FieldLogger - ClusterName string + ClusterDialer ClusterDialer + + CipherSuites []uint16 + GetCertificate func(*tls.ClientHelloInfo) (*tls.Certificate, error) + GetClientCAs func(*tls.ClientHelloInfo) (*x509.CertPool, error) // service is a custom ProxyServiceServer // configurable for testing purposes. @@ -59,29 +64,23 @@ type ServerConfig struct { // checkAndSetDefaults checks and sets default values func (c *ServerConfig) checkAndSetDefaults() error { if c.Log == nil { - c.Log = logrus.New() + c.Log = logrus.StandardLogger() } c.Log = c.Log.WithField( teleport.ComponentKey, teleport.Component(teleport.ComponentProxy, "peer"), ) - if c.Listener == nil { - return trace.BadParameter("missing listener") - } - if c.ClusterDialer == nil { return trace.BadParameter("missing cluster dialer server") } - if c.ClusterName == "" { - return trace.BadParameter("missing cluster name") + if c.GetCertificate == nil { + return trace.BadParameter("missing GetCertificate") } - - if c.TLSConfig == nil { - return trace.BadParameter("missing tls config") + if c.GetClientCAs == nil { + return trace.BadParameter("missing GetClientCAs") } - c.TLSConfig.ClientAuth = tls.RequireAndVerifyClientCert if c.service == nil { c.service = &proxyService{ @@ -95,13 +94,14 @@ func (c *ServerConfig) checkAndSetDefaults() error { // Server is a proxy service server using grpc and tls. type Server struct { - config ServerConfig - server *grpc.Server + log logrus.FieldLogger + clusterDialer ClusterDialer + server *grpc.Server } // NewServer creates a new proxy server instance. -func NewServer(config ServerConfig) (*Server, error) { - err := config.checkAndSetDefaults() +func NewServer(cfg ServerConfig) (*Server, error) { + err := cfg.checkAndSetDefaults() if err != nil { return nil, trace.Wrap(err) } @@ -113,8 +113,27 @@ func NewServer(config ServerConfig) (*Server, error) { reporter := newReporter(metrics) + tlsConfig := utils.TLSConfig(cfg.CipherSuites) + tlsConfig.NextProtos = []string{"h2"} + tlsConfig.GetCertificate = cfg.GetCertificate + tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert + tlsConfig.VerifyPeerCertificate = verifyPeerCertificateIsProxy + + getClientCAs := cfg.GetClientCAs + tlsConfig.GetConfigForClient = func(chi *tls.ClientHelloInfo) (*tls.Config, error) { + clientCAs, err := getClientCAs(chi) + if err != nil { + return nil, trace.Wrap(err) + } + + utils.RefreshTLSConfigTickets(tlsConfig) + c := tlsConfig.Clone() + c.ClientCAs = clientCAs + return c, nil + } + server := grpc.NewServer( - grpc.Creds(newServerCredentials(credentials.NewTLS(config.TLSConfig))), + grpc.Creds(credentials.NewTLS(tlsConfig)), grpc.StatsHandler(newStatsHandler(reporter)), grpc.ChainStreamInterceptor(metadata.StreamServerInterceptor, interceptors.GRPCServerStreamErrorInterceptor), grpc.KeepaliveParams(keepalive.ServerParameters{ @@ -136,17 +155,35 @@ func NewServer(config ServerConfig) (*Server, error) { grpc.MaxConcurrentStreams(math.MaxUint32), ) - proto.RegisterProxyServiceServer(server, config.service) + proto.RegisterProxyServiceServer(server, cfg.service) return &Server{ - config: config, - server: server, + log: cfg.Log, + clusterDialer: cfg.ClusterDialer, + server: server, }, nil } +func verifyPeerCertificateIsProxy(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error { + if len(verifiedChains) < 1 { + return trace.AccessDenied("missing client certificate (this is a bug)") + } + + clientCert := verifiedChains[0][0] + clientIdentity, err := tlsca.FromSubject(clientCert.Subject, clientCert.NotAfter) + if err != nil { + return trace.Wrap(err) + } + + if !slices.Contains(clientIdentity.Groups, string(types.RoleProxy)) { + return trace.AccessDenied("expected Proxy client credentials") + } + return nil +} + // Serve starts the proxy server. -func (s *Server) Serve() error { - if err := s.server.Serve(s.config.Listener); err != nil { +func (s *Server) Serve(l net.Listener) error { + if err := s.server.Serve(l); err != nil { if errors.Is(err, grpc.ErrServerStopped) || utils.IsUseOfClosedNetworkError(err) { return nil diff --git a/lib/service/service.go b/lib/service/service.go index 54ec59037bf1..b2b141a64b75 100644 --- a/lib/service/service.go +++ b/lib/service/service.go @@ -4676,7 +4676,7 @@ func (process *TeleportProcess) initProxyEndpoint(conn *Connector) error { } var peerAddrString string - var proxyServer *peer.Server + var peerServer *peer.Server if !process.Config.Proxy.DisableReverseTunnel && listeners.proxyPeer != nil { peerAddr, err := process.Config.Proxy.PublicPeerAddr() if err != nil { @@ -4684,25 +4684,20 @@ func (process *TeleportProcess) initProxyEndpoint(conn *Connector) error { } peerAddrString = peerAddr.String() - // TODO(espadolini): once connectors are live updated we can get rid of - // this and just refer to the host CA pool in the connector instead - peerServerTLSConfig := serverTLSConfig.Clone() - peerServerTLSConfig.GetConfigForClient = func(chi *tls.ClientHelloInfo) (*tls.Config, error) { - pool, _, err := authclient.ClientCertPool(chi.Context(), accessPoint, clusterName, types.HostCA) - if err != nil { - return nil, trace.Wrap(err) - } - tlsConfig := peerServerTLSConfig.Clone() - tlsConfig.ClientCAs = pool - return tlsConfig, nil - } - - proxyServer, err = peer.NewServer(peer.ServerConfig{ - Listener: listeners.proxyPeer, - TLSConfig: peerServerTLSConfig, + peerServer, err = peer.NewServer(peer.ServerConfig{ + Log: process.log, ClusterDialer: clusterdial.NewClusterDialer(tsrv), - Log: process.log.WithField(teleport.ComponentKey, teleport.Component(teleport.ComponentReverseTunnelServer, process.id)), - ClusterName: clusterName, + CipherSuites: cfg.CipherSuites, + GetCertificate: func(*tls.ClientHelloInfo) (*tls.Certificate, error) { + return conn.serverGetCertificate() + }, + GetClientCAs: func(chi *tls.ClientHelloInfo) (*x509.CertPool, error) { + pool, _, err := authclient.ClientCertPool(chi.Context(), accessPoint, clusterName, types.HostCA) + if err != nil { + return nil, trace.Wrap(err) + } + return pool, nil + }, }) if err != nil { return trace.Wrap(err) @@ -4715,7 +4710,7 @@ func (process *TeleportProcess) initProxyEndpoint(conn *Connector) error { } logger.InfoContext(process.ExitContext(), "Starting peer proxy service", "listen_address", logutils.StringerAttr(listeners.proxyPeer.Addr())) - err := proxyServer.Serve() + err := peerServer.Serve(listeners.proxyPeer) if err != nil { return trace.Wrap(err) } @@ -5258,8 +5253,8 @@ func (process *TeleportProcess) initProxyEndpoint(conn *Connector) error { warnOnErr(process.ExitContext(), tsrv.Close(), logger) } warnOnErr(process.ExitContext(), rcWatcher.Close(), logger) - if proxyServer != nil { - warnOnErr(process.ExitContext(), proxyServer.Close(), logger) + if peerServer != nil { + warnOnErr(process.ExitContext(), peerServer.Close(), logger) } if webServer != nil { warnOnErr(process.ExitContext(), webServer.Close(), logger) @@ -5309,8 +5304,8 @@ func (process *TeleportProcess) initProxyEndpoint(conn *Connector) error { warnOnErr(ctx, tsrv.Shutdown(ctx), logger) } warnOnErr(ctx, rcWatcher.Close(), logger) - if proxyServer != nil { - warnOnErr(ctx, proxyServer.Shutdown(), logger) + if peerServer != nil { + warnOnErr(ctx, peerServer.Shutdown(), logger) } if peerClient != nil { peerClient.Shutdown(ctx) diff --git a/lib/utils/tls.go b/lib/utils/tls.go index b78f86b67e0e..8f619b63faf8 100644 --- a/lib/utils/tls.go +++ b/lib/utils/tls.go @@ -190,3 +190,10 @@ func DefaultCipherSuites() []uint16 { tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, } } + +// RefreshTLSConfigTickets should be called right before cloning a [tls.Config] +// for a one-off use to not break TLS session resumption, as a workaround for +// https://github.com/golang/go/issues/60506 . +func RefreshTLSConfigTickets(c *tls.Config) { + _, _ = c.DecryptTicket(nil, tls.ConnectionState{}) +} From e81db873157d78f65339f15b2aad639695e7e709 Mon Sep 17 00:00:00 2001 From: Edoardo Spadolini Date: Fri, 4 Oct 2024 18:53:04 +0200 Subject: [PATCH 2/3] Make the peer clientConn generic --- lib/proxy/peer/client.go | 313 ++++++++++++++++++++------------- lib/proxy/peer/client_test.go | 47 ++--- lib/proxy/peer/helpers_test.go | 4 +- lib/proxy/peer/server_test.go | 8 +- 4 files changed, 223 insertions(+), 149 deletions(-) diff --git a/lib/proxy/peer/client.go b/lib/proxy/peer/client.go index 6ce7958de0f2..594795fddb5b 100644 --- a/lib/proxy/peer/client.go +++ b/lib/proxy/peer/client.go @@ -21,7 +21,7 @@ package peer import ( "context" "crypto/tls" - "math/rand" + "math/rand/v2" "net" "sync" "time" @@ -89,11 +89,11 @@ type ClientConfig struct { } // connShuffler shuffles the order of client connections. -type connShuffler func([]*clientConn) +type connShuffler func([]clientConn) // randomConnShuffler returns a conn shuffler that randomizes the order of connections. func randomConnShuffler() connShuffler { - return func(conns []*clientConn) { + return func(conns []clientConn) { rand.Shuffle(len(conns), func(i, j int) { conns[i], conns[j] = conns[j], conns[i] }) @@ -102,7 +102,7 @@ func randomConnShuffler() connShuffler { // noopConnShutffler returns a conn shuffler that keeps the original connection ordering. func noopConnShuffler() connShuffler { - return func([]*clientConn) {} + return func([]clientConn) {} } // checkAndSetDefaults checks and sets default values @@ -158,9 +158,36 @@ func (c *ClientConfig) checkAndSetDefaults() error { return nil } -// clientConn hold info about a dialed grpc connection -type clientConn struct { - *grpc.ClientConn +// clientConn manages client connections to a specific peer proxy (with a fixed +// host ID and address). +type clientConn interface { + // peerID returns the host ID of the peer proxy. + peerID() string + // peerAddr returns the address of the peer proxy. + peerAddr() string + + // dial opens a connection of a given tunnel type to a node with the given + // ID through the peer proxy managed by the clientConn. + dial( + nodeID string, + src net.Addr, + dst net.Addr, + tunnelType types.TunnelType, + ) (net.Conn, error) + + // close closes all connections and releases any background resources + // immediately. + close() error + + // shutdown waits until all connections are closed or the context is done, + // then acts like close. + shutdown(context.Context) +} + +// grpcClientConn manages client connections to a specific peer proxy over gRPC. +type grpcClientConn struct { + cc *grpc.ClientConn + metrics *clientMetrics id string addr string @@ -173,7 +200,17 @@ type clientConn struct { count int } -func (c *clientConn) maybeAcquire() (release func()) { +var _ clientConn = (*grpcClientConn)(nil) + +// peerID implements [clientConn]. +func (c *grpcClientConn) peerID() string { return c.id } + +// peerAddr implements [clientConn]. +func (c *grpcClientConn) peerAddr() string { return c.addr } + +// maybeAcquire returns a non-nil release func if the grpcClientConn is +// currently allowed to open connections; i.e., if it hasn't fully shut down. +func (c *grpcClientConn) maybeAcquire() (release func()) { c.mu.Lock() defer c.mu.Unlock() @@ -192,10 +229,9 @@ func (c *clientConn) maybeAcquire() (release func()) { }) } -// Shutdown closes the clientConn after all connections through it are closed, -// or after the context is done. -func (c *clientConn) Shutdown(ctx context.Context) { - defer c.Close() +// shutdown implements [clientConn]. +func (c *grpcClientConn) shutdown(ctx context.Context) { + defer c.cc.Close() c.mu.Lock() defer c.mu.Unlock() @@ -214,14 +250,87 @@ func (c *clientConn) Shutdown(ctx context.Context) { } } -// Client is a peer proxy service client using grpc and tls. +// close implements [clientConn]. +func (c *grpcClientConn) close() error { + return c.cc.Close() +} + +// dial implements [clientConn]. +func (c *grpcClientConn) dial( + nodeID string, + src net.Addr, + dst net.Addr, + tunnelType types.TunnelType, +) (net.Conn, error) { + release := c.maybeAcquire() + if release == nil { + c.metrics.reportTunnelError(errorProxyPeerTunnelRPC) + return nil, trace.ConnectionProblem(nil, "error starting stream: connection is shutting down") + } + + ctx, cancel := context.WithCancel(context.Background()) + context.AfterFunc(ctx, release) + + stream, err := clientapi.NewProxyServiceClient(c.cc).DialNode(ctx) + if err != nil { + cancel() + c.metrics.reportTunnelError(errorProxyPeerTunnelRPC) + return nil, trace.ConnectionProblem(err, "error starting stream: %v", err) + } + + err = stream.Send(&clientapi.Frame{ + Message: &clientapi.Frame_DialRequest{ + DialRequest: &clientapi.DialRequest{ + NodeID: nodeID, + TunnelType: tunnelType, + Source: &clientapi.NetAddr{ + Addr: src.String(), + Network: src.Network(), + }, + Destination: &clientapi.NetAddr{ + Addr: dst.String(), + Network: dst.Network(), + }, + }, + }, + }) + if err != nil { + cancel() + return nil, trace.ConnectionProblem(err, "error sending dial frame: %v", err) + } + msg, err := stream.Recv() + if err != nil { + cancel() + return nil, trace.ConnectionProblem(err, "error receiving dial response: %v", err) + } + if msg.GetConnectionEstablished() == nil { + cancel() + return nil, trace.ConnectionProblem(nil, "received malformed connection established frame") + } + + source := &frameStream{ + stream: stream, + cancel: cancel, + } + + streamRW, err := streamutils.NewReadWriter(source) + if err != nil { + _ = source.Close() + return nil, trace.Wrap(err) + } + + return streamutils.NewConn(streamRW, src, dst), nil +} + +// Client manages connections to known peer proxies and allows to open +// connections to agents through them. type Client struct { sync.RWMutex ctx context.Context cancel context.CancelFunc config ClientConfig - conns map[string]*clientConn + conns map[string]clientConn metrics *clientMetrics reporter *reporter } @@ -246,7 +355,7 @@ func NewClient(config ClientConfig) (*Client, error) { config: config, ctx: closeContext, cancel: cancel, - conns: make(map[string]*clientConn), + conns: make(map[string]clientConn), metrics: metrics, reporter: reporter, } @@ -274,17 +383,20 @@ func (c *Client) monitor() { c.RLock() c.reporter.resetConnections() for _, conn := range c.conns { - switch conn.GetState() { - case connectivity.Idle: - c.reporter.incConnection(c.config.ID, conn.id, connectivity.Idle.String()) - case connectivity.Connecting: - c.reporter.incConnection(c.config.ID, conn.id, connectivity.Connecting.String()) - case connectivity.Ready: - c.reporter.incConnection(c.config.ID, conn.id, connectivity.Ready.String()) - case connectivity.TransientFailure: - c.reporter.incConnection(c.config.ID, conn.id, connectivity.TransientFailure.String()) - case connectivity.Shutdown: - c.reporter.incConnection(c.config.ID, conn.id, connectivity.Shutdown.String()) + switch conn := conn.(type) { + case *grpcClientConn: + switch conn.cc.GetState() { + case connectivity.Idle: + c.reporter.incConnection(c.config.ID, conn.id, connectivity.Idle.String()) + case connectivity.Connecting: + c.reporter.incConnection(c.config.ID, conn.id, connectivity.Connecting.String()) + case connectivity.Ready: + c.reporter.incConnection(c.config.ID, conn.id, connectivity.Ready.String()) + case connectivity.TransientFailure: + c.reporter.incConnection(c.config.ID, conn.id, connectivity.TransientFailure.String()) + case connectivity.Shutdown: + c.reporter.incConnection(c.config.ID, conn.id, connectivity.Shutdown.String()) + } } } c.RUnlock() @@ -335,7 +447,7 @@ func (c *Client) updateConnections(proxies []types.Server) error { } var toDelete []string - toKeep := make(map[string]*clientConn) + toKeep := make(map[string]clientConn) for id, conn := range c.conns { proxy, ok := toDial[id] @@ -346,7 +458,7 @@ func (c *Client) updateConnections(proxies []types.Server) error { } // peer address changed - if conn.addr != proxy.GetPeerAddr() { + if conn.peerAddr() != proxy.GetPeerAddr() { toDelete = append(toDelete, id) continue } @@ -384,7 +496,7 @@ func (c *Client) updateConnections(proxies []types.Server) error { for _, id := range toDelete { if conn, ok := c.conns[id]; ok { - go conn.Shutdown(c.ctx) + go conn.shutdown(c.ctx) } } c.conns = toKeep @@ -392,39 +504,6 @@ func (c *Client) updateConnections(proxies []types.Server) error { return trace.NewAggregate(errs...) } -// DialNode dials a node through a peer proxy. -func (c *Client) DialNode( - proxyIDs []string, - nodeID string, - src net.Addr, - dst net.Addr, - tunnelType types.TunnelType, -) (net.Conn, error) { - stream, _, err := c.dial(proxyIDs, &clientapi.DialRequest{ - NodeID: nodeID, - TunnelType: tunnelType, - Source: &clientapi.NetAddr{ - Addr: src.String(), - Network: src.Network(), - }, - Destination: &clientapi.NetAddr{ - Addr: dst.String(), - Network: dst.Network(), - }, - }) - if err != nil { - return nil, trace.ConnectionProblem(err, "error dialing peer proxies %s: %v", proxyIDs, err) - } - - streamRW, err := streamutils.NewReadWriter(stream) - if err != nil { - _ = stream.Close() - return nil, trace.Wrap(err) - } - - return streamutils.NewConn(streamRW, src, dst), nil -} - // stream is the common subset of the [clientapi.ProxyService_DialNodeClient] and // [clientapi.ProxyService_DialNodeServer] interfaces. type stream interface { @@ -470,9 +549,9 @@ func (c *Client) Shutdown(ctx context.Context) { var wg sync.WaitGroup for _, conn := range c.conns { wg.Add(1) - go func(conn *clientConn) { + go func(conn clientConn) { defer wg.Done() - conn.Shutdown(ctx) + conn.shutdown(ctx) }(conn) } wg.Wait() @@ -486,7 +565,7 @@ func (c *Client) Stop() error { var errs []error for _, conn := range c.conns { - if err := conn.Close(); err != nil { + if err := conn.close(); err != nil { errs = append(errs, err) } } @@ -500,67 +579,56 @@ func (c *Client) GetConnectionsCount() int { return len(c.conns) } -// dial opens a new stream to one of the supplied proxy ids. -// it tries to find an existing grpc.ClientConn or initializes a new rpc -// to one of the proxies otherwise. -// The boolean returned in the second argument is intended for testing purposes, -// to indicates whether the connection was cached or newly established. -func (c *Client) dial(proxyIDs []string, dialRequest *clientapi.DialRequest) (frameStream, bool, error) { - conns, existing, err := c.getConnections(proxyIDs) +// DialNode dials a node through a peer proxy. +func (c *Client) DialNode( + proxyIDs []string, + nodeID string, + src net.Addr, + dst net.Addr, + tunnelType types.TunnelType, +) (net.Conn, error) { + conn, _, err := c.dial( + proxyIDs, + nodeID, + src, + dst, + tunnelType, + ) if err != nil { - return frameStream{}, existing, trace.Wrap(err) + return nil, trace.Wrap(err) } - var errs []error - for _, conn := range conns { - release := conn.maybeAcquire() - if release == nil { - c.metrics.reportTunnelError(errorProxyPeerTunnelRPC) - errs = append(errs, trace.ConnectionProblem(nil, "error starting stream: connection is shutting down")) - continue - } - - ctx, cancel := context.WithCancel(context.Background()) - context.AfterFunc(ctx, release) + return conn, nil +} - stream, err := clientapi.NewProxyServiceClient(conn.ClientConn).DialNode(ctx) - if err != nil { - cancel() - c.metrics.reportTunnelError(errorProxyPeerTunnelRPC) - c.config.Log.Debugf("Error opening tunnel rpc to proxy %+v at %+v", conn.id, conn.addr) - errs = append(errs, trace.ConnectionProblem(err, "error starting stream: %v", err)) - continue - } +// dial opens a new connection through one of the given proxy ids. It tries to +// find an existing [clientConn] or initializes new clientConns to the given +// proxies otherwise. The boolean returned in the second argument is intended +// for testing purposes, to indicates whether the connection used an existing +// clientConn or a newly established one. +func (c *Client) dial( + proxyIDs []string, + nodeID string, + src net.Addr, + dst net.Addr, + tunnelType types.TunnelType, +) (net.Conn, bool, error) { + conns, existing, err := c.getConnections(proxyIDs) + if err != nil { + return nil, false, trace.Wrap(err) + } - err = stream.Send(&clientapi.Frame{ - Message: &clientapi.Frame_DialRequest{ - DialRequest: dialRequest, - }, - }) - if err != nil { - cancel() - errs = append(errs, trace.ConnectionProblem(err, "error sending dial frame: %v", err)) - continue - } - msg, err := stream.Recv() + var errs []error + for _, clientConn := range conns { + conn, err := clientConn.dial(nodeID, src, dst, tunnelType) if err != nil { - cancel() - errs = append(errs, trace.ConnectionProblem(err, "error receiving dial response: %v", err)) + errs = append(errs, trace.Wrap(err)) continue } - if msg.GetConnectionEstablished() == nil { - cancel() - errs = append(errs, trace.ConnectionProblem(nil, "received malformed connection established frame")) - continue - } - - return frameStream{ - stream: stream, - cancel: cancel, - }, existing, nil + return conn, existing, nil } - return frameStream{}, existing, trace.NewAggregate(errs...) + return nil, existing, trace.NewAggregate(errs...) } // getConnections returns connections to the supplied proxy ids. @@ -568,13 +636,13 @@ func (c *Client) dial(proxyIDs []string, dialRequest *clientapi.DialRequest) (fr // otherwise. // The boolean returned in the second argument is intended for testing purposes, // to indicates whether the connection was cached or newly established. -func (c *Client) getConnections(proxyIDs []string) ([]*clientConn, bool, error) { +func (c *Client) getConnections(proxyIDs []string) ([]clientConn, bool, error) { if len(proxyIDs) == 0 { return nil, false, trace.BadParameter("failed to dial: no proxy ids given") } ids := make(map[string]struct{}) - var conns []*clientConn + var conns []clientConn // look for existing matching connections. c.RLock() @@ -631,7 +699,7 @@ func (c *Client) getConnections(proxyIDs []string) ([]*clientConn, bool, error) defer c.Unlock() for _, conn := range conns { - c.conns[conn.id] = conn + c.conns[conn.peerID()] = conn } c.config.connShuffler(conns) @@ -639,7 +707,7 @@ func (c *Client) getConnections(proxyIDs []string) ([]*clientConn, bool, error) } // connect dials a new connection to proxyAddr. -func (c *Client) connect(peerID string, peerAddr string) (*clientConn, error) { +func (c *Client) connect(peerID string, peerAddr string) (clientConn, error) { tlsConfig := utils.TLSConfig(c.config.TLSCipherSuites) tlsConfig.ServerName = apiutils.EncodeClusterName(c.config.ClusterName) tlsConfig.GetClientCertificate = func(*tls.CertificateRequestInfo) (*tls.Certificate, error) { @@ -670,8 +738,9 @@ func (c *Client) connect(peerID string, peerAddr string) (*clientConn, error) { return nil, trace.Wrap(err, "Error dialing proxy %q", peerID) } - return &clientConn{ - ClientConn: conn, + return &grpcClientConn{ + cc: conn, + metrics: c.metrics, id: peerID, addr: peerAddr, diff --git a/lib/proxy/peer/client_test.go b/lib/proxy/peer/client_test.go index d4e1cd7c5df5..92a64904b7c0 100644 --- a/lib/proxy/peer/client_test.go +++ b/lib/proxy/peer/client_test.go @@ -30,6 +30,7 @@ import ( "github.com/gravitational/teleport/api/client/proto" clientapi "github.com/gravitational/teleport/api/client/proto" "github.com/gravitational/teleport/api/types" + "github.com/gravitational/teleport/lib/utils" ) // TestClientConn checks the client's connection caching capabilities @@ -46,33 +47,33 @@ func TestClientConn(t *testing.T) { require.Len(t, client.conns, 2) // dial first server and send a test data frame - stream, cached, err := client.dial([]string{"s1"}, &proto.DialRequest{}) + stream, cached, err := client.dial([]string{"s1"}, "", &utils.NetAddr{}, &utils.NetAddr{}, "") require.NoError(t, err) require.True(t, cached) - require.NotNil(t, stream.stream) + require.NotNil(t, stream) stream.Close() // dial second server - stream, cached, err = client.dial([]string{"s2"}, &proto.DialRequest{}) + stream, cached, err = client.dial([]string{"s2"}, "", &utils.NetAddr{}, &utils.NetAddr{}, "") require.NoError(t, err) require.True(t, cached) - require.NotNil(t, stream.stream) + require.NotNil(t, stream) stream.Close() // redial second server - stream, cached, err = client.dial([]string{"s2"}, &proto.DialRequest{}) + stream, cached, err = client.dial([]string{"s2"}, "", &utils.NetAddr{}, &utils.NetAddr{}, "") require.NoError(t, err) require.True(t, cached) - require.NotNil(t, stream.stream) + require.NotNil(t, stream) stream.Close() // close second server // and attempt to redial it server2.Shutdown() - stream, cached, err = client.dial([]string{"s2"}, &proto.DialRequest{}) + stream, cached, err = client.dial([]string{"s2"}, "", &utils.NetAddr{}, &utils.NetAddr{}, "") require.Error(t, err) require.True(t, cached) - require.Nil(t, stream.stream) + require.Nil(t, stream) } // TestClientUpdate checks the client's watcher update behavior @@ -90,12 +91,12 @@ func TestClientUpdate(t *testing.T) { require.Contains(t, client.conns, "s1") require.Contains(t, client.conns, "s2") - s1, _, err := client.dial([]string{"s1"}, &proto.DialRequest{}) + s1, _, err := client.dial([]string{"s1"}, "", &utils.NetAddr{}, &utils.NetAddr{}, "") require.NoError(t, err) - require.NotNil(t, s1.stream) - s2, _, err := client.dial([]string{"s2"}, &proto.DialRequest{}) + require.NotNil(t, s1) + s2, _, err := client.dial([]string{"s2"}, "", &utils.NetAddr{}, &utils.NetAddr{}, "") require.NoError(t, err) - require.NotNil(t, s2.stream) + require.NotNil(t, s2) // watcher finds one of the two servers err = client.updateConnections([]types.Server{def1}) @@ -114,7 +115,7 @@ func TestClientUpdate(t *testing.T) { require.Len(t, client.conns, 2) require.Contains(t, client.conns, "s1") sendMsg(t, s1) // stream is still going strong - _, _, err = client.dial([]string{"s2"}, &proto.DialRequest{}) + _, _, err = client.dial([]string{"s2"}, "", &utils.NetAddr{}, &utils.NetAddr{}, "") require.Error(t, err) // can't dial server2, obviously // peer address change @@ -124,7 +125,7 @@ func TestClientUpdate(t *testing.T) { require.Len(t, client.conns, 1) require.Contains(t, client.conns, "s1") sendMsg(t, s1) // stream is not forcefully closed. ClientConn waits for a graceful shutdown before it closes. - s3, _, err := client.dial([]string{"s1"}, &proto.DialRequest{}) + s3, _, err := client.dial([]string{"s1"}, "", &utils.NetAddr{}, &utils.NetAddr{}, "") require.NoError(t, err) require.NotNil(t, s3) @@ -145,9 +146,10 @@ func TestCAChange(t *testing.T) { conn, err := client.connect("s1", ts.GetPeerAddr()) require.NoError(t, err) require.NotNil(t, conn) + require.IsType(t, (*grpcClientConn)(nil), conn) ctx, cancel := context.WithCancel(context.Background()) defer cancel() - stream, err := clientapi.NewProxyServiceClient(conn.ClientConn).DialNode(ctx) + stream, err := clientapi.NewProxyServiceClient(conn.(*grpcClientConn).cc).DialNode(ctx) require.NoError(t, err) require.NotNil(t, stream) @@ -162,7 +164,8 @@ func TestCAChange(t *testing.T) { conn, err = client.connect("s1", ts.GetPeerAddr()) require.NoError(t, err) require.NotNil(t, conn) - stream, err = clientapi.NewProxyServiceClient(conn.ClientConn).DialNode(ctx) + require.IsType(t, (*grpcClientConn)(nil), conn) + stream, err = clientapi.NewProxyServiceClient(conn.(*grpcClientConn).cc).DialNode(ctx) require.Error(t, err) require.Nil(t, stream) @@ -173,7 +176,8 @@ func TestCAChange(t *testing.T) { conn, err = client.connect("s1", ts.GetPeerAddr()) require.NoError(t, err) require.NotNil(t, conn) - stream, err = clientapi.NewProxyServiceClient(conn.ClientConn).DialNode(ctx) + require.IsType(t, (*grpcClientConn)(nil), conn) + stream, err = clientapi.NewProxyServiceClient(conn.(*grpcClientConn).cc).DialNode(ctx) require.NoError(t, err) require.NotNil(t, stream) } @@ -196,17 +200,18 @@ func TestBackupClient(t *testing.T) { err := client.updateConnections([]types.Server{def1, def2}) require.NoError(t, err) - waitForConns(t, client.conns, time.Second*2) + waitForGRPCConns(t, client.conns, time.Second*2) - _, _, err = client.dial([]string{def1.GetName(), def2.GetName()}, &proto.DialRequest{}) + _, _, err = client.dial([]string{def1.GetName(), def2.GetName()}, "", &utils.NetAddr{}, &utils.NetAddr{}, "") require.NoError(t, err) require.True(t, dialCalled) } -func waitForConns(t *testing.T, conns map[string]*clientConn, d time.Duration) { +func waitForGRPCConns(t *testing.T, conns map[string]clientConn, d time.Duration) { require.Eventually(t, func() bool { for _, conn := range conns { - if conn.GetState() != connectivity.Ready { + // panic if we hit a non-grpc client conn + if conn.(*grpcClientConn).cc.GetState() != connectivity.Ready { return false } } diff --git a/lib/proxy/peer/helpers_test.go b/lib/proxy/peer/helpers_test.go index eb7eaccfe268..647ffbcea027 100644 --- a/lib/proxy/peer/helpers_test.go +++ b/lib/proxy/peer/helpers_test.go @@ -258,7 +258,7 @@ func setupServer(t *testing.T, name string, serverCA, clientCA *tlsca.CertAuthor return server, ts } -func sendMsg(t *testing.T, stream frameStream) { - err := stream.Send([]byte("ping")) +func sendMsg(t *testing.T, stream net.Conn) { + _, err := stream.Write([]byte("ping")) require.NoError(t, err) } diff --git a/lib/proxy/peer/server_test.go b/lib/proxy/peer/server_test.go index 3ff1765272d5..3c7a4c5cb4eb 100644 --- a/lib/proxy/peer/server_test.go +++ b/lib/proxy/peer/server_test.go @@ -23,8 +23,8 @@ import ( "github.com/stretchr/testify/require" - "github.com/gravitational/teleport/api/client/proto" "github.com/gravitational/teleport/api/types" + "github.com/gravitational/teleport/lib/utils" ) // TestServerTLS ensures that only trusted certificates with the proxy role @@ -38,7 +38,7 @@ func TestServerTLS(t *testing.T) { _, serverDef1 := setupServer(t, "s1", ca1, ca1, types.RoleProxy) err := client1.updateConnections([]types.Server{serverDef1}) require.NoError(t, err) - stream, _, err := client1.dial([]string{"s1"}, &proto.DialRequest{}) + stream, _, err := client1.dial([]string{"s1"}, "", &utils.NetAddr{}, &utils.NetAddr{}, "") require.NoError(t, err) require.NotNil(t, stream) stream.Close() @@ -48,7 +48,7 @@ func TestServerTLS(t *testing.T) { _, serverDef2 := setupServer(t, "s2", ca1, ca1, types.RoleProxy) err = client2.updateConnections([]types.Server{serverDef2}) require.NoError(t, err) // connection succeeds but is in transient failure state - _, _, err = client2.dial([]string{"s2"}, &proto.DialRequest{}) + _, _, err = client2.dial([]string{"s2"}, "", &utils.NetAddr{}, &utils.NetAddr{}, "") require.Error(t, err) // certificates with correct role from different CAs @@ -56,7 +56,7 @@ func TestServerTLS(t *testing.T) { _, serverDef3 := setupServer(t, "s3", ca2, ca1, types.RoleProxy) err = client3.updateConnections([]types.Server{serverDef3}) require.NoError(t, err) - stream, _, err = client3.dial([]string{"s3"}, &proto.DialRequest{}) + stream, _, err = client3.dial([]string{"s3"}, "", &utils.NetAddr{}, &utils.NetAddr{}, "") require.NoError(t, err) require.NotNil(t, stream) stream.Close() From b74222742e67a2e1e3589586bb1444c710de138e Mon Sep 17 00:00:00 2001 From: Edoardo Spadolini Date: Fri, 4 Oct 2024 18:54:13 +0200 Subject: [PATCH 3/3] Mock implementation of QUIC proxy peering --- api/types/constants.go | 5 ++ go.mod | 5 ++ go.sum | 5 +- lib/proxy/peer/client.go | 15 ++++- lib/proxy/peer/client_test.go | 7 ++- lib/proxy/peer/quicserver.go | 85 ++++++++++++++++++++++++++++ lib/service/service.go | 103 ++++++++++++++++++++++++++++++++-- lib/service/signals.go | 28 +++++++++ 8 files changed, 242 insertions(+), 11 deletions(-) create mode 100644 lib/proxy/peer/quicserver.go diff --git a/api/types/constants.go b/api/types/constants.go index 3ba98907db59..1f929d436b0d 100644 --- a/api/types/constants.go +++ b/api/types/constants.go @@ -1028,6 +1028,11 @@ const ( // group they should attempt to be connected to. ProxyGroupGenerationLabel = TeleportInternalLabelPrefix + "proxygroup-gen" + // ProxyPeerQUICLabel is the internal-user label for proxy heartbeats that's + // used to signal that the proxy supports receiving proxy peering + // connections over QUIC. + ProxyPeerQUICLabel = TeleportInternalLabelPrefix + "proxy-peer-quic" + // OktaAppNameLabel is the individual app name label. OktaAppNameLabel = TeleportInternalLabelPrefix + "okta-app-name" diff --git a/go.mod b/go.mod index cd0e4dfd05c7..b9c299c81a84 100644 --- a/go.mod +++ b/go.mod @@ -164,6 +164,7 @@ require ( github.com/prometheus/client_golang v1.20.4 github.com/prometheus/client_model v0.6.1 github.com/prometheus/common v0.55.0 + github.com/quic-go/quic-go v0.47.0 github.com/redis/go-redis/v9 v9.5.1 // replaced github.com/russellhaering/gosaml2 v0.9.1 github.com/russellhaering/goxmldsig v1.4.0 @@ -350,6 +351,7 @@ require ( github.com/go-openapi/strfmt v0.23.0 // indirect github.com/go-openapi/swag v0.23.0 // indirect github.com/go-openapi/validate v0.24.0 // indirect + github.com/go-task/slim-sprig/v3 v3.0.0 // indirect github.com/go-webauthn/x v0.1.14 // indirect github.com/gobuffalo/flect v1.0.2 // indirect github.com/gobwas/glob v0.2.3 // indirect @@ -369,6 +371,7 @@ require ( github.com/google/go-configfs-tsm v0.2.2 // indirect github.com/google/go-tspi v0.3.0 // indirect github.com/google/gofuzz v1.2.0 // indirect + github.com/google/pprof v0.0.0-20240525223248-4bfdf5a9a2af // indirect github.com/google/s2a-go v0.1.8 // indirect github.com/googleapis/enterprise-certificate-proxy v0.3.4 // indirect github.com/gorilla/handlers v1.5.2 // indirect @@ -449,6 +452,7 @@ require ( github.com/nsf/termbox-go v1.1.1 // indirect github.com/oklog/ulid v1.3.1 // indirect github.com/olekukonko/tablewriter v0.0.5 // indirect + github.com/onsi/ginkgo/v2 v2.19.0 // indirect github.com/opencontainers/image-spec v1.1.0 // indirect github.com/opentracing/opentracing-go v1.2.0 // indirect github.com/pascaldekloe/name v1.0.1 // indirect @@ -524,6 +528,7 @@ require ( go.opentelemetry.io/otel/sdk/metric v1.29.0 // indirect go.starlark.net v0.0.0-20230525235612-a134d8f9ddca // indirect go.uber.org/atomic v1.11.0 // indirect + go.uber.org/mock v0.4.0 // indirect go.uber.org/multierr v1.11.0 // indirect golang.org/x/tools v0.24.0 // indirect golang.org/x/xerrors v0.0.0-20240716161551-93cc26a95ae9 // indirect diff --git a/go.sum b/go.sum index 6ad194e89a40..9554c063666c 100644 --- a/go.sum +++ b/go.sum @@ -1281,7 +1281,6 @@ github.com/go-sql-driver/mysql v1.5.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LB github.com/go-sql-driver/mysql v1.8.1 h1:LedoTUt/eveggdHS9qUFC1EFSa8bU2+1pZjSRpvNJ1Y= github.com/go-sql-driver/mysql v1.8.1/go.mod h1:wEBSXgmK//2ZFJyE+qWnIsVGmvmEKlqwuVSjsCm7DZg= github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY= -github.com/go-task/slim-sprig v0.0.0-20210107165309-348f09dbbbc0 h1:p104kn46Q8WdvHunIJ9dAyjPVtrBPhSr3KT2yUst43I= github.com/go-task/slim-sprig v0.0.0-20210107165309-348f09dbbbc0/go.mod h1:fyg7847qk6SyHyPtNmDHnmrv/HOrqktSC+C9fM+CJOE= github.com/go-task/slim-sprig/v3 v3.0.0 h1:sUs3vkvUymDpBKi3qH1YSqBQk9+9D/8M2mN1vB6EwHI= github.com/go-task/slim-sprig/v3 v3.0.0/go.mod h1:W848ghGpv3Qj3dhTPRyJypKRiqCdHZiAzKg9hl15HA8= @@ -2009,6 +2008,8 @@ github.com/prometheus/procfs v0.15.1 h1:YagwOFzUgYfKKHX6Dr+sHT7km/hxC76UB0leargg github.com/prometheus/procfs v0.15.1/go.mod h1:fB45yRUv8NstnjriLhBQLuOUt+WW4BsoGhij/e3PBqk= github.com/protocolbuffers/txtpbfmt v0.0.0-20231025115547-084445ff1adf h1:014O62zIzQwvoD7Ekj3ePDF5bv9Xxy0w6AZk0qYbjUk= github.com/protocolbuffers/txtpbfmt v0.0.0-20231025115547-084445ff1adf/go.mod h1:jgxiZysxFPM+iWKwQwPR+y+Jvo54ARd4EisXxKYpB5c= +github.com/quic-go/quic-go v0.47.0 h1:yXs3v7r2bm1wmPTYNLKAAJTHMYkPEsfYJmTazXrCZ7Y= +github.com/quic-go/quic-go v0.47.0/go.mod h1:3bCapYsJvXGZcipOHuu7plYtaV6tnF+z7wIFsU0WK9E= github.com/rcrowley/go-metrics v0.0.0-20201227073835-cf1acfcdf475 h1:N/ElC8H3+5XpJzTSTfLsJV/mx9Q9g7kxmchpfZyxgzM= github.com/rcrowley/go-metrics v0.0.0-20201227073835-cf1acfcdf475/go.mod h1:bCqnVzQkZxMG4s8nGwiZ5l3QUCyqpo9Y+/ZMZ9VjZe4= github.com/remyoudompheng/bigfft v0.0.0-20190728182440-6a916e37a237/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo= @@ -2317,6 +2318,8 @@ go.uber.org/atomic v1.11.0 h1:ZvwS0R+56ePWxUNi+Atn9dWONBPp/AUETXlHW0DxSjE= go.uber.org/atomic v1.11.0/go.mod h1:LUxbIzbOniOlMKjJjyPfpl4v+PKK2cNJn91OQbhoJI0= go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto= go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE= +go.uber.org/mock v0.4.0 h1:VcM4ZOtdbR4f6VXfiOpwpVJDL6lCReaZ6mw31wqh7KU= +go.uber.org/mock v0.4.0/go.mod h1:a6FSlNadKUHUa9IP5Vyt1zh4fC7uAwxMutEAscFbkZc= go.uber.org/multierr v1.1.0/go.mod h1:wR5kodmAFQ0UK8QlbwjlSNy0Z68gJhDJUG5sjR94q/0= go.uber.org/multierr v1.3.0/go.mod h1:VgVr7evmIr6uPjLBxg28wmKNXyqE9akIJ5XnfpiKl+4= go.uber.org/multierr v1.5.0/go.mod h1:FeouvMocqHpRaaGuG9EjoKcStLC43Zu/fmqdUMPcKYU= diff --git a/lib/proxy/peer/client.go b/lib/proxy/peer/client.go index 594795fddb5b..f6a9e0aff424 100644 --- a/lib/proxy/peer/client.go +++ b/lib/proxy/peer/client.go @@ -28,6 +28,7 @@ import ( "github.com/gravitational/trace" "github.com/jonboulle/clockwork" + "github.com/quic-go/quic-go" "github.com/sirupsen/logrus" "google.golang.org/grpc" "google.golang.org/grpc/connectivity" @@ -79,6 +80,9 @@ type ClientConfig struct { GracefulShutdownTimeout time.Duration // ClusterName is the name of the cluster. ClusterName string + // QUICTransport, if set, will be used to dial peer proxies that advertise + // support for peering connections over QUIC. + QUICTransport *quic.Transport // connShuffler determines the order client connections will be used. connShuffler connShuffler @@ -479,7 +483,8 @@ func (c *Client) updateConnections(proxies []types.Server) error { } // establish new connections - conn, err := c.connect(id, proxy.GetPeerAddr()) + _, supportsQuic := proxy.GetLabel(types.ProxyPeerQUICLabel) + conn, err := c.connect(id, proxy.GetPeerAddr(), supportsQuic) if err != nil { c.metrics.reportTunnelError(errorProxyPeerTunnelDial) c.config.Log.Debugf("Error dialing peer proxy %+v at %+v", id, proxy.GetPeerAddr()) @@ -679,7 +684,8 @@ func (c *Client) getConnections(proxyIDs []string) ([]clientConn, bool, error) { continue } - conn, err := c.connect(id, proxy.GetPeerAddr()) + _, supportsQuic := proxy.GetLabel(types.ProxyPeerQUICLabel) + conn, err := c.connect(id, proxy.GetPeerAddr(), supportsQuic) if err != nil { c.metrics.reportTunnelError(errorProxyPeerTunnelDirectDial) c.config.Log.Debugf("Error direct dialing peer proxy %+v at %+v", id, proxy.GetPeerAddr()) @@ -707,7 +713,10 @@ func (c *Client) getConnections(proxyIDs []string) ([]clientConn, bool, error) { } // connect dials a new connection to proxyAddr. -func (c *Client) connect(peerID string, peerAddr string) (clientConn, error) { +func (c *Client) connect(peerID string, peerAddr string, supportsQUIC bool) (clientConn, error) { + if supportsQUIC && c.config.QUICTransport != nil { + panic("QUIC proxy peering is not implemented") + } tlsConfig := utils.TLSConfig(c.config.TLSCipherSuites) tlsConfig.ServerName = apiutils.EncodeClusterName(c.config.ClusterName) tlsConfig.GetClientCertificate = func(*tls.CertificateRequestInfo) (*tls.Certificate, error) { diff --git a/lib/proxy/peer/client_test.go b/lib/proxy/peer/client_test.go index 92a64904b7c0..49df7c97b28b 100644 --- a/lib/proxy/peer/client_test.go +++ b/lib/proxy/peer/client_test.go @@ -143,7 +143,8 @@ func TestCAChange(t *testing.T) { t.Cleanup(func() { server.Close() }) // dial server and send a test data frame - conn, err := client.connect("s1", ts.GetPeerAddr()) + const supportsQUICFalse = false + conn, err := client.connect("s1", ts.GetPeerAddr(), supportsQUICFalse) require.NoError(t, err) require.NotNil(t, conn) require.IsType(t, (*grpcClientConn)(nil), conn) @@ -161,7 +162,7 @@ func TestCAChange(t *testing.T) { // new connection should fail because client tls config still references old // RootCAs. - conn, err = client.connect("s1", ts.GetPeerAddr()) + conn, err = client.connect("s1", ts.GetPeerAddr(), supportsQUICFalse) require.NoError(t, err) require.NotNil(t, conn) require.IsType(t, (*grpcClientConn)(nil), conn) @@ -173,7 +174,7 @@ func TestCAChange(t *testing.T) { // RootCAs. currentServerCA.Store(newServerCA) - conn, err = client.connect("s1", ts.GetPeerAddr()) + conn, err = client.connect("s1", ts.GetPeerAddr(), supportsQUICFalse) require.NoError(t, err) require.NotNil(t, conn) require.IsType(t, (*grpcClientConn)(nil), conn) diff --git a/lib/proxy/peer/quicserver.go b/lib/proxy/peer/quicserver.go new file mode 100644 index 000000000000..574857533ba2 --- /dev/null +++ b/lib/proxy/peer/quicserver.go @@ -0,0 +1,85 @@ +/* + * Teleport + * Copyright (C) 2023 Gravitational, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +package peer + +import ( + "context" + "crypto/tls" + "crypto/x509" + "log/slog" + + "github.com/gravitational/trace" + "github.com/quic-go/quic-go" + + "github.com/gravitational/teleport" +) + +type QUICServerConfig struct { + Log *slog.Logger + ClusterDialer ClusterDialer + + CipherSuites []uint16 + GetCertificate func(*tls.ClientHelloInfo) (*tls.Certificate, error) + GetClientCAs func(*tls.ClientHelloInfo) (*x509.CertPool, error) +} + +func (c *QUICServerConfig) checkAndSetDefaults() error { + if c.Log == nil { + c.Log = slog.Default() + } + c.Log = c.Log.With( + teleport.ComponentKey, + teleport.Component(teleport.ComponentProxy, "qpeer"), + ) + + if c.ClusterDialer == nil { + return trace.BadParameter("missing cluster dialer") + } + + if c.GetCertificate == nil { + return trace.BadParameter("missing GetCertificate") + } + if c.GetClientCAs == nil { + return trace.BadParameter("missing GetClientCAs") + } + + return nil +} + +// QUICServer is a proxy peering server that uses the QUIC protocol. +type QUICServer struct{} + +func NewQUICServer(cfg QUICServerConfig) (*QUICServer, error) { + if err := cfg.checkAndSetDefaults(); err != nil { + return nil, trace.Wrap(err) + } + panic("QUIC proxy peering is not implemented") +} + +func (s *QUICServer) Serve(t *quic.Transport) error { + panic("QUIC proxy peering is not implemented") +} + +func (s *QUICServer) Close() error { + panic("QUIC proxy peering is not implemented") +} + +func (s *QUICServer) Shutdown(ctx context.Context) error { + panic("QUIC proxy peering is not implemented") +} diff --git a/lib/service/service.go b/lib/service/service.go index b2b141a64b75..afd5b0065a77 100644 --- a/lib/service/service.go +++ b/lib/service/service.go @@ -55,6 +55,7 @@ import ( "github.com/gravitational/trace" "github.com/jonboulle/clockwork" "github.com/prometheus/client_golang/prometheus/promhttp" + "github.com/quic-go/quic-go" "github.com/sirupsen/logrus" "go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc" "go.opentelemetry.io/otel/attribute" @@ -4311,9 +4312,31 @@ func (process *TeleportProcess) initProxyEndpoint(conn *Connector) error { // from remote teleport nodes var tsrv reversetunnelclient.Server var peerClient *peer.Client - + var peerQUICTransport *quic.Transport if !process.Config.Proxy.DisableReverseTunnel { if listeners.proxyPeer != nil { + // TODO(espadolini): allow this when the implementation is merged + if false && os.Getenv("TELEPORT_UNSTABLE_QUIC_PROXY_PEERING") == "yes" { + // the stateless reset key is important in case there's a crash + // so peers can be told to close their side of the connections + // instead of having to wait for a timeout; for this reason, we + // store it in the datadir, which should persist just as much as + // the host ID and the cluster credentials + resetKey, err := process.readOrInitPeerStatelessResetKey() + if err != nil { + return trace.Wrap(err) + } + pc, err := process.createPacketConn(string(ListenerProxyPeer), listeners.proxyPeer.Addr().String()) + if err != nil { + return trace.Wrap(err) + } + peerQUICTransport = &quic.Transport{ + Conn: pc, + + StatelessResetKey: resetKey, + } + } + peerClient, err = peer.NewClient(peer.ClientConfig{ Context: process.ExitContext(), ID: process.Config.HostUUID, @@ -4325,6 +4348,7 @@ func (process *TeleportProcess) initProxyEndpoint(conn *Connector) error { Log: process.log, Clock: process.Clock, ClusterName: clusterName, + QUICTransport: peerQUICTransport, }) if err != nil { return trace.Wrap(err) @@ -4677,6 +4701,7 @@ func (process *TeleportProcess) initProxyEndpoint(conn *Connector) error { var peerAddrString string var peerServer *peer.Server + var peerQUICServer *peer.QUICServer if !process.Config.Proxy.DisableReverseTunnel && listeners.proxyPeer != nil { peerAddr, err := process.Config.Proxy.PublicPeerAddr() if err != nil { @@ -4705,11 +4730,11 @@ func (process *TeleportProcess) initProxyEndpoint(conn *Connector) error { process.RegisterCriticalFunc("proxy.peer", func() error { if _, err := process.WaitForEvent(process.ExitContext(), ProxyReverseTunnelReady); err != nil { - logger.DebugContext(process.ExitContext(), "Process exiting: failed to start peer proxy service waiting for reverse tunnel server") + logger.DebugContext(process.ExitContext(), "Process exiting: failed to start peer proxy service waiting for reverse tunnel server.") return nil } - logger.InfoContext(process.ExitContext(), "Starting peer proxy service", "listen_address", logutils.StringerAttr(listeners.proxyPeer.Addr())) + logger.InfoContext(process.ExitContext(), "Starting peer proxy service.", "listen_address", logutils.StringerAttr(listeners.proxyPeer.Addr())) err := peerServer.Serve(listeners.proxyPeer) if err != nil { return trace.Wrap(err) @@ -4717,9 +4742,45 @@ func (process *TeleportProcess) initProxyEndpoint(conn *Connector) error { return nil }) + + if peerQUICTransport != nil { + peerQUICServer, err := peer.NewQUICServer(peer.QUICServerConfig{ + Log: process.logger, + ClusterDialer: clusterdial.NewClusterDialer(tsrv), + CipherSuites: cfg.CipherSuites, + GetCertificate: func(*tls.ClientHelloInfo) (*tls.Certificate, error) { + return conn.serverGetCertificate() + }, + GetClientCAs: func(chi *tls.ClientHelloInfo) (*x509.CertPool, error) { + pool, _, err := authclient.ClientCertPool(chi.Context(), accessPoint, clusterName, types.HostCA) + if err != nil { + return nil, trace.Wrap(err) + } + return pool, nil + }, + }) + if err != nil { + return trace.Wrap(err) + } + + process.RegisterCriticalFunc("proxy.peer.quic", func() error { + if _, err := process.WaitForEvent(process.ExitContext(), ProxyReverseTunnelReady); err != nil { + logger.DebugContext(process.ExitContext(), "Process exiting: failed to start QUIC peer proxy service waiting for reverse tunnel server.") + return nil + } + + logger.InfoContext(process.ExitContext(), "Starting QUIC peer proxy service.", "local_addr", logutils.StringerAttr(peerQUICTransport.Conn.LocalAddr())) + err := peerQUICServer.Serve(peerQUICTransport) + if err != nil { + return trace.Wrap(err) + } + + return nil + }) + } } - staticLabels := make(map[string]string, 2) + staticLabels := make(map[string]string, 3) if cfg.Proxy.ProxyGroupID != "" { staticLabels[types.ProxyGroupIDLabel] = cfg.Proxy.ProxyGroupID } @@ -4729,6 +4790,10 @@ func (process *TeleportProcess) initProxyEndpoint(conn *Connector) error { if len(staticLabels) > 0 { logger.InfoContext(process.ExitContext(), "Enabling proxy group labels.", "group_id", cfg.Proxy.ProxyGroupID, "generation", cfg.Proxy.ProxyGroupGeneration) } + if peerQUICTransport != nil { + staticLabels[types.ProxyPeerQUICLabel] = "x" + logger.InfoContext(process.ExitContext(), "Advertising proxy peering QUIC support.") + } sshProxy, err := regular.New( process.ExitContext(), @@ -5256,6 +5321,9 @@ func (process *TeleportProcess) initProxyEndpoint(conn *Connector) error { if peerServer != nil { warnOnErr(process.ExitContext(), peerServer.Close(), logger) } + if peerQUICServer != nil { + warnOnErr(process.ExitContext(), peerQUICServer.Close(), logger) + } if webServer != nil { warnOnErr(process.ExitContext(), webServer.Close(), logger) } @@ -5307,6 +5375,9 @@ func (process *TeleportProcess) initProxyEndpoint(conn *Connector) error { if peerServer != nil { warnOnErr(ctx, peerServer.Shutdown(), logger) } + if peerQUICServer != nil { + warnOnErr(ctx, peerQUICServer.Shutdown(ctx), logger) + } if peerClient != nil { peerClient.Shutdown(ctx) } @@ -5343,6 +5414,10 @@ func (process *TeleportProcess) initProxyEndpoint(conn *Connector) error { clientTLSConfigGenerator.Close() } } + if peerQUICTransport != nil { + _ = peerQUICTransport.Close() + _ = peerQUICTransport.Conn.Close() + } warnOnErr(process.ExitContext(), asyncEmitter.Close(), logger) warnOnErr(process.ExitContext(), conn.Close(), logger) logger.InfoContext(process.ExitContext(), "Exited.") @@ -5417,6 +5492,26 @@ func (process *TeleportProcess) initMinimalReverseTunnel(listeners *proxyListene return minimalWebServer, nil } +func (process *TeleportProcess) readOrInitPeerStatelessResetKey() (*quic.StatelessResetKey, error) { + resetKeyPath := filepath.Join(process.Config.DataDir, "peer_stateless_reset_key") + k := new(quic.StatelessResetKey) + stored, err := os.ReadFile(resetKeyPath) + if err == nil && len(stored) == len(k) { + copy(k[:], stored) + return k, nil + } + if !errors.Is(err, os.ErrNotExist) { + process.logger.WarnContext(process.ExitContext(), "Stateless reset key file unreadable or invalid.", "error", err) + } + if _, err := rand.Read(k[:]); err != nil { + return nil, trace.ConvertSystemError(err) + } + if err := renameio.WriteFile(resetKeyPath, k[:], 0o600); err != nil { + process.logger.WarnContext(process.ExitContext(), "Failed to persist stateless reset key.", "error", err) + } + return k, nil +} + // kubeDialAddr returns Proxy Kube service address used for dialing local kube service // by remote trusted cluster. // If the proxy is running with Multiplex mode the WebPort is returned diff --git a/lib/service/signals.go b/lib/service/signals.go index 941a95348728..8bb32f5675f0 100644 --- a/lib/service/signals.go +++ b/lib/service/signals.go @@ -344,6 +344,34 @@ func (process *TeleportProcess) createListener(typ ListenerType, address string) return listener, nil } +// createPacketConn opens a UDP socket with the given address. UDP sockets are +// never passed on to a different process, so they're not registered anywhere. +func (process *TeleportProcess) createPacketConn(typ string, address string) (net.PacketConn, error) { + listenersClosed := func() bool { + process.Lock() + defer process.Unlock() + return process.listenersClosed + } + + if listenersClosed() { + process.logger.DebugContext(process.ExitContext(), "Listening is blocked, not opening packet conn.", "type", typ, "address", address) + return nil, trace.BadParameter("listening is blocked") + } + + pc, err := net.ListenPacket("udp", address) + if err != nil { + return nil, trace.Wrap(err) + } + + if listenersClosed() { + _ = pc.Close() + process.logger.DebugContext(process.ExitContext(), "Listening is blocked, closing newly-created packet conn.", "type", typ, "address", address) + return nil, trace.BadParameter("listening is blocked") + } + + return pc, nil +} + // getListenerNeedsLock tries to get an existing listener that matches the type/addr. func (process *TeleportProcess) getListenerNeedsLock(typ ListenerType, address string) (listener net.Listener, ok bool) { for _, l := range process.registeredListeners {