Skip to content

Commit

Permalink
Use bufio reader
Browse files Browse the repository at this point in the history
  • Loading branch information
antoniomika committed Apr 26, 2024
1 parent 5f4bc7a commit c8003d4
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 68 deletions.
19 changes: 6 additions & 13 deletions httpmuxer/https.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ func (pL *proxyListener) Accept() (net.Conn, error) {
continue
}

tlsHello, buf, teeConn, peekErr := utils.PeekTLSHello(cl)
tlsHello, teeConn, peekErr := utils.PeekTLSHello(cl)
if peekErr != nil && tlsHello == nil {
return teeConn, nil
}
Expand Down Expand Up @@ -59,20 +59,20 @@ func (pL *proxyListener) Accept() (net.Conn, error) {
connectionLocation, err := balancer.NextServer()
if err != nil {
log.Println("Unable to load connection location:", err)
cl.Close()
teeConn.Close()
continue
}

host, err := base64.StdEncoding.DecodeString(connectionLocation.Host)
if err != nil {
log.Println("Unable to decode connection location:", err)
cl.Close()
teeConn.Close()
continue
}

hostAddr := string(host)

logLine := fmt.Sprintf("Accepted connection from %s -> %s", cl.RemoteAddr().String(), cl.LocalAddr().String())
logLine := fmt.Sprintf("Accepted connection from %s -> %s", teeConn.RemoteAddr().String(), teeConn.LocalAddr().String())
log.Println(logLine)

if viper.GetBool("log-to-client") {
Expand All @@ -94,18 +94,11 @@ func (pL *proxyListener) Accept() (net.Conn, error) {
conn, err := net.Dial("unix", hostAddr)
if err != nil {
log.Println("Error connecting to tcp balancer:", err)
cl.Close()
continue
}

_, err = conn.Write(buf.Bytes())
if err != nil {
log.Println("Unable to write to conn:", err)
cl.Close()
teeConn.Close()
continue
}

go utils.CopyBoth(conn, cl)
go utils.CopyBoth(conn, teeConn)
}
}

Expand Down
68 changes: 33 additions & 35 deletions utils/conn.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package utils

import (
"bufio"
"bytes"
"crypto/tls"
"io"
Expand Down Expand Up @@ -88,44 +89,22 @@ func (s *SSHConnection) CleanUp(state *State) {

// TeeConn represents a simple net.Conn interface for SNI Processing.
type TeeConn struct {
Conn net.Conn
Reader io.Reader
Buffer *bytes.Buffer
FirstRead bool
Flushed bool
Conn net.Conn
Buffer *bufio.ReadWriter
}

// Read implements a reader ontop of the TeeReader.
func (conn *TeeConn) Read(p []byte) (int, error) {
if !conn.FirstRead {
conn.FirstRead = true
return conn.Reader.Read(p)
}

if conn.FirstRead && !conn.Flushed {
conn.Flushed = true
copy(p[0:conn.Buffer.Len()], conn.Buffer.Bytes())
return conn.Buffer.Len(), nil
}

return conn.Conn.Read(p)
return conn.Buffer.Read(p)
}

// Write is a shim function to fit net.Conn.
func (conn *TeeConn) Write(p []byte) (int, error) {
if !conn.Flushed {
return 0, io.ErrClosedPipe
}

return conn.Conn.Write(p)
return conn.Buffer.Write(p)
}

// Close is a shim function to fit net.Conn.
func (conn *TeeConn) Close() error {
if !conn.Flushed {
return nil
}

return conn.Conn.Close()
}

Expand All @@ -145,22 +124,19 @@ func (conn *TeeConn) SetReadDeadline(t time.Time) error { return conn.Conn.SetRe
func (conn *TeeConn) SetWriteDeadline(t time.Time) error { return conn.Conn.SetWriteDeadline(t) }

// GetBuffer returns the tee'd buffer.
func (conn *TeeConn) GetBuffer() *bytes.Buffer { return conn.Buffer }
func (conn *TeeConn) GetBuffer() *bufio.ReadWriter { return conn.Buffer }

func NewTeeConn(conn net.Conn) *TeeConn {
teeConn := &TeeConn{
Conn: conn,
Buffer: bytes.NewBuffer([]byte{}),
Flushed: false,
Conn: conn,
Buffer: bufio.NewReadWriter(bufio.NewReaderSize(conn, 8192), bufio.NewWriterSize(conn, 8192)),
}

teeConn.Reader = io.TeeReader(conn, teeConn.Buffer)

return teeConn
}

// PeekTLSHello peeks the TLS Connection Hello to proxy based on SNI.
func PeekTLSHello(conn net.Conn) (*tls.ClientHelloInfo, *bytes.Buffer, *TeeConn, error) {
func PeekTLSHello(conn net.Conn) (*tls.ClientHelloInfo, *TeeConn, error) {
var tlsHello *tls.ClientHelloInfo

tlsConfig := &tls.Config{
Expand All @@ -172,11 +148,33 @@ func PeekTLSHello(conn net.Conn) (*tls.ClientHelloInfo, *bytes.Buffer, *TeeConn,

teeConn := NewTeeConn(conn)

err := tls.Server(teeConn, tlsConfig).Handshake()
header, err := teeConn.GetBuffer().Peek(5)
if err != nil {
return tlsHello, teeConn, err
}

if header[0] != 0x16 {
return tlsHello, teeConn, err
}

helloBytes, err := teeConn.GetBuffer().Peek(len(header) + (int(header[3])<<8 | int(header[4])))
if err != nil {
return tlsHello, teeConn, err
}

err = tls.Server(bufConn{reader: bytes.NewReader(helloBytes)}, tlsConfig).Handshake()

return tlsHello, teeConn.GetBuffer(), teeConn, err
return tlsHello, teeConn, err
}

type bufConn struct {
reader io.Reader
net.Conn
}

func (b bufConn) Read(p []byte) (int, error) { return b.reader.Read(p) }
func (bufConn) Write(p []byte) (int, error) { return 0, io.EOF }

// IdleTimeoutConn handles the connection with a context deadline.
// code adapted from https://qiita.com/kwi/items/b38d6273624ad3f6ae79
type IdleTimeoutConn struct {
Expand Down
29 changes: 9 additions & 20 deletions utils/state.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package utils

import (
"bytes"
"encoding/base64"
"fmt"
"io"
Expand Down Expand Up @@ -94,19 +93,18 @@ func (tH *TCPHolder) Handle(state *State) {
continue
}

var firstWrite *bytes.Buffer
realConn := cl

balancerName := ""
if tH.SNIProxy {
tlsHello, buf, _, err := PeekTLSHello(cl)
tlsHello, realConn, err := PeekTLSHello(cl)
if err != nil && tlsHello == nil {
log.Printf("Unable to read TLS hello: %s", err)
cl.Close()
realConn.Close()
continue
}

balancerName = tlsHello.ServerName
firstWrite = buf
}

pB, ok := tH.Balancers.Load(balancerName)
Expand All @@ -121,7 +119,7 @@ func (tH *TCPHolder) Handle(state *State) {

if pB == nil {
log.Printf("Unable to load connection location: %s not found on TCP listener %s", balancerName, tH.TCPHost)
cl.Close()
realConn.Close()
continue
}
}
Expand All @@ -131,20 +129,20 @@ func (tH *TCPHolder) Handle(state *State) {
connectionLocation, err := balancer.NextServer()
if err != nil {
log.Println("Unable to load connection location:", err)
cl.Close()
realConn.Close()
continue
}

host, err := base64.StdEncoding.DecodeString(connectionLocation.Host)
if err != nil {
log.Println("Unable to decode connection location:", err)
cl.Close()
realConn.Close()
continue
}

hostAddr := string(host)

logLine := fmt.Sprintf("Accepted connection from %s -> %s", cl.RemoteAddr().String(), cl.LocalAddr().String())
logLine := fmt.Sprintf("Accepted connection from %s -> %s", realConn.RemoteAddr().String(), realConn.LocalAddr().String())
log.Println(logLine)

if viper.GetBool("log-to-client") {
Expand All @@ -166,20 +164,11 @@ func (tH *TCPHolder) Handle(state *State) {
conn, err := net.Dial("unix", hostAddr)
if err != nil {
log.Println("Error connecting to tcp balancer:", err)
cl.Close()
realConn.Close()
continue
}

if firstWrite != nil {
_, err := conn.Write(firstWrite.Bytes())
if err != nil {
log.Println("Unable to write to conn:", err)
cl.Close()
continue
}
}

go CopyBoth(conn, cl)
go CopyBoth(conn, realConn)
}
}

Expand Down

0 comments on commit c8003d4

Please sign in to comment.