diff --git a/tls.go b/tls.go index 38651a7..ea02484 100644 --- a/tls.go +++ b/tls.go @@ -41,31 +41,43 @@ import ( "golang.org/x/crypto/hkdf" ) -type WeakConn struct { +type MirrorConn struct { + *sync.Mutex net.Conn + Target net.Conn } -func (c *WeakConn) Read(b []byte) (int, error) { - return 0, fmt.Errorf("Read(%v)", len(b)) +func (c *MirrorConn) Read(b []byte) (int, error) { + c.Unlock() + runtime.Gosched() + n, err := c.Conn.Read(b) + c.Lock() // calling c.Lock() before c.Target.Write(), to make sure that this goroutine has the priority to make the next move + if n != 0 { + c.Target.Write(b[:n]) + } + if err != nil { + c.Target.Close() + } + return n, err } -func (c *WeakConn) Write(b []byte) (int, error) { +func (c *MirrorConn) Write(b []byte) (int, error) { return 0, fmt.Errorf("Write(%v)", len(b)) } -func (c *WeakConn) Close() error { +func (c *MirrorConn) Close() error { return fmt.Errorf("Close()") } -func (c *WeakConn) SetDeadline(t time.Time) error { +func (c *MirrorConn) SetDeadline(t time.Time) error { return nil } -func (c *WeakConn) SetReadDeadline(t time.Time) error { +func (c *MirrorConn) SetReadDeadline(t time.Time) error { return nil } -func (c *WeakConn) SetWriteDeadline(t time.Time) error { +func (c *MirrorConn) SetWriteDeadline(t time.Time) error { return nil } @@ -116,68 +128,33 @@ func Server(ctx context.Context, conn net.Conn, config *Config) (*Conn, error) { underlying := conn if pc, ok := underlying.(*proxyproto.Conn); ok { - underlying = pc.Raw() + underlying = pc.Raw() // for TCP splicing in io.Copy() } - hs := serverHandshakeStateTLS13{ctx: context.Background()} + mutex := new(sync.Mutex) - c2sSaved := make([]byte, 0, size) - s2cSaved := make([]byte, 0, size) + hs := serverHandshakeStateTLS13{ + c: &Conn{ + conn: &MirrorConn{ + Mutex: mutex, + Conn: conn, + Target: target, + }, + config: config, + }, + ctx: context.Background(), + } copying := false - handled := false waitGroup := new(sync.WaitGroup) waitGroup.Add(2) - mutex := new(sync.Mutex) - go func() { - done := false - buf := make([]byte, size) - clientHelloLen := 0 for { - runtime.Gosched() - n, err := conn.Read(buf) - if n == 0 { - if err != nil { - target.Close() - waitGroup.Done() - return - } - continue - } mutex.Lock() - c2sSaved = append(c2sSaved, buf[:n]...) - if _, err = target.Write(buf[:n]); err != nil { - done = true - break - } - if len(c2sSaved) > size || copying { // too long; follow - break - } - if clientHelloLen == 0 && len(c2sSaved) > recordHeaderLen { - if recordType(c2sSaved[0]) != recordTypeHandshake || Value(c2sSaved[1:3]...) != VersionTLS10 || c2sSaved[5] != typeClientHello { - break - } - clientHelloLen = recordHeaderLen + Value(c2sSaved[3:5]...) - } - if clientHelloLen > size { // too long - break - } - if clientHelloLen == 0 || len(c2sSaved) < clientHelloLen { - mutex.Unlock() - continue - } - hs.c = &Conn{ - conn: &WeakConn{conn}, - config: config, - rawInput: *bytes.NewBuffer(c2sSaved), - } - if hs.clientHello, err = hs.c.readClientHello(context.Background()); err != nil { - break - } - if hs.c.vers != VersionTLS13 || !config.ServerNames[hs.clientHello.serverName] { + hs.clientHello, err = hs.c.readClientHello(context.Background()) // TODO: Change some rules in this function. + if copying || err != nil || hs.c.vers != VersionTLS13 || !config.ServerNames[hs.clientHello.serverName] { break } for i, keyShare := range hs.clientHello.keyShares { @@ -228,20 +205,17 @@ func Server(ctx context.Context, conn net.Conn, config *Config) (*Conn, error) { if config.Show { fmt.Printf("REALITY remoteAddr: %v\ths.c.conn == conn: %v\n", remoteAddr, hs.c.conn == conn) } - if hs.c.conn == conn { - done = true - } break } mutex.Unlock() - if !done { - io.CopyBuffer(target, underlying, buf) + if hs.c.conn != conn { + io.Copy(target, underlying) } waitGroup.Done() }() go func() { - done := false + s2cSaved := make([]byte, 0, size) buf := make([]byte, size) handshakeLen := 0 f: @@ -258,14 +232,10 @@ func Server(ctx context.Context, conn net.Conn, config *Config) (*Conn, error) { } mutex.Lock() s2cSaved = append(s2cSaved, buf[:n]...) - if hs.c == nil || hs.c.conn != conn { - copying = true - if _, err = conn.Write(buf[:n]); err != nil { - done = true - } + if hs.c.conn != conn { + copying = true // if the target already sent some data, just start bidirectional direct forwarding break } - done = true // special if len(s2cSaved) > size { break } @@ -349,12 +319,19 @@ func Server(ctx context.Context, conn net.Conn, config *Config) (*Conn, error) { break } atomic.StoreUint32(&hs.c.handshakeStatus, 1) - handled = true break } mutex.Unlock() - if !done { - io.CopyBuffer(underlying, target, buf) + if hs.c.out.handshakeLen[0] == 0 { // if the target sent an incorrect Server Hello, or before that + if hs.c.conn == conn { // if we processed the Client Hello successfully but the target did not + waitGroup.Add(1) + go func() { + io.Copy(target, underlying) + waitGroup.Done() + }() + } + conn.Write(s2cSaved) + io.Copy(underlying, target) } waitGroup.Done() }() @@ -362,13 +339,13 @@ func Server(ctx context.Context, conn net.Conn, config *Config) (*Conn, error) { waitGroup.Wait() target.Close() if config.Show { - fmt.Printf("REALITY remoteAddr: %v\thandled: %v\n", remoteAddr, handled) + fmt.Printf("REALITY remoteAddr: %v\ths.c.handshakeStatus: %v\n", remoteAddr, atomic.LoadUint32(&hs.c.handshakeStatus)) } - if handled { + if atomic.LoadUint32(&hs.c.handshakeStatus) == 1 { return hs.c, nil } conn.Close() - return nil, errors.New("REALITY: processed invalid connection") + return nil, errors.New("REALITY: processed invalid connection") // TODO: Add details. /* c := &Conn{