Skip to content
This repository has been archived by the owner on Apr 20, 2023. It is now read-only.

Define ContextCache for caches that need a context.Context or that mi… #113

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 111 additions & 9 deletions httpcache.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@ package httpcache
import (
"bufio"
"bytes"
"context"
"errors"
"fmt"
"io"
"io/ioutil"
"net/http"
Expand Down Expand Up @@ -38,6 +40,41 @@ type Cache interface {
Delete(key string)
}

// ContextCache the same as Cache except that its functions accept a
// context.Context argument and return an additional error object.
type ContextCache interface {
// Get returns the []byte representation of a cached response and a bool
// set to true if the value isn't empty
Get(ctx context.Context, key string) (responseBytes []byte, ok bool, err error)
// Set stores the []byte representation of a response against a key
Set(ctx context.Context, key string, responseBytes []byte) error
// Delete removes the value associated with the key
Delete(ctx context.Context, key string) error
}

// cacheAsContextCache is an implementation of ContextCache that wraps a regular Cache.
type cacheAsContextCache struct{ cache Cache }

var _ ContextCache = cacheAsContextCache{}

// Delete implements ContextCache
func (c cacheAsContextCache) Delete(_ context.Context, key string) error {
c.cache.Delete(key)
return nil
}

// Get implements ContextCache
func (c cacheAsContextCache) Get(_ context.Context, key string) (responseBytes []byte, ok bool, err error) {
got, ok := c.cache.Get(key)
return got, ok, nil
}

// Set implements ContextCache
func (c cacheAsContextCache) Set(_ context.Context, key string, responseBytes []byte) error {
c.cache.Set(key, responseBytes)
return nil
}

// cacheKey returns the cache key for req.
func cacheKey(req *http.Request) string {
if req.Method == http.MethodGet {
Expand All @@ -59,6 +96,21 @@ func CachedResponse(c Cache, req *http.Request) (resp *http.Response, err error)
return http.ReadResponse(bufio.NewReader(b), req)
}

// contextCachedResponse returns the cached http.Response for req if present, and nil
// otherwise.
func contextCachedResponse(c ContextCache, req *http.Request) (resp *http.Response, err error) {
cachedVal, ok, err := c.Get(req.Context(), cacheKey(req))
if err != nil {
return nil, fmt.Errorf("httpcache Get error: %w", err)
}
if !ok {
return
}

b := bytes.NewBuffer(cachedVal)
return http.ReadResponse(bufio.NewReader(b), req)
}

// MemoryCache is an implemtation of Cache that stores responses in an in-memory map.
type MemoryCache struct {
mu sync.RWMutex
Expand Down Expand Up @@ -101,6 +153,14 @@ type Transport struct {
// If nil, http.DefaultTransport is used
Transport http.RoundTripper
Cache Cache
// ContextCache, if set, will be used instead of Cache by the transport.
//
// The Context() method of http.Request is used to obtain the
// context.Context argument for the cache. Errors from the ContextCache
// cause the Transport's RoundTrip method to return errors.
//
// If ContextCache is non-nil, Cache may be nil.
ContextCache ContextCache
// If true, responses returned from the cache will be given an extra header, X-From-Cache
MarkCachedResponses bool
}
Expand Down Expand Up @@ -141,10 +201,20 @@ func (t *Transport) RoundTrip(req *http.Request) (resp *http.Response, err error
cacheable := (req.Method == "GET" || req.Method == "HEAD") && req.Header.Get("range") == ""
var cachedResp *http.Response
if cacheable {
cachedResp, err = CachedResponse(t.Cache, req)
if t.ContextCache != nil {
cachedResp, err = contextCachedResponse(t.ContextCache, req)
} else {
cachedResp, err = CachedResponse(t.Cache, req)
}
} else {
// Need to invalidate an existing value
t.Cache.Delete(cacheKey)
if t.ContextCache != nil {
if err := t.ContextCache.Delete(req.Context(), cacheKey); err != nil {
return nil, fmt.Errorf("httpcache Delete error: %w", err)
}
} else {
t.Cache.Delete(cacheKey)
}
}

transport := t.Transport
Expand Down Expand Up @@ -200,7 +270,14 @@ func (t *Transport) RoundTrip(req *http.Request) (resp *http.Response, err error
return cachedResp, nil
} else {
if err != nil || resp.StatusCode != http.StatusOK {
t.Cache.Delete(cacheKey)
if t.ContextCache != nil {
// Don't overwrite non-nil err.
if cacheErr := t.ContextCache.Delete(req.Context(), cacheKey); err == nil {
err = cacheErr
}
} else {
t.Cache.Delete(cacheKey)
}
}
if err != nil {
return nil, err
Expand Down Expand Up @@ -232,23 +309,42 @@ func (t *Transport) RoundTrip(req *http.Request) (resp *http.Response, err error
// Delay caching until EOF is reached.
resp.Body = &cachingReadCloser{
R: resp.Body,
OnEOF: func(r io.Reader) {
OnEOF: func(r io.Reader) error {
resp := *resp
resp.Body = ioutil.NopCloser(r)
respBytes, err := httputil.DumpResponse(&resp, true)
if err == nil {
t.Cache.Set(cacheKey, respBytes)
if t.ContextCache != nil {
if err := t.ContextCache.Set(req.Context(), cacheKey, respBytes); err != nil {
return fmt.Errorf("httpcache Set error: %w", err)
}
} else {
t.Cache.Set(cacheKey, respBytes)
}
}
return err
},
}
default:
respBytes, err := httputil.DumpResponse(resp, true)
if err == nil {
t.Cache.Set(cacheKey, respBytes)
if t.ContextCache != nil {
if err := t.ContextCache.Set(req.Context(), cacheKey, respBytes); err != nil {
return nil, err
}
} else {
t.Cache.Set(cacheKey, respBytes)
}
}
}
} else {
t.Cache.Delete(cacheKey)
if t.ContextCache != nil {
if err := t.ContextCache.Delete(req.Context(), cacheKey); err != nil {
return nil, fmt.Errorf("httpcache Delete error: %w", err)
}
} else {
t.Cache.Delete(cacheKey)
}
}
return resp, nil
}
Expand Down Expand Up @@ -473,6 +569,10 @@ func cloneRequest(r *http.Request) *http.Request {
for k, s := range r.Header {
r2.Header[k] = s
}
ctx := r.Context()
if ctx != nil {
r2 = r2.WithContext(ctx)
}
return r2
}

Expand Down Expand Up @@ -521,7 +621,7 @@ type cachingReadCloser struct {
// Underlying ReadCloser.
R io.ReadCloser
// OnEOF is called with a copy of the content of R when EOF is reached.
OnEOF func(io.Reader)
OnEOF func(io.Reader) error

buf bytes.Buffer // buf stores a copy of the content of R.
}
Expand All @@ -534,7 +634,9 @@ func (r *cachingReadCloser) Read(p []byte) (n int, err error) {
n, err = r.R.Read(p)
r.buf.Write(p[:n])
if err == io.EOF {
r.OnEOF(bytes.NewReader(r.buf.Bytes()))
if err := r.OnEOF(bytes.NewReader(r.buf.Bytes())); err != nil {
return n, err
}
}
return n, err
}
Expand Down
25 changes: 17 additions & 8 deletions httpcache_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@ import (
"time"
)

var (
testContextCache = flag.Bool("test-context-cache", false, "if true, tests the functionality of the ContextCache property")
)

var s struct {
server *httptest.Server
client http.Client
Expand Down Expand Up @@ -167,6 +171,11 @@ func teardown() {

func resetTest() {
s.transport.Cache = NewMemoryCache()
if *testContextCache {
s.transport.ContextCache = cacheAsContextCache{s.transport.Cache}
} else {
s.transport.ContextCache = nil
}
clock = &realClock{}
}

Expand Down Expand Up @@ -223,8 +232,8 @@ func TestCacheableMethod(t *testing.T) {
if resp.StatusCode != http.StatusOK {
t.Errorf("response status code isn't 200 OK: %v", resp.StatusCode)
}
if resp.Header.Get(XFromCache) != "" {
t.Errorf("XFromCache header isn't blank")
if got := resp.Header.Get(XFromCache); got != "" {
t.Errorf("XFromCache header isn't blank: %q", got)
}
}
}
Expand Down Expand Up @@ -305,8 +314,8 @@ func TestDontStorePartialRangeInCache(t *testing.T) {
if resp.StatusCode != http.StatusOK {
t.Errorf("response status code isn't 200 OK: %v", resp.StatusCode)
}
if resp.Header.Get(XFromCache) != "" {
t.Error("XFromCache header isn't blank")
if got := resp.Header.Get(XFromCache); got != "" {
t.Errorf("XFromCache header isn't blank: %q", got)
}
}
{
Expand Down Expand Up @@ -469,8 +478,8 @@ func TestGetOnlyIfCachedMiss(t *testing.T) {
t.Fatal(err)
}
defer resp.Body.Close()
if resp.Header.Get(XFromCache) != "" {
t.Fatal("XFromCache header isn't blank")
if got := resp.Header.Get(XFromCache); got != "" {
t.Errorf("XFromCache header isn't blank: %q", got)
}
if resp.StatusCode != http.StatusGatewayTimeout {
t.Fatalf("response status code isn't 504 GatewayTimeout: %v", resp.StatusCode)
Expand All @@ -490,8 +499,8 @@ func TestGetNoStoreRequest(t *testing.T) {
t.Fatal(err)
}
defer resp.Body.Close()
if resp.Header.Get(XFromCache) != "" {
t.Fatal("XFromCache header isn't blank")
if got := resp.Header.Get(XFromCache); got != "" {
t.Errorf("XFromCache header isn't blank: %q", got)
}
}
{
Expand Down