From 78f85d7dcf97511b1931a32da18cb5d833cd3305 Mon Sep 17 00:00:00 2001 From: Oleg Kovalov Date: Wed, 9 Jun 2021 00:01:09 +0200 Subject: [PATCH] Fix context canceled error via manual cancelling (#9) --- hedged.go | 42 ++++++++++++++++++++++++++++++------------ hedged_test.go | 38 ++++++++++++++++++++++---------------- 2 files changed, 52 insertions(+), 28 deletions(-) diff --git a/hedged.go b/hedged.go index 0185155..a90a6de 100644 --- a/hedged.go +++ b/hedged.go @@ -55,20 +55,32 @@ func (ht *hedgedTransport) RoundTrip(req *http.Request) (*http.Response, error) } errOverall := &MultiError{} - resultCh := make(chan *http.Response, ht.upto) + resultCh := make(chan indexedResp, ht.upto) errorCh := make(chan error, ht.upto) + resultIdx := -1 + cancels := make([]func(), ht.upto) + + defer runInPool(func() { + for i, cancel := range cancels { + if i != resultIdx && cancel != nil { + cancel() + } + } + }) + for sent := 0; len(errOverall.Errors) < ht.upto; sent++ { if sent < ht.upto { - runInPool(func() { - req, cancel := reqWithCtx(req, mainCtx) - defer cancel() + idx := sent + subReq, cancel := reqWithCtx(req, mainCtx) + cancels[idx] = cancel - resp, err := ht.rt.RoundTrip(req) + runInPool(func() { + resp, err := ht.rt.RoundTrip(subReq) if err != nil { errorCh <- err } else { - resultCh <- resp + resultCh <- indexedResp{idx, resp} } }) } @@ -80,8 +92,9 @@ func (ht *hedgedTransport) RoundTrip(req *http.Request) (*http.Response, error) resp, err := waitResult(mainCtx, resultCh, errorCh, timeout) switch { - case resp != nil: - return resp, nil + case resp.Resp != nil: + resultIdx = resp.Index + return resp.Resp, nil case mainCtx.Err() != nil: return nil, mainCtx.Err() case err != nil: @@ -93,7 +106,7 @@ func (ht *hedgedTransport) RoundTrip(req *http.Request) (*http.Response, error) return nil, errOverall } -func waitResult(ctx context.Context, resultCh <-chan *http.Response, errorCh <-chan error, timeout time.Duration) (*http.Response, error) { +func waitResult(ctx context.Context, resultCh <-chan indexedResp, errorCh <-chan error, timeout time.Duration) (indexedResp, error) { // try to read result first before blocking on all other channels select { case res := <-resultCh: @@ -107,17 +120,22 @@ func waitResult(ctx context.Context, resultCh <-chan *http.Response, errorCh <-c return res, nil case reqErr := <-errorCh: - return nil, reqErr + return indexedResp{}, reqErr case <-ctx.Done(): - return nil, ctx.Err() + return indexedResp{}, ctx.Err() case <-timer.C: - return nil, nil // it's not a request timeout, it's timeout BETWEEN consecutive requests + return indexedResp{}, nil // it's not a request timeout, it's timeout BETWEEN consecutive requests } } } +type indexedResp struct { + Index int + Resp *http.Response +} + func reqWithCtx(r *http.Request, ctx context.Context) (*http.Request, func()) { ctx, cancel := context.WithCancel(ctx) req := r.WithContext(ctx) diff --git a/hedged_test.go b/hedged_test.go index 3fd3ebf..eaf8e0a 100644 --- a/hedged_test.go +++ b/hedged_test.go @@ -6,7 +6,6 @@ import ( "fmt" "io" "io/ioutil" - "math/rand" "net/http" "net/http/httptest" "strings" @@ -16,10 +15,10 @@ import ( ) func TestUpto(t *testing.T) { - gotRequests := 0 + var gotRequests int64 url := testServerURL(t, func(w http.ResponseWriter, r *http.Request) { - gotRequests++ + atomic.AddInt64(&gotRequests, 1) time.Sleep(100 * time.Millisecond) }) @@ -31,17 +30,17 @@ func TestUpto(t *testing.T) { const upto = 7 _, _ = NewClient(10*time.Millisecond, upto, nil).Do(req) - if gotRequests != upto { + if gotRequests := atomic.LoadInt64(&gotRequests); gotRequests != upto { t.Fatalf("want %v, got %v", upto, gotRequests) } } func TestNoTimeout(t *testing.T) { const sleep = 10 * time.Millisecond - var gotRequests = 0 + var gotRequests int64 url := testServerURL(t, func(w http.ResponseWriter, r *http.Request) { - gotRequests++ + atomic.AddInt64(&gotRequests, 1) time.Sleep(sleep) }) @@ -60,7 +59,7 @@ func TestNoTimeout(t *testing.T) { if float64(passed) > want { t.Fatalf("want %v, got %v", time.Duration(want), passed) } - if gotRequests != upto { + if gotRequests := atomic.LoadInt64(&gotRequests); gotRequests != upto { t.Fatalf("want %v, got %v", upto, gotRequests) } } @@ -92,13 +91,19 @@ func TestFirstIsOK(t *testing.T) { } func TestBestResponse(t *testing.T) { - timeout := []time.Duration{7000 * time.Millisecond, 100 * time.Millisecond, 20 * time.Millisecond} - shortest := shortestFrom(timeout) + const shortest = 20 * time.Millisecond + timeouts := [...]time.Duration{30 * shortest, 5 * shortest, shortest, shortest, shortest} + timeoutCh := make(chan time.Duration, len(timeouts)) + for _, t := range timeouts { + timeoutCh <- t + } url := testServerURL(t, func(w http.ResponseWriter, r *http.Request) { - time.Sleep(timeout[rand.Int()%len(timeout)]) + time.Sleep(<-timeoutCh) }) + start := time.Now() + ctx, cancel := context.WithCancel(context.Background()) defer cancel() @@ -106,12 +111,13 @@ func TestBestResponse(t *testing.T) { if err != nil { t.Fatal(err) } - - start := time.Now() - _, _ = NewClient(10*time.Millisecond, 10, nil).Do(req) + _, err = NewClient(10*time.Millisecond, 5, nil).Do(req) + if err != nil { + t.Fatal(err) + } passed := time.Since(start) - if float64(passed) > float64(shortest)*1.2 { + if float64(passed) > float64(shortest)*2.5 { t.Fatalf("want %v, got %v", shortest, passed) } } @@ -197,8 +203,8 @@ func TestHangAllExceptLast(t *testing.T) { defer close(blockCh) url := testServerURL(t, func(w http.ResponseWriter, r *http.Request) { - gotRequests++ - if gotRequests == upto { + idx := atomic.AddUint64(&gotRequests, 1) + if idx == upto { time.Sleep(100 * time.Millisecond) return }