diff --git a/internal/proxy/proxy.go b/internal/proxy/proxy.go index fd5a68fbc..89c471910 100644 --- a/internal/proxy/proxy.go +++ b/internal/proxy/proxy.go @@ -747,6 +747,14 @@ type socketMount struct { dialOpts []cloudsqlconn.DialOption } +func networkType(conf *Config, inst InstanceConnConfig) string { + if (conf.UnixSocket == "" && inst.UnixSocket == "" && inst.UnixSocketPath == "") || + (inst.Addr != "" || inst.Port != 0) { + return "tcp" + } + return "unix" +} + func (c *Client) newSocketMount(ctx context.Context, conf *Config, pc *portConfig, inst InstanceConnConfig) (*socketMount, error) { var ( // network is one of "tcp" or "unix" @@ -765,8 +773,7 @@ func (c *Client) newSocketMount(ctx context.Context, conf *Config, pc *portConfi // instance) // use a TCP listener. // Otherwise, use a Unix socket. - if (conf.UnixSocket == "" && inst.UnixSocket == "" && inst.UnixSocketPath == "") || - (inst.Addr != "" || inst.Port != 0) { + if networkType(conf, inst) == "tcp" { network = "tcp" a := conf.Addr @@ -782,8 +789,9 @@ func (c *Client) newSocketMount(ctx context.Context, conf *Config, pc *portConfi np = pc.nextPort() default: version, err := c.dialer.EngineVersion(ctx, inst.Name) + // Exit if the port is not specified for inactive instance if err != nil { - c.logger.Errorf("could not resolve version for %q: %v", inst.Name, err) + c.logger.Errorf("[%v] could not resolve instance version: %v", inst.Name, err) return nil, err } np = pc.nextDBPort(version) @@ -795,12 +803,13 @@ func (c *Client) newSocketMount(ctx context.Context, conf *Config, pc *portConfi version, err := c.dialer.EngineVersion(ctx, inst.Name) if err != nil { - c.logger.Errorf("could not resolve version for %q: %v", inst.Name, err) + c.logger.Errorf("[%v] could not resolve instance version: %v", inst.Name, err) return nil, err } address, err = newUnixSocketMount(inst, conf.UnixSocket, strings.HasPrefix(version, "POSTGRES")) if err != nil { + c.logger.Errorf("[%v] could not mount unix socket %q: %v", inst.Name, conf.UnixSocket, err) return nil, err } } @@ -808,6 +817,7 @@ func (c *Client) newSocketMount(ctx context.Context, conf *Config, pc *portConfi lc := net.ListenConfig{KeepAlive: 30 * time.Second} ln, err := lc.Listen(ctx, network, address) if err != nil { + c.logger.Errorf("[%v] could not listen to address %v: %v", inst.Name, address, err) return nil, err } // Change file permissions to allow access for user, group, and other. diff --git a/internal/proxy/proxy_test.go b/internal/proxy/proxy_test.go index 8d509295e..807a27264 100644 --- a/internal/proxy/proxy_test.go +++ b/internal/proxy/proxy_test.go @@ -17,6 +17,7 @@ package proxy_test import ( "context" "errors" + "fmt" "io" "net" "os" @@ -83,6 +84,8 @@ func (f *fakeDialer) EngineVersion(_ context.Context, inst string) (string, erro return "MYSQL_8_0", nil case strings.Contains(inst, "sqlserver"): return "SQLSERVER_2019_STANDARD", nil + case strings.Contains(inst, "fakeserver"): + return "", fmt.Errorf("non existing server") default: return "POSTGRES_14", nil } @@ -306,6 +309,17 @@ func TestClientInitialization(t *testing.T) { filepath.Join(testUnixSocketPathPg), }, }, + { + desc: "with TCP port for non functional instance", + in: &proxy.Config{ + Instances: []proxy.InstanceConnConfig{ + {Name: "proj:region:fakeserver", Port: 50000}, + }, + }, + wantTCPAddrs: []string{ + "127.0.0.1:50000", + }, + }, } for _, tc := range tcs { @@ -711,3 +725,86 @@ func TestRunConnectionCheck(t *testing.T) { } } + +func TestProxyInitializationWithFailedUnixSocket(t *testing.T) { + ctx := context.Background() + testDir, _ := createTempDir(t) + testUnixSocketPath := path.Join(testDir, "db") + + tcs := []struct { + desc string + in *proxy.Config + }{ + { + desc: "with unix socket for non functional instance", + in: &proxy.Config{ + Instances: []proxy.InstanceConnConfig{ + { + Name: "proj:region:fakeserver", + UnixSocketPath: testUnixSocketPath, + }, + }, + }, + }, + { + desc: "without TCP port or unix socket for non functional instance", + in: &proxy.Config{ + Instances: []proxy.InstanceConnConfig{ + {Name: "proj:region:fakeserver"}, + }, + }, + }, + } + for _, tc := range tcs { + t.Run(tc.desc, func(t *testing.T) { + _, err := proxy.NewClient(ctx, &fakeDialer{}, testLogger, tc.in) + if err == nil { + t.Fatalf("want non nil error, got = %v", err) + } + }) + } +} + +func TestProxyMultiInstances(t *testing.T) { + ctx := context.Background() + testDir, _ := createTempDir(t) + testUnixSocketPath := path.Join(testDir, "db") + + tcs := []struct { + desc string + in *proxy.Config + wantSuccess bool + }{ + { + desc: "with tcp socket and unix for non functional instance", + in: &proxy.Config{ + Instances: []proxy.InstanceConnConfig{ + { + Name: "proj:region:fakeserver", + UnixSocketPath: testUnixSocketPath, + }, + {Name: mysql, Port: 3306}, + }, + }, + wantSuccess: false, + }, + { + desc: "with two tcp socket instances and conflicting ports", + in: &proxy.Config{ + Instances: []proxy.InstanceConnConfig{ + {Name: "proj:region:fakeserver", Port: 60000}, + {Name: mysql, Port: 60000}, + }, + }, + wantSuccess: false, + }, + } + for _, tc := range tcs { + t.Run(tc.desc, func(t *testing.T) { + _, err := proxy.NewClient(ctx, &fakeDialer{}, testLogger, tc.in) + if tc.wantSuccess != (err == nil) { + t.Fatalf("want return = %v, got = %v", tc.wantSuccess, err == nil) + } + }) + } +}