Skip to content

Commit

Permalink
Fix DTLS client role in long delay connections
Browse files Browse the repository at this point in the history
Fixes pion/webrtc#2089

When a retranmission from the remote side arrives after
the handshake is complete, the `finish` routine
puts it back into retransmit loop.

With Chrome, this fails after 15 seconds.
Firefox does not error out though.

Testing:
---------
- Tested with Firefox and Chrome with long delay
(500 ms up and down) in network link conditioner.
- Tested the above with no introduced delays too.
- Added test for slow server.
  • Loading branch information
boks1971 committed Feb 5, 2022
1 parent 17f86a3 commit 748c25c
Show file tree
Hide file tree
Showing 3 changed files with 185 additions and 23 deletions.
1 change: 1 addition & 0 deletions AUTHORS.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ Arlo Breault <[email protected]>
Atsushi Watanabe <[email protected]>
backkem <[email protected]>
bjdgyc <[email protected]>
boks1971 <[email protected]>
Bragadeesh <[email protected]>
Carson Hoffman <[email protected]>
Cecylia Bocovich <[email protected]>
Expand Down
3 changes: 3 additions & 0 deletions handshaker.go
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,9 @@ func (s *handshakeFSM) finish(ctx context.Context, c flightConn) (handshakeState
if nextFlight == 0 {
break
}
if nextFlight.isLastRecvFlight() && s.currentFlight == nextFlight {
return handshakeFinished, nil
}
<-retransmitTimer.C
// Retransmit last flight
return handshakeSending, nil
Expand Down
204 changes: 181 additions & 23 deletions handshaker_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,17 +57,20 @@ func TestHandshaker(t *testing.T) {
t.Fatal(err)
}

genFilters := map[string]func() (packetFilter, packetFilter, func(t *testing.T)){
"PassThrough": func() (packetFilter, packetFilter, func(t *testing.T)) {
return nil, nil, nil
genFilters := map[string]func() (TestEndpoint, TestEndpoint, func(t *testing.T)){
"PassThrough": func() (TestEndpoint, TestEndpoint, func(t *testing.T)) {
return TestEndpoint{}, TestEndpoint{}, nil
},
"HelloVerifyRequestLost": func() (packetFilter, packetFilter, func(t *testing.T)) {

"HelloVerifyRequestLost": func() (TestEndpoint, TestEndpoint, func(t *testing.T)) {
var (
cntHelloVerifyRequest = 0
cntClientHelloNoCookie = 0
)
const helloVerifyDrop = 5
return func(p *packet) bool {

clientEndpoint := TestEndpoint{
Filter: func(p *packet) bool {
h, ok := p.record.Content.(*handshake.Handshake)
if !ok {
return true
Expand All @@ -79,7 +82,10 @@ func TestHandshaker(t *testing.T) {
}
return true
},
func(p *packet) bool {
}

serverEndpoint := TestEndpoint{
Filter: func(p *packet) bool {
h, ok := p.record.Content.(*handshake.Handshake)
if !ok {
return true
Expand All @@ -90,31 +96,161 @@ func TestHandshaker(t *testing.T) {
}
return true
},
func(t *testing.T) {
if cntHelloVerifyRequest != helloVerifyDrop+1 {
t.Errorf("Number of HelloVerifyRequest retransmit is wrong, expected: %d times, got: %d times", helloVerifyDrop+1, cntHelloVerifyRequest)
}

report := func(t *testing.T) {
if cntHelloVerifyRequest != helloVerifyDrop+1 {
t.Errorf("Number of HelloVerifyRequest retransmit is wrong, expected: %d times, got: %d times", helloVerifyDrop+1, cntHelloVerifyRequest)
}
if cntClientHelloNoCookie != cntHelloVerifyRequest {
t.Errorf(
"HelloVerifyRequest must be triggered only by ClientHello, but HelloVerifyRequest was sent %d times and ClientHello was sent %d times",
cntHelloVerifyRequest, cntClientHelloNoCookie,
)
}
}

return clientEndpoint, serverEndpoint, report
},

"NoLatencyTest": func() (TestEndpoint, TestEndpoint, func(t *testing.T)) {
var (
cntClientFinished = 0
cntServerFinished = 0
)

clientEndpoint := TestEndpoint{
Filter: func(p *packet) bool {
h, ok := p.record.Content.(*handshake.Handshake)
if !ok {
return true
}
if _, ok := h.Message.(*handshake.MessageFinished); ok {
cntClientFinished++
}
return true
},
}

serverEndpoint := TestEndpoint{
Filter: func(p *packet) bool {
h, ok := p.record.Content.(*handshake.Handshake)
if !ok {
return true
}
if _, ok := h.Message.(*handshake.MessageFinished); ok {
cntServerFinished++
}
return true
},
}

report := func(t *testing.T) {
if cntClientFinished != 1 {
t.Errorf("Number of client finished is wrong, expected: %d times, got: %d times", 1, cntClientFinished)
}
if cntServerFinished != 1 {
t.Errorf("Number of server finished is wrong, expected: %d times, got: %d times", 1, cntServerFinished)
}
}

return clientEndpoint, serverEndpoint, report
},

"SlowServerTest": func() (TestEndpoint, TestEndpoint, func(t *testing.T)) {
var (
cntClientFinished = 0
isClientFinished = false
cntClientFinishedLastRetransmit = 0
cntServerFinished = 0
isServerFinished = false
cntServerFinishedLastRetransmit = 0
)

clientEndpoint := TestEndpoint{
Filter: func(p *packet) bool {
h, ok := p.record.Content.(*handshake.Handshake)
if !ok {
return true
}
if _, ok := h.Message.(*handshake.MessageFinished); ok {
if isClientFinished {
cntClientFinishedLastRetransmit++
} else {
cntClientFinished++
}
}
return true
},
Delay: 0,
OnFinished: func() {
isClientFinished = true
},
FinishWait: 2000 * time.Millisecond,
}

serverEndpoint := TestEndpoint{
Filter: func(p *packet) bool {
h, ok := p.record.Content.(*handshake.Handshake)
if !ok {
return true
}
if cntClientHelloNoCookie != cntHelloVerifyRequest {
t.Errorf(
"HelloVerifyRequest must be triggered only by ClientHello, but HelloVerifyRequest was sent %d times and ClientHello was sent %d times",
cntHelloVerifyRequest, cntClientHelloNoCookie,
)
if _, ok := h.Message.(*handshake.MessageFinished); ok {
if isServerFinished {
cntServerFinishedLastRetransmit++
} else {
cntServerFinished++
}
}
return true
},
Delay: 1000 * time.Millisecond,
OnFinished: func() {
isServerFinished = true
},
FinishWait: 2000 * time.Millisecond,
}

report := func(t *testing.T) {
// with one second server delay and 100 ms retransmit, there should be close to 10 `Finished` from client
// using a range of 9 - 11 for checking
if cntClientFinished < 8 || cntClientFinished > 11 {
t.Errorf("Number of client finished is wrong, expected: %d - %d times, got: %d times", 9, 11, cntClientFinished)
}
if !isClientFinished {
t.Errorf("Client is not finished")
}
// there should be no `Finished` last retransmit from client
if cntClientFinishedLastRetransmit != 0 {
t.Errorf("Number of client finished last retransmit is wrong, expected: %d times, got: %d times", 0, cntClientFinishedLastRetransmit)
}
if cntServerFinished < 1 {
t.Errorf("Number of server finished is wrong, expected: at least %d times, got: %d times", 1, cntServerFinished)
}
if !isServerFinished {
t.Errorf("Server is not finished")
}
// there should be `Finished` last retransmit from server. Because of slow server, client would have sent several `Finished`.
if cntServerFinishedLastRetransmit < 1 {
t.Errorf("Number of server finished last retransmit is wrong, expected: at least %d times, got: %d times", 1, cntServerFinishedLastRetransmit)
}
}

return clientEndpoint, serverEndpoint, report
},
}

for name, filters := range genFilters {
f1, f2, report := filters()
clientEndpoint, serverEndpoint, report := filters()
t.Run(name, func(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
ctx, cancel := context.WithTimeout(context.Background(), 20*time.Second)
defer cancel()

if report != nil {
defer report(t)
}

ca, cb := flightTestPipe(ctx, f1, f2)
ca, cb := flightTestPipe(ctx, clientEndpoint, serverEndpoint)
ca.state.isClient = true

var wg sync.WaitGroup
Expand All @@ -132,7 +268,12 @@ func TestHandshaker(t *testing.T) {
log: logger,
onFlightState: func(f flightVal, s handshakeState) {
if s == handshakeFinished {
cancelCli()
if clientEndpoint.OnFinished != nil {
clientEndpoint.OnFinished()
}
time.AfterFunc(clientEndpoint.FinishWait, func() {
cancelCli()
})
}
},
retransmitInterval: nonZeroRetransmitInterval,
Expand All @@ -158,7 +299,12 @@ func TestHandshaker(t *testing.T) {
log: logger,
onFlightState: func(f flightVal, s handshakeState) {
if s == handshakeFinished {
cancelSrv()
if serverEndpoint.OnFinished != nil {
serverEndpoint.OnFinished()
}
time.AfterFunc(serverEndpoint.FinishWait, func() {
cancelSrv()
})
}
},
retransmitInterval: nonZeroRetransmitInterval,
Expand All @@ -183,9 +329,16 @@ func TestHandshaker(t *testing.T) {
}
}

type packetFilter func(*packet) bool
type packetFilter func(p *packet) bool

type TestEndpoint struct {
Filter packetFilter
Delay time.Duration
OnFinished func()
FinishWait time.Duration
}

func flightTestPipe(ctx context.Context, filter1 packetFilter, filter2 packetFilter) (*flightTestConn, *flightTestConn) {
func flightTestPipe(ctx context.Context, clientEndpoint TestEndpoint, serverEndpoint TestEndpoint) (*flightTestConn, *flightTestConn) {
ca := newHandshakeCache()
cb := newHandshakeCache()
chA := make(chan chan struct{})
Expand All @@ -196,14 +349,16 @@ func flightTestPipe(ctx context.Context, filter1 packetFilter, filter2 packetFil
recv: chA,
otherEndRecv: chB,
done: ctx.Done(),
filter: filter1,
filter: clientEndpoint.Filter,
delay: clientEndpoint.Delay,
}, &flightTestConn{
handshakeCache: cb,
otherEndCache: ca,
recv: chB,
otherEndRecv: chA,
done: ctx.Done(),
filter: filter2,
filter: serverEndpoint.Filter,
delay: serverEndpoint.Delay,
}
}

Expand All @@ -216,6 +371,8 @@ type flightTestConn struct {

filter packetFilter

delay time.Duration

otherEndCache *handshakeCache
otherEndRecv chan chan struct{}
}
Expand All @@ -233,6 +390,7 @@ func (c *flightTestConn) notify(ctx context.Context, level alert.Level, desc ale
}

func (c *flightTestConn) writePackets(ctx context.Context, pkts []*packet) error {
time.Sleep(c.delay)
for _, p := range pkts {
if c.filter != nil && !c.filter(p) {
continue
Expand Down

0 comments on commit 748c25c

Please sign in to comment.