Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

tls: fix flaky TestInvalidCerts on Windows #1560

Merged
merged 3 commits into from
Jun 1, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 38 additions & 17 deletions p2p/security/tls/transport_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ import (
"math/big"
mrand "math/rand"
"net"
"runtime"
"strings"
"testing"
"time"

Expand Down Expand Up @@ -47,18 +49,36 @@ func createPeer(t *testing.T) (peer.ID, ic.PrivKey) {
}

func connect(t *testing.T) (net.Conn, net.Conn) {
ln, err := net.Listen("tcp", "localhost:0")
ln, err := net.ListenTCP("tcp", nil)
require.NoError(t, err)
defer ln.Close()
serverConnChan := make(chan net.Conn)
serverConnChan := make(chan *net.TCPConn)
go func() {
conn, err := ln.Accept()
assert.NoError(t, err)
serverConnChan <- conn
sconn := conn.(*net.TCPConn)
serverConnChan <- sconn
}()
conn, err := net.Dial("tcp", ln.Addr().String())
conn, err := net.DialTCP("tcp", nil, ln.Addr().(*net.TCPAddr))
require.NoError(t, err)
return conn, <-serverConnChan
sconn := <-serverConnChan
// On Windows we have to set linger to 0, otherwise we'll occasionally run into errors like the following:
// "connectex: Only one usage of each socket address (protocol/network address/port) is normally permitted."
// See https://github.com/libp2p/go-libp2p/issues/1529.
conn.SetLinger(0)
sconn.SetLinger(0)
t.Cleanup(func() {
conn.Close()
sconn.Close()
})
return conn, sconn
}

func isWindowsTCPCloseError(err error) bool {
if runtime.GOOS != "windows" {
return false
}
return strings.Contains(err.Error(), "wsarecv: An existing connection was forcibly closed by the remote host")
}

func TestHandshakeSucceeds(t *testing.T) {
Expand Down Expand Up @@ -482,26 +502,25 @@ func TestInvalidCerts(t *testing.T) {
_, err := conn.Read([]byte{0})
clientErrChan <- err
}()
var clientErr error
select {
case clientErr = <-clientErrChan:
case err := <-clientErrChan:
require.Error(t, err)
if err.Error() != "remote error: tls: error decrypting message" &&
err.Error() != "remote error: tls: bad certificate" &&
!isWindowsTCPCloseError(err) {
t.Errorf("unexpected error: %s", err.Error())
}
case <-time.After(250 * time.Millisecond):
t.Fatal("expected the server handshake to return")
}
require.Error(t, clientErr)
if clientErr.Error() != "remote error: tls: error decrypting message" &&
clientErr.Error() != "remote error: tls: bad certificate" {
t.Fatalf("unexpected error: %s", err.Error())
}

var serverErr error
select {
case serverErr = <-serverErrChan:
case err := <-serverErrChan:
require.Error(t, err)
tr.checkErr(t, err)
case <-time.After(250 * time.Millisecond):
t.Fatal("expected the server handshake to return")
}
require.Error(t, serverErr)
tr.checkErr(t, serverErr)
})

t.Run(fmt.Sprintf("server offending: %s", tr.name), func(t *testing.T) {
Expand Down Expand Up @@ -530,7 +549,9 @@ func TestInvalidCerts(t *testing.T) {
t.Fatal("expected the server handshake to return")
}
require.Error(t, serverErr)
require.Contains(t, serverErr.Error(), "remote error: tls:")
if !isWindowsTCPCloseError(serverErr) {
require.Contains(t, serverErr.Error(), "remote error: tls:")
}
})
}
}