From 748c25cf9a68e2f57febc4fec1a7f38763bc8fe1 Mon Sep 17 00:00:00 2001 From: boks1971 Date: Wed, 2 Feb 2022 19:15:44 +0530 Subject: [PATCH] Fix DTLS client role in long delay connections Fixes https://github.com/pion/webrtc/issues/2089 When a retranmission from the remote side arrives after the handshake is complete, the `finish` routine puts it back into retransmit loop. With Chrome, this fails after 15 seconds. Firefox does not error out though. Testing: --------- - Tested with Firefox and Chrome with long delay (500 ms up and down) in network link conditioner. - Tested the above with no introduced delays too. - Added test for slow server. --- AUTHORS.txt | 1 + handshaker.go | 3 + handshaker_test.go | 204 ++++++++++++++++++++++++++++++++++++++++----- 3 files changed, 185 insertions(+), 23 deletions(-) diff --git a/AUTHORS.txt b/AUTHORS.txt index 4b4e1d803..a8e6fb46f 100644 --- a/AUTHORS.txt +++ b/AUTHORS.txt @@ -9,6 +9,7 @@ Arlo Breault Atsushi Watanabe backkem bjdgyc +boks1971 Bragadeesh Carson Hoffman Cecylia Bocovich diff --git a/handshaker.go b/handshaker.go index f2cedfa97..1c7b9ffa2 100644 --- a/handshaker.go +++ b/handshaker.go @@ -326,6 +326,9 @@ func (s *handshakeFSM) finish(ctx context.Context, c flightConn) (handshakeState if nextFlight == 0 { break } + if nextFlight.isLastRecvFlight() && s.currentFlight == nextFlight { + return handshakeFinished, nil + } <-retransmitTimer.C // Retransmit last flight return handshakeSending, nil diff --git a/handshaker_test.go b/handshaker_test.go index 382b27b42..37bda9037 100644 --- a/handshaker_test.go +++ b/handshaker_test.go @@ -57,17 +57,20 @@ func TestHandshaker(t *testing.T) { t.Fatal(err) } - genFilters := map[string]func() (packetFilter, packetFilter, func(t *testing.T)){ - "PassThrough": func() (packetFilter, packetFilter, func(t *testing.T)) { - return nil, nil, nil + genFilters := map[string]func() (TestEndpoint, TestEndpoint, func(t *testing.T)){ + "PassThrough": func() (TestEndpoint, TestEndpoint, func(t *testing.T)) { + return TestEndpoint{}, TestEndpoint{}, nil }, - "HelloVerifyRequestLost": func() (packetFilter, packetFilter, func(t *testing.T)) { + + "HelloVerifyRequestLost": func() (TestEndpoint, TestEndpoint, func(t *testing.T)) { var ( cntHelloVerifyRequest = 0 cntClientHelloNoCookie = 0 ) const helloVerifyDrop = 5 - return func(p *packet) bool { + + clientEndpoint := TestEndpoint{ + Filter: func(p *packet) bool { h, ok := p.record.Content.(*handshake.Handshake) if !ok { return true @@ -79,7 +82,10 @@ func TestHandshaker(t *testing.T) { } return true }, - func(p *packet) bool { + } + + serverEndpoint := TestEndpoint{ + Filter: func(p *packet) bool { h, ok := p.record.Content.(*handshake.Handshake) if !ok { return true @@ -90,31 +96,161 @@ func TestHandshaker(t *testing.T) { } return true }, - func(t *testing.T) { - if cntHelloVerifyRequest != helloVerifyDrop+1 { - t.Errorf("Number of HelloVerifyRequest retransmit is wrong, expected: %d times, got: %d times", helloVerifyDrop+1, cntHelloVerifyRequest) + } + + report := func(t *testing.T) { + if cntHelloVerifyRequest != helloVerifyDrop+1 { + t.Errorf("Number of HelloVerifyRequest retransmit is wrong, expected: %d times, got: %d times", helloVerifyDrop+1, cntHelloVerifyRequest) + } + if cntClientHelloNoCookie != cntHelloVerifyRequest { + t.Errorf( + "HelloVerifyRequest must be triggered only by ClientHello, but HelloVerifyRequest was sent %d times and ClientHello was sent %d times", + cntHelloVerifyRequest, cntClientHelloNoCookie, + ) + } + } + + return clientEndpoint, serverEndpoint, report + }, + + "NoLatencyTest": func() (TestEndpoint, TestEndpoint, func(t *testing.T)) { + var ( + cntClientFinished = 0 + cntServerFinished = 0 + ) + + clientEndpoint := TestEndpoint{ + Filter: func(p *packet) bool { + h, ok := p.record.Content.(*handshake.Handshake) + if !ok { + return true + } + if _, ok := h.Message.(*handshake.MessageFinished); ok { + cntClientFinished++ + } + return true + }, + } + + serverEndpoint := TestEndpoint{ + Filter: func(p *packet) bool { + h, ok := p.record.Content.(*handshake.Handshake) + if !ok { + return true + } + if _, ok := h.Message.(*handshake.MessageFinished); ok { + cntServerFinished++ + } + return true + }, + } + + report := func(t *testing.T) { + if cntClientFinished != 1 { + t.Errorf("Number of client finished is wrong, expected: %d times, got: %d times", 1, cntClientFinished) + } + if cntServerFinished != 1 { + t.Errorf("Number of server finished is wrong, expected: %d times, got: %d times", 1, cntServerFinished) + } + } + + return clientEndpoint, serverEndpoint, report + }, + + "SlowServerTest": func() (TestEndpoint, TestEndpoint, func(t *testing.T)) { + var ( + cntClientFinished = 0 + isClientFinished = false + cntClientFinishedLastRetransmit = 0 + cntServerFinished = 0 + isServerFinished = false + cntServerFinishedLastRetransmit = 0 + ) + + clientEndpoint := TestEndpoint{ + Filter: func(p *packet) bool { + h, ok := p.record.Content.(*handshake.Handshake) + if !ok { + return true + } + if _, ok := h.Message.(*handshake.MessageFinished); ok { + if isClientFinished { + cntClientFinishedLastRetransmit++ + } else { + cntClientFinished++ + } + } + return true + }, + Delay: 0, + OnFinished: func() { + isClientFinished = true + }, + FinishWait: 2000 * time.Millisecond, + } + + serverEndpoint := TestEndpoint{ + Filter: func(p *packet) bool { + h, ok := p.record.Content.(*handshake.Handshake) + if !ok { + return true } - if cntClientHelloNoCookie != cntHelloVerifyRequest { - t.Errorf( - "HelloVerifyRequest must be triggered only by ClientHello, but HelloVerifyRequest was sent %d times and ClientHello was sent %d times", - cntHelloVerifyRequest, cntClientHelloNoCookie, - ) + if _, ok := h.Message.(*handshake.MessageFinished); ok { + if isServerFinished { + cntServerFinishedLastRetransmit++ + } else { + cntServerFinished++ + } } + return true + }, + Delay: 1000 * time.Millisecond, + OnFinished: func() { + isServerFinished = true + }, + FinishWait: 2000 * time.Millisecond, + } + + report := func(t *testing.T) { + // with one second server delay and 100 ms retransmit, there should be close to 10 `Finished` from client + // using a range of 9 - 11 for checking + if cntClientFinished < 8 || cntClientFinished > 11 { + t.Errorf("Number of client finished is wrong, expected: %d - %d times, got: %d times", 9, 11, cntClientFinished) + } + if !isClientFinished { + t.Errorf("Client is not finished") } + // there should be no `Finished` last retransmit from client + if cntClientFinishedLastRetransmit != 0 { + t.Errorf("Number of client finished last retransmit is wrong, expected: %d times, got: %d times", 0, cntClientFinishedLastRetransmit) + } + if cntServerFinished < 1 { + t.Errorf("Number of server finished is wrong, expected: at least %d times, got: %d times", 1, cntServerFinished) + } + if !isServerFinished { + t.Errorf("Server is not finished") + } + // there should be `Finished` last retransmit from server. Because of slow server, client would have sent several `Finished`. + if cntServerFinishedLastRetransmit < 1 { + t.Errorf("Number of server finished last retransmit is wrong, expected: at least %d times, got: %d times", 1, cntServerFinishedLastRetransmit) + } + } + + return clientEndpoint, serverEndpoint, report }, } for name, filters := range genFilters { - f1, f2, report := filters() + clientEndpoint, serverEndpoint, report := filters() t.Run(name, func(t *testing.T) { - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + ctx, cancel := context.WithTimeout(context.Background(), 20*time.Second) defer cancel() if report != nil { defer report(t) } - ca, cb := flightTestPipe(ctx, f1, f2) + ca, cb := flightTestPipe(ctx, clientEndpoint, serverEndpoint) ca.state.isClient = true var wg sync.WaitGroup @@ -132,7 +268,12 @@ func TestHandshaker(t *testing.T) { log: logger, onFlightState: func(f flightVal, s handshakeState) { if s == handshakeFinished { - cancelCli() + if clientEndpoint.OnFinished != nil { + clientEndpoint.OnFinished() + } + time.AfterFunc(clientEndpoint.FinishWait, func() { + cancelCli() + }) } }, retransmitInterval: nonZeroRetransmitInterval, @@ -158,7 +299,12 @@ func TestHandshaker(t *testing.T) { log: logger, onFlightState: func(f flightVal, s handshakeState) { if s == handshakeFinished { - cancelSrv() + if serverEndpoint.OnFinished != nil { + serverEndpoint.OnFinished() + } + time.AfterFunc(serverEndpoint.FinishWait, func() { + cancelSrv() + }) } }, retransmitInterval: nonZeroRetransmitInterval, @@ -183,9 +329,16 @@ func TestHandshaker(t *testing.T) { } } -type packetFilter func(*packet) bool +type packetFilter func(p *packet) bool + +type TestEndpoint struct { + Filter packetFilter + Delay time.Duration + OnFinished func() + FinishWait time.Duration +} -func flightTestPipe(ctx context.Context, filter1 packetFilter, filter2 packetFilter) (*flightTestConn, *flightTestConn) { +func flightTestPipe(ctx context.Context, clientEndpoint TestEndpoint, serverEndpoint TestEndpoint) (*flightTestConn, *flightTestConn) { ca := newHandshakeCache() cb := newHandshakeCache() chA := make(chan chan struct{}) @@ -196,14 +349,16 @@ func flightTestPipe(ctx context.Context, filter1 packetFilter, filter2 packetFil recv: chA, otherEndRecv: chB, done: ctx.Done(), - filter: filter1, + filter: clientEndpoint.Filter, + delay: clientEndpoint.Delay, }, &flightTestConn{ handshakeCache: cb, otherEndCache: ca, recv: chB, otherEndRecv: chA, done: ctx.Done(), - filter: filter2, + filter: serverEndpoint.Filter, + delay: serverEndpoint.Delay, } } @@ -216,6 +371,8 @@ type flightTestConn struct { filter packetFilter + delay time.Duration + otherEndCache *handshakeCache otherEndRecv chan chan struct{} } @@ -233,6 +390,7 @@ func (c *flightTestConn) notify(ctx context.Context, level alert.Level, desc ale } func (c *flightTestConn) writePackets(ctx context.Context, pkts []*packet) error { + time.Sleep(c.delay) for _, p := range pkts { if c.filter != nil && !c.filter(p) { continue