Skip to content

Commit

Permalink
Fix context canceled error via manual cancelling (#9)
Browse files Browse the repository at this point in the history
  • Loading branch information
cristaloleg authored Jun 8, 2021
1 parent 980f5db commit 78f85d7
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 28 deletions.
42 changes: 30 additions & 12 deletions hedged.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,20 +55,32 @@ func (ht *hedgedTransport) RoundTrip(req *http.Request) (*http.Response, error)
}

errOverall := &MultiError{}
resultCh := make(chan *http.Response, ht.upto)
resultCh := make(chan indexedResp, ht.upto)
errorCh := make(chan error, ht.upto)

resultIdx := -1
cancels := make([]func(), ht.upto)

defer runInPool(func() {
for i, cancel := range cancels {
if i != resultIdx && cancel != nil {
cancel()
}
}
})

for sent := 0; len(errOverall.Errors) < ht.upto; sent++ {
if sent < ht.upto {
runInPool(func() {
req, cancel := reqWithCtx(req, mainCtx)
defer cancel()
idx := sent
subReq, cancel := reqWithCtx(req, mainCtx)
cancels[idx] = cancel

resp, err := ht.rt.RoundTrip(req)
runInPool(func() {
resp, err := ht.rt.RoundTrip(subReq)
if err != nil {
errorCh <- err
} else {
resultCh <- resp
resultCh <- indexedResp{idx, resp}
}
})
}
Expand All @@ -80,8 +92,9 @@ func (ht *hedgedTransport) RoundTrip(req *http.Request) (*http.Response, error)
resp, err := waitResult(mainCtx, resultCh, errorCh, timeout)

switch {
case resp != nil:
return resp, nil
case resp.Resp != nil:
resultIdx = resp.Index
return resp.Resp, nil
case mainCtx.Err() != nil:
return nil, mainCtx.Err()
case err != nil:
Expand All @@ -93,7 +106,7 @@ func (ht *hedgedTransport) RoundTrip(req *http.Request) (*http.Response, error)
return nil, errOverall
}

func waitResult(ctx context.Context, resultCh <-chan *http.Response, errorCh <-chan error, timeout time.Duration) (*http.Response, error) {
func waitResult(ctx context.Context, resultCh <-chan indexedResp, errorCh <-chan error, timeout time.Duration) (indexedResp, error) {
// try to read result first before blocking on all other channels
select {
case res := <-resultCh:
Expand All @@ -107,17 +120,22 @@ func waitResult(ctx context.Context, resultCh <-chan *http.Response, errorCh <-c
return res, nil

case reqErr := <-errorCh:
return nil, reqErr
return indexedResp{}, reqErr

case <-ctx.Done():
return nil, ctx.Err()
return indexedResp{}, ctx.Err()

case <-timer.C:
return nil, nil // it's not a request timeout, it's timeout BETWEEN consecutive requests
return indexedResp{}, nil // it's not a request timeout, it's timeout BETWEEN consecutive requests
}
}
}

type indexedResp struct {
Index int
Resp *http.Response
}

func reqWithCtx(r *http.Request, ctx context.Context) (*http.Request, func()) {
ctx, cancel := context.WithCancel(ctx)
req := r.WithContext(ctx)
Expand Down
38 changes: 22 additions & 16 deletions hedged_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ import (
"fmt"
"io"
"io/ioutil"
"math/rand"
"net/http"
"net/http/httptest"
"strings"
Expand All @@ -16,10 +15,10 @@ import (
)

func TestUpto(t *testing.T) {
gotRequests := 0
var gotRequests int64

url := testServerURL(t, func(w http.ResponseWriter, r *http.Request) {
gotRequests++
atomic.AddInt64(&gotRequests, 1)
time.Sleep(100 * time.Millisecond)
})

Expand All @@ -31,17 +30,17 @@ func TestUpto(t *testing.T) {
const upto = 7
_, _ = NewClient(10*time.Millisecond, upto, nil).Do(req)

if gotRequests != upto {
if gotRequests := atomic.LoadInt64(&gotRequests); gotRequests != upto {
t.Fatalf("want %v, got %v", upto, gotRequests)
}
}

func TestNoTimeout(t *testing.T) {
const sleep = 10 * time.Millisecond
var gotRequests = 0
var gotRequests int64

url := testServerURL(t, func(w http.ResponseWriter, r *http.Request) {
gotRequests++
atomic.AddInt64(&gotRequests, 1)
time.Sleep(sleep)
})

Expand All @@ -60,7 +59,7 @@ func TestNoTimeout(t *testing.T) {
if float64(passed) > want {
t.Fatalf("want %v, got %v", time.Duration(want), passed)
}
if gotRequests != upto {
if gotRequests := atomic.LoadInt64(&gotRequests); gotRequests != upto {
t.Fatalf("want %v, got %v", upto, gotRequests)
}
}
Expand Down Expand Up @@ -92,26 +91,33 @@ func TestFirstIsOK(t *testing.T) {
}

func TestBestResponse(t *testing.T) {
timeout := []time.Duration{7000 * time.Millisecond, 100 * time.Millisecond, 20 * time.Millisecond}
shortest := shortestFrom(timeout)
const shortest = 20 * time.Millisecond
timeouts := [...]time.Duration{30 * shortest, 5 * shortest, shortest, shortest, shortest}
timeoutCh := make(chan time.Duration, len(timeouts))
for _, t := range timeouts {
timeoutCh <- t
}

url := testServerURL(t, func(w http.ResponseWriter, r *http.Request) {
time.Sleep(timeout[rand.Int()%len(timeout)])
time.Sleep(<-timeoutCh)
})

start := time.Now()

ctx, cancel := context.WithCancel(context.Background())
defer cancel()

req, err := http.NewRequestWithContext(ctx, "GET", url, http.NoBody)
if err != nil {
t.Fatal(err)
}

start := time.Now()
_, _ = NewClient(10*time.Millisecond, 10, nil).Do(req)
_, err = NewClient(10*time.Millisecond, 5, nil).Do(req)
if err != nil {
t.Fatal(err)
}
passed := time.Since(start)

if float64(passed) > float64(shortest)*1.2 {
if float64(passed) > float64(shortest)*2.5 {
t.Fatalf("want %v, got %v", shortest, passed)
}
}
Expand Down Expand Up @@ -197,8 +203,8 @@ func TestHangAllExceptLast(t *testing.T) {
defer close(blockCh)

url := testServerURL(t, func(w http.ResponseWriter, r *http.Request) {
gotRequests++
if gotRequests == upto {
idx := atomic.AddUint64(&gotRequests, 1)
if idx == upto {
time.Sleep(100 * time.Millisecond)
return
}
Expand Down

0 comments on commit 78f85d7

Please sign in to comment.