diff --git a/relay.go b/relay.go index 9f32da4..1df5183 100644 --- a/relay.go +++ b/relay.go @@ -59,6 +59,9 @@ type Relay struct { m sync.Mutex protocolSwitching map[int]string + + // connPool contains a list of ACTIVE connections + connPool []net.Conn } const ( @@ -79,7 +82,7 @@ const ( // VERSION uses semantic versioning // this version number is for the library not the CLI - VERSION = "v1.3.4" + VERSION = "v1.4.0" ) var ( @@ -284,3 +287,26 @@ func (r *Relay) Serve(l net.Listener) error { return ErrUnknownProxyType } } + +// storeConn places the provided net.Conn into the connPoll. +// To remove this conn from the pool, provide it to popConn() +func (r *Relay) storeConn(conn net.Conn) { + r.m.Lock() + defer r.m.Unlock() + + r.connPool = append(r.connPool, conn) +} + +// popConn removes the provided connection from the conn pool +func (r *Relay) popConn(conn net.Conn) { + r.m.Lock() + defer r.m.Unlock() + + for i := 0; i < len(r.connPool); i++ { + if r.connPool[i] == conn { + // remove conn + r.connPool = append(r.connPool[:i], r.connPool[i+1:]...) + return + } + } +} diff --git a/relay_test.go b/relay_test.go new file mode 100644 index 0000000..67165b5 --- /dev/null +++ b/relay_test.go @@ -0,0 +1,113 @@ +package localrelay + +import ( + "io" + "net" + "sync" + "testing" + "time" +) + +func TestConnPoolBasic(t *testing.T) { + conns := []net.Conn{} + connAmount := 50 + relay := New("test-relay", "127.0.0.1:23838", "127.0.0.1:23838", io.Discard) + + for i := 0; i < connAmount; i++ { + conn := &net.TCPConn{} + + conns = append(conns, conn) + relay.storeConn(conn) + } + + for i := 0; i < connAmount; i++ { + relay.popConn(conns[i]) + } + + if len(relay.connPool) != 0 { + t.Fatal("connPool is not empty") + } +} + +func TestConnPool(t *testing.T) { + // create channel to receive errors from another goroutine + errCh := make(chan error) + go startTCPServer(errCh) + + // wait for error or nil error indicating server launched fine + if err := <-errCh; err != nil { + t.Fatal(err) + } + + relay := New("test-relay", "127.0.0.1:23838", "127.0.0.1:23838", io.Discard) + + wg := sync.WaitGroup{} + + // open 10 conns and append to the conn pool + for i := 0; i < 10; i++ { + wg.Add(1) + + conn, err := net.Dial("tcp", "127.0.0.1:23838") + if err != nil { + t.Fatal(err) + } + + relay.storeConn(conn) + + // handle conn + go func(conn net.Conn, i int) { + for { + time.Sleep(time.Millisecond * (10 * time.Duration(i))) + _, err := conn.Write([]byte("test")) + if err != nil { + relay.popConn(conn) + + for _, c := range relay.connPool { + if c == conn { + t.Fatal("correct conn was not removed") + } + } + + wg.Done() + return + } + } + }(conn, i) + } + + wg.Wait() +} + +func startTCPServer(errCh chan error) { + l, err := net.Listen("tcp", ":23838") + if err != nil { + errCh <- err + return + } + + errCh <- nil + + for { + conn, err := l.Accept() + if err != nil { + continue + } + + // handle conn with echo server + go func(conn net.Conn) { + for i := 0; i <= 5; i++ { + buf := make([]byte, 1048) + n, err := conn.Read(buf) + if err != nil { + conn.Close() + return + } + + conn.Write(buf[:n]) + } + + // close conn after 5 messages + conn.Close() + }(conn) + } +} diff --git a/relayfailovertcp.go b/relayfailovertcp.go index 451ed77..794b021 100644 --- a/relayfailovertcp.go +++ b/relayfailovertcp.go @@ -35,8 +35,14 @@ func relayFailOverTCP(r *Relay, l net.Listener) error { } func handleFailOver(r *Relay, conn net.Conn, network string) { + r.storeConn(conn) + defer func() { conn.Close() + + // remove conn from connPool + r.popConn(conn) + r.Metrics.connections(-1) }() diff --git a/relayhttp.go b/relayhttp.go index e5920be..1880190 100644 --- a/relayhttp.go +++ b/relayhttp.go @@ -25,7 +25,6 @@ func HandleHTTP(relay *Relay) http.HandlerFunc { } func handleHTTP(w http.ResponseWriter, r *http.Request, re *Relay) { - re.Metrics.requests(1) remoteURL := re.ForwardAddr + r.URL.Path + "?" + r.URL.Query().Encode() diff --git a/relaytcp.go b/relaytcp.go index 6344cc8..3296570 100644 --- a/relaytcp.go +++ b/relaytcp.go @@ -45,8 +45,14 @@ func relayTCP(r *Relay, l net.Listener) error { } func handleConn(r *Relay, conn net.Conn, network string) { + r.storeConn(conn) + defer func() { conn.Close() + + // remove conn from connPool + r.popConn(conn) + r.Metrics.connections(-1) }()