Skip to content

Commit

Permalink
Allow fragmented REALITY Client Hello & Simplify logic
Browse files Browse the repository at this point in the history
It's mainly for defending against certain attacks.
  • Loading branch information
RPRX committed Aug 28, 2023
1 parent e07c3b0 commit e426190
Showing 1 changed file with 53 additions and 76 deletions.
129 changes: 53 additions & 76 deletions tls.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,31 +41,43 @@ import (
"golang.org/x/crypto/hkdf"
)

type WeakConn struct {
type MirrorConn struct {
*sync.Mutex
net.Conn
Target net.Conn
}

func (c *WeakConn) Read(b []byte) (int, error) {
return 0, fmt.Errorf("Read(%v)", len(b))
func (c *MirrorConn) Read(b []byte) (int, error) {
c.Unlock()
runtime.Gosched()
n, err := c.Conn.Read(b)
c.Lock() // calling c.Lock() before c.Target.Write(), to make sure that this goroutine has the priority to make the next move
if n != 0 {
c.Target.Write(b[:n])
}
if err != nil {
c.Target.Close()
}
return n, err
}

func (c *WeakConn) Write(b []byte) (int, error) {
func (c *MirrorConn) Write(b []byte) (int, error) {
return 0, fmt.Errorf("Write(%v)", len(b))
}

func (c *WeakConn) Close() error {
func (c *MirrorConn) Close() error {
return fmt.Errorf("Close()")
}

func (c *WeakConn) SetDeadline(t time.Time) error {
func (c *MirrorConn) SetDeadline(t time.Time) error {
return nil
}

func (c *WeakConn) SetReadDeadline(t time.Time) error {
func (c *MirrorConn) SetReadDeadline(t time.Time) error {
return nil
}

func (c *WeakConn) SetWriteDeadline(t time.Time) error {
func (c *MirrorConn) SetWriteDeadline(t time.Time) error {
return nil
}

Expand Down Expand Up @@ -116,68 +128,33 @@ func Server(ctx context.Context, conn net.Conn, config *Config) (*Conn, error) {

underlying := conn
if pc, ok := underlying.(*proxyproto.Conn); ok {
underlying = pc.Raw()
underlying = pc.Raw() // for TCP splicing in io.Copy()
}

hs := serverHandshakeStateTLS13{ctx: context.Background()}
mutex := new(sync.Mutex)

c2sSaved := make([]byte, 0, size)
s2cSaved := make([]byte, 0, size)
hs := serverHandshakeStateTLS13{
c: &Conn{
conn: &MirrorConn{
Mutex: mutex,
Conn: conn,
Target: target,
},
config: config,
},
ctx: context.Background(),
}

copying := false
handled := false

waitGroup := new(sync.WaitGroup)
waitGroup.Add(2)

mutex := new(sync.Mutex)

go func() {
done := false
buf := make([]byte, size)
clientHelloLen := 0
for {
runtime.Gosched()
n, err := conn.Read(buf)
if n == 0 {
if err != nil {
target.Close()
waitGroup.Done()
return
}
continue
}
mutex.Lock()
c2sSaved = append(c2sSaved, buf[:n]...)
if _, err = target.Write(buf[:n]); err != nil {
done = true
break
}
if len(c2sSaved) > size || copying { // too long; follow
break
}
if clientHelloLen == 0 && len(c2sSaved) > recordHeaderLen {
if recordType(c2sSaved[0]) != recordTypeHandshake || Value(c2sSaved[1:3]...) != VersionTLS10 || c2sSaved[5] != typeClientHello {
break
}
clientHelloLen = recordHeaderLen + Value(c2sSaved[3:5]...)
}
if clientHelloLen > size { // too long
break
}
if clientHelloLen == 0 || len(c2sSaved) < clientHelloLen {
mutex.Unlock()
continue
}
hs.c = &Conn{
conn: &WeakConn{conn},
config: config,
rawInput: *bytes.NewBuffer(c2sSaved),
}
if hs.clientHello, err = hs.c.readClientHello(context.Background()); err != nil {
break
}
if hs.c.vers != VersionTLS13 || !config.ServerNames[hs.clientHello.serverName] {
hs.clientHello, err = hs.c.readClientHello(context.Background()) // TODO: Change some rules in this function.
if copying || err != nil || hs.c.vers != VersionTLS13 || !config.ServerNames[hs.clientHello.serverName] {
break
}
for i, keyShare := range hs.clientHello.keyShares {
Expand Down Expand Up @@ -228,20 +205,17 @@ func Server(ctx context.Context, conn net.Conn, config *Config) (*Conn, error) {
if config.Show {
fmt.Printf("REALITY remoteAddr: %v\ths.c.conn == conn: %v\n", remoteAddr, hs.c.conn == conn)
}
if hs.c.conn == conn {
done = true
}
break
}
mutex.Unlock()
if !done {
io.CopyBuffer(target, underlying, buf)
if hs.c.conn != conn {
io.Copy(target, underlying)
}
waitGroup.Done()
}()

go func() {
done := false
s2cSaved := make([]byte, 0, size)
buf := make([]byte, size)
handshakeLen := 0
f:
Expand All @@ -258,14 +232,10 @@ func Server(ctx context.Context, conn net.Conn, config *Config) (*Conn, error) {
}
mutex.Lock()
s2cSaved = append(s2cSaved, buf[:n]...)
if hs.c == nil || hs.c.conn != conn {
copying = true
if _, err = conn.Write(buf[:n]); err != nil {
done = true
}
if hs.c.conn != conn {
copying = true // if the target already sent some data, just start bidirectional direct forwarding
break
}
done = true // special
if len(s2cSaved) > size {
break
}
Expand Down Expand Up @@ -349,26 +319,33 @@ func Server(ctx context.Context, conn net.Conn, config *Config) (*Conn, error) {
break
}
atomic.StoreUint32(&hs.c.handshakeStatus, 1)
handled = true
break
}
mutex.Unlock()
if !done {
io.CopyBuffer(underlying, target, buf)
if hs.c.out.handshakeLen[0] == 0 { // if the target sent an incorrect Server Hello, or before that
if hs.c.conn == conn { // if we processed the Client Hello successfully but the target did not
waitGroup.Add(1)
go func() {
io.Copy(target, underlying)
waitGroup.Done()
}()
}
conn.Write(s2cSaved)
io.Copy(underlying, target)
}
waitGroup.Done()
}()

waitGroup.Wait()
target.Close()
if config.Show {
fmt.Printf("REALITY remoteAddr: %v\thandled: %v\n", remoteAddr, handled)
fmt.Printf("REALITY remoteAddr: %v\ths.c.handshakeStatus: %v\n", remoteAddr, atomic.LoadUint32(&hs.c.handshakeStatus))
}
if handled {
if atomic.LoadUint32(&hs.c.handshakeStatus) == 1 {
return hs.c, nil
}
conn.Close()
return nil, errors.New("REALITY: processed invalid connection")
return nil, errors.New("REALITY: processed invalid connection") // TODO: Add details.

/*
c := &Conn{
Expand Down

0 comments on commit e426190

Please sign in to comment.