diff --git a/server/server.go b/server/server.go index 6f52d0ddb2d19..608f387004635 100644 --- a/server/server.go +++ b/server/server.go @@ -32,6 +32,7 @@ import ( "crypto/tls" "crypto/x509" "fmt" + "io" "io/ioutil" "math/rand" "net" @@ -80,6 +81,7 @@ type Server struct { tlsConfig *tls.Config driver IDriver listener net.Listener + socket net.Listener rwlock *sync.RWMutex concurrentLimiter *TokenLimiter clients map[uint32]*clientConn @@ -133,6 +135,39 @@ func (s *Server) isUnixSocket() bool { return s.cfg.Socket != "" } +func (s *Server) forwardUnixSocketToTCP() { + addr := fmt.Sprintf("%s:%d", s.cfg.Host, s.cfg.Port) + for { + if s.listener == nil { + return // server shutdown has started + } + if uconn, err := s.socket.Accept(); err == nil { + log.Infof("server socket forwarding from [%s] to [%s]", s.cfg.Socket, addr) + go s.handleForwardedConnection(uconn, addr) + } else { + if s.listener != nil { + log.Errorf("server failed to forward from [%s] to [%s], err: %s", s.cfg.Socket, addr, err) + } + } + } +} + +func (s *Server) handleForwardedConnection(uconn net.Conn, addr string) { + defer terror.Call(uconn.Close) + if tconn, err := net.Dial("tcp", addr); err == nil { + go func() { + if _, err := io.Copy(uconn, tconn); err != nil { + log.Warningf("copy server to socket failed: %s", err) + } + }() + if _, err := io.Copy(tconn, uconn); err != nil { + log.Warningf("socket forward copy failed: %s", err) + } + } else { + log.Warningf("socket forward failed: could not connect to [%s], err: %s", addr, err) + } +} + // NewServer creates a new Server. func NewServer(cfg *config.Config, driver IDriver) (*Server, error) { s := &Server{ @@ -151,15 +186,24 @@ func NewServer(cfg *config.Config, driver IDriver) (*Server, error) { } var err error - if cfg.Socket != "" { - if s.listener, err = net.Listen("unix", cfg.Socket); err == nil { - log.Infof("Server is running MySQL Protocol through Socket [%s]", cfg.Socket) - } - } else { + + if s.cfg.Host != "" && s.cfg.Port != 0 { addr := fmt.Sprintf("%s:%d", s.cfg.Host, s.cfg.Port) if s.listener, err = net.Listen("tcp", addr); err == nil { log.Infof("Server is running MySQL Protocol at [%s]", addr) + if cfg.Socket != "" { + if s.socket, err = net.Listen("unix", s.cfg.Socket); err == nil { + log.Infof("Server redirecting [%s] to [%s]", s.cfg.Socket, addr) + go s.forwardUnixSocketToTCP() + } + } } + } else if cfg.Socket != "" { + if s.listener, err = net.Listen("unix", cfg.Socket); err == nil { + log.Infof("Server is running MySQL Protocol through Socket [%s]", cfg.Socket) + } + } else { + err = errors.New("Server not configured to listen on either -socket or -host and -port") } if cfg.ProxyProtocol.Networks != "" { @@ -292,6 +336,11 @@ func (s *Server) Close() { terror.Log(errors.Trace(err)) s.listener = nil } + if s.socket != nil { + err := s.socket.Close() + terror.Log(errors.Trace(err)) + s.socket = nil + } if s.statusServer != nil { err := s.statusServer.Close() terror.Log(errors.Trace(err)) @@ -419,7 +468,7 @@ func (s *Server) kickIdleConnection() { for _, cc := range conns { err := cc.Close() if err != nil { - log.Error("close connection error:", err) + log.Errorf("close connection error: %s", err) } } } diff --git a/server/tidb_test.go b/server/tidb_test.go index 2a7a18264f15e..200263caf5931 100644 --- a/server/tidb_test.go +++ b/server/tidb_test.go @@ -160,9 +160,34 @@ func (ts *TidbTestSuite) TestMultiStatements(c *C) { runTestMultiStatements(c) } +func (ts *TidbTestSuite) TestSocketForwarding(c *C) { + cfg := config.NewConfig() + cfg.Socket = "/tmp/tidbtest.sock" + cfg.Port = 3999 + os.Remove(cfg.Socket) + cfg.Status.ReportStatus = false + + server, err := NewServer(cfg, ts.tidbdrv) + c.Assert(err, IsNil) + go server.Run() + time.Sleep(time.Millisecond * 100) + defer server.Close() + + runTestRegression(c, func(config *mysql.Config) { + config.User = "root" + config.Net = "unix" + config.Addr = "/tmp/tidbtest.sock" + config.DBName = "test" + config.Strict = true + }, "SocketRegression") +} + func (ts *TidbTestSuite) TestSocket(c *C) { cfg := config.NewConfig() cfg.Socket = "/tmp/tidbtest.sock" + cfg.Port = 0 + os.Remove(cfg.Socket) + cfg.Host = "" cfg.Status.ReportStatus = false server, err := NewServer(cfg, ts.tidbdrv) @@ -178,6 +203,7 @@ func (ts *TidbTestSuite) TestSocket(c *C) { config.DBName = "test" config.Strict = true }, "SocketRegression") + } // generateCert generates a private key and a certificate in PEM format based on parameters.