diff --git a/httpcache.go b/httpcache.go index b41a63d..4fe5322 100644 --- a/httpcache.go +++ b/httpcache.go @@ -9,6 +9,7 @@ package httpcache import ( "bufio" "bytes" + "context" "errors" "io" "io/ioutil" @@ -19,12 +20,17 @@ import ( "time" ) +type Freshness string + const ( - stale = iota - fresh - transparent + stale Freshness = "stale" + staleWhileRevalidate Freshness = "stale-while-revalidate" + fresh Freshness = "fresh" + transparent Freshness = "transparent" + // XFromCache is the header added to responses that are returned from the cache XFromCache = "X-From-Cache" + XFreshness = "X-Cache-Freshness" ) // A Cache interface is used by the Transport to store and retrieve responses. @@ -103,6 +109,8 @@ type Transport struct { Cache Cache // If true, responses returned from the cache will be given an extra header, X-From-Cache MarkCachedResponses bool + // Context timeout for async requests triggered by stale-while-revalidate + AsyncRevalidateTimeout time.Duration } // NewTransport returns a new Transport with the @@ -160,8 +168,34 @@ func (t *Transport) RoundTrip(req *http.Request) (resp *http.Response, err error if varyMatches(cachedResp, req) { // Can only use cached value if the new request doesn't Vary significantly freshness := getFreshness(cachedResp.Header, req.Header) + if t.MarkCachedResponses { + cachedResp.Header.Set(XFreshness, string(freshness)) + } + if freshness == fresh { return cachedResp, nil + } else if freshness == staleWhileRevalidate { + bgContext := context.Background() + cancelContext := func() {} + if t.AsyncRevalidateTimeout > 0 { + bgContext, cancelContext = context.WithTimeout(bgContext, t.AsyncRevalidateTimeout) + } + noCacheRequest := req.Clone(bgContext) + noCacheRequest.Header.Set("cache-control", "no-cache") + go func() { + defer cancelContext() + resp, err := t.RoundTrip(noCacheRequest) + if err == nil { + defer resp.Body.Close() + buffer := make([]byte, 4096) + for { + if _, err = resp.Body.Read(buffer); err == io.EOF { + break + } + } + } + }() + return cachedResp, nil } if freshness == stale { @@ -288,7 +322,7 @@ var clock timer = &realClock{} // // Because this is only a private cache, 'public' and 'private' in cache-control aren't // signficant. Similarly, smax-age isn't used. -func getFreshness(respHeaders, reqHeaders http.Header) (freshness int) { +func getFreshness(respHeaders, reqHeaders http.Header) Freshness { respCacheControl := parseCacheControl(respHeaders) reqCacheControl := parseCacheControl(reqHeaders) if _, ok := reqCacheControl["no-cache"]; ok { @@ -366,6 +400,16 @@ func getFreshness(respHeaders, reqHeaders http.Header) (freshness int) { return fresh } + if stalewhilerevalidate, ok := respCacheControl["stale-while-revalidate"]; ok { + // If the cached response isn't too stale, we can return it and refresh asynchronously + stalewhilerevalidateDuration, err := time.ParseDuration(stalewhilerevalidate + "s") + if err == nil { + if lifetime+stalewhilerevalidateDuration > currentAge { + return staleWhileRevalidate + } + } + } + return stale } diff --git a/httpcache_test.go b/httpcache_test.go index a504641..7239332 100644 --- a/httpcache_test.go +++ b/httpcache_test.go @@ -71,6 +71,13 @@ func setup() { w.Write([]byte("Some text content")) })) + staleWhileRevalidateCounter := 0 + mux.HandleFunc("/stale-while-revalidate", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + staleWhileRevalidateCounter++ + w.Header().Set("X-Counter", strconv.Itoa(staleWhileRevalidateCounter)) + w.Header().Set("Cache-Control", "max-age=100, stale-while-revalidate=100") + })) + mux.HandleFunc("/nostore", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Cache-Control", "no-store") })) @@ -1034,6 +1041,31 @@ func TestMaxAge(t *testing.T) { } } +func TestFreshnessStaleWhileRevalidate(t *testing.T) { + resetTest() + now := time.Now() + respHeaders := http.Header{} + respHeaders.Set("date", now.Format(time.RFC1123)) + respHeaders.Set("Cache-Control", "max-age=100, stale-while-revalidate=100") + + reqHeaders := http.Header{} + + clock = &fakeClock{elapsed: 50 * time.Second} + if getFreshness(respHeaders, reqHeaders) != fresh { + t.Fatal("freshness isn't fresh") + } + + clock = &fakeClock{elapsed: 150 * time.Second} + if getFreshness(respHeaders, reqHeaders) != staleWhileRevalidate { + t.Fatal("freshness isn't staleWhileRevalidate") + } + + clock = &fakeClock{elapsed: 250 * time.Second} + if getFreshness(respHeaders, reqHeaders) != stale { + t.Fatal("freshness isn't stale") + } +} + func TestMaxAgeZero(t *testing.T) { resetTest() now := time.Now() @@ -1473,3 +1505,115 @@ func TestClientTimeout(t *testing.T) { t.Error("client.Do took 2+ seconds, want < 2 seconds") } } + +func TestStaleWhileRevalidate(t *testing.T) { + resetTest() + req, err := http.NewRequest("GET", s.server.URL+"/stale-while-revalidate", nil) + if err != nil { + t.Fatal(err) + } + var counter1 string + { + // 1st request: Not cached + resp, err := s.client.Do(req) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + + if resp.Header.Get(XFromCache) != "" { + t.Fatalf(`XFromCache header isn't absent: %v`, resp.Header.Get(XFromCache)) + } + if resp.Header.Get(XFreshness) != "" { + t.Fatalf(`X-Cache-Freshness header isn't absent: %v`, resp.Header.Get(XFreshness)) + } + + counter1 = resp.Header.Get("x-counter") + + _, err = ioutil.ReadAll(resp.Body) + if err != nil { + t.Fatal(err) + } + } + { + // 2nd request: Fresh + clock = &fakeClock{elapsed: 50 * time.Second} + resp, err := s.client.Do(req) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + + if resp.Header.Get(XFromCache) != "1" { + t.Fatalf(`XFromCache header isn't "1": %v`, resp.Header.Get(XFromCache)) + } + if resp.Header.Get(XFreshness) != string(fresh) { + t.Fatalf(`X-Cache-Freshness header isn't "%v": %v`, fresh, resp.Header.Get(XFreshness)) + } + + counter := resp.Header.Get("x-counter") + if counter1 != counter { + t.Fatalf(`"x-counter" values are different: %v %v`, counter1, counter) + } + + _, err = ioutil.ReadAll(resp.Body) + if err != nil { + t.Fatal(err) + } + } + { + // 3rd request: Stale-While-Revalidate + clock = &fakeClock{elapsed: 150 * time.Second} + resp, err := s.client.Do(req) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + + if resp.Header.Get(XFromCache) != "1" { + t.Fatalf(`XFromCache header isn't "1": %v`, resp.Header.Get(XFromCache)) + } + if resp.Header.Get(XFreshness) != string(staleWhileRevalidate) { + t.Fatalf(`X-Cache-Freshness header isn't "%v": %v`, staleWhileRevalidate, resp.Header.Get(XFreshness)) + } + + counter := resp.Header.Get("x-counter") + if counter1 != counter { + t.Fatalf(`"x-counter" values are different: %v %v`, counter1, counter) + } + + _, err = ioutil.ReadAll(resp.Body) + if err != nil { + t.Fatal(err) + } + + // Revalidate is asynchronous, make sure it completes executing + time.Sleep(1 * time.Second) + } + { + // 4th request: Return the response cached just now + clock = &fakeClock{elapsed: 50 * time.Second} + resp, err := s.client.Do(req) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + + if resp.Header.Get(XFromCache) != "1" { + t.Fatalf(`XFromCache header isn't "1": %v`, resp.Header.Get(XFromCache)) + } + if resp.Header.Get(XFreshness) != string(fresh) { + t.Fatalf(`X-Cache-Freshness header isn't "%v": %v`, fresh, resp.Header.Get(XFreshness)) + } + + counter := resp.Header.Get("x-counter") + if counter1 == counter { + t.Fatalf(`"x-counter" values are equal: %v %v`, counter1, counter) + } + + _, err = ioutil.ReadAll(resp.Body) + if err != nil { + t.Fatal(err) + } + } +}