From 28f49cd50d5d9f1364745d9fe8873e3b22c9e177 Mon Sep 17 00:00:00 2001 From: "Gerasimos (Makis) Maropoulos" Date: Tue, 26 Sep 2023 21:14:57 +0300 Subject: [PATCH] improve cache handler, embracing #2210 too --- HISTORY.md | 8 ++ .../response-writer/cache/simple/main.go | 4 + cache/cache.go | 33 ++++- cache/cache_test.go | 10 +- cache/client/handler.go | 122 +++++++++-------- cache/entry/entry.go | 125 ++++-------------- cache/entry/pool.go | 69 ++++++++++ cache/entry/response.go | 55 +++++++- cache/entry/store.go | 52 ++++++++ cache/entry/util.go | 24 ---- context/context.go | 2 +- {sessions => core/memstore}/lifetime.go | 66 +++++++-- sessions/database.go | 6 +- sessions/session.go | 2 +- sessions/sessiondb/badger/database.go | 7 +- sessions/sessiondb/boltdb/database.go | 7 +- sessions/sessiondb/redis/database.go | 7 +- 17 files changed, 382 insertions(+), 217 deletions(-) create mode 100644 cache/entry/pool.go create mode 100644 cache/entry/store.go delete mode 100644 cache/entry/util.go rename {sessions => core/memstore}/lifetime.go (58%) diff --git a/HISTORY.md b/HISTORY.md index 3695c594a..4a6b5a18c 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -23,6 +23,14 @@ Developers are not forced to upgrade if they don't really need it. Upgrade whene Changes apply to `main` branch. +- The `cache` sub-package has an update, after 4 years: + + - Add support for custom storage on `cache` package, through the `Handler#Store` method. + - Add support for custom expiration duration on `cache` package, trough the `Handler#MaxAge` method. + - Improve the overral performance of the `cache` package. + - The `cache.Handler` input and output arguments remain as it is. + - The `cache.Cache` input argument changed from `time.Duration` to `func(iris.Context) time.Duration`. + # Mon, 25 Sep 2023 | v12.2.7 Minor bug fixes and support of multiple `block` and `define` directives in multiple layouts and templates in the `Blocks` view engine. diff --git a/_examples/response-writer/cache/simple/main.go b/_examples/response-writer/cache/simple/main.go index 03ff31ee3..9617712cb 100644 --- a/_examples/response-writer/cache/simple/main.go +++ b/_examples/response-writer/cache/simple/main.go @@ -29,6 +29,10 @@ func main() { app := iris.New() app.Logger().SetLevel("debug") app.Get("/", cache.Handler(10*time.Second), writeMarkdown) + // To customize the cache handler: + // cache.Cache(nil).MaxAge(func(ctx iris.Context) time.Duration { + // return time.Duration(ctx.MaxAge()) * time.Second + // }).AddRule(...).Store(...) // saves its content on the first request and serves it instead of re-calculating the content. // After 10 seconds it will be cleared and reset. diff --git a/cache/cache.go b/cache/cache.go index 3031d2d57..96950632c 100644 --- a/cache/cache.go +++ b/cache/cache.go @@ -46,8 +46,25 @@ func WithKey(key string) context.Handler { } } +// DefaultMaxAge is a function which returns the +// `context#MaxAge` as time.Duration. +// It's the default expiration function for the cache handler. +var DefaultMaxAge = func(ctx *context.Context) time.Duration { + return time.Duration(ctx.MaxAge()) * time.Second +} + +// MaxAge is a shortcut to set a simple duration as a MaxAgeFunc. +// +// Usage: +// app.Get("/", cache.Cache(cache.MaxAge(1*time.Minute), mainHandler) +func MaxAge(dur time.Duration) client.MaxAgeFunc { + return func(*context.Context) time.Duration { + return dur + } +} + // Cache accepts the cache expiration duration. -// If the "expiration" input argument is invalid, <=2 seconds, +// If the "maxAgeFunc" input argument is nil, // then expiration is taken by the "cache-control's maxage" header. // Returns a Handler structure which you can use to customize cache further. // @@ -57,8 +74,12 @@ func WithKey(key string) context.Handler { // may be more suited to your needs. // // You can add validators with this function. -func Cache(expiration time.Duration) *client.Handler { - return client.NewHandler(expiration) +func Cache(maxAgeFunc client.MaxAgeFunc) *client.Handler { + if maxAgeFunc == nil { + maxAgeFunc = DefaultMaxAge + } + + return client.NewHandler(maxAgeFunc) } // Handler like `Cache` but returns an Iris Handler to be used as a middleware. @@ -66,6 +87,10 @@ func Cache(expiration time.Duration) *client.Handler { // // Examples can be found at: https://github.com/kataras/iris/tree/main/_examples/response-writer/cache func Handler(expiration time.Duration) context.Handler { - h := Cache(expiration).ServeHTTP + maxAgeFunc := func(*context.Context) time.Duration { + return expiration + } + + h := Cache(maxAgeFunc).ServeHTTP return h } diff --git a/cache/cache_test.go b/cache/cache_test.go index 619ad5982..a71b50e57 100644 --- a/cache/cache_test.go +++ b/cache/cache_test.go @@ -162,10 +162,10 @@ func TestCacheValidator(t *testing.T) { ctx.Write([]byte(expectedBodyStr)) } - validCache := cache.Cache(cacheDuration) - app.Get("/", validCache.ServeHTTP, h) + validCache := cache.Handler(cacheDuration) + app.Get("/", validCache, h) - managedCache := cache.Cache(cacheDuration) + managedCache := cache.Cache(cache.MaxAge(cacheDuration)) managedCache.AddRule(rule.Validator([]rule.PreValidator{ func(ctx *context.Context) bool { // should always invalid for cache, don't bother to go to try to get or set cache @@ -173,7 +173,7 @@ func TestCacheValidator(t *testing.T) { }, }, nil)) - managedCache2 := cache.Cache(cacheDuration) + managedCache2 := cache.Cache(cache.MaxAge(cacheDuration)) managedCache2.AddRule(rule.Validator(nil, []rule.PostValidator{ func(ctx *context.Context) bool { @@ -183,7 +183,7 @@ func TestCacheValidator(t *testing.T) { }, )) - app.Get("/valid", validCache.ServeHTTP, h) + app.Get("/valid", validCache, h) app.Get("/invalid", managedCache.ServeHTTP, h) app.Get("/invalid2", managedCache2.ServeHTTP, func(ctx *context.Context) { diff --git a/cache/client/handler.go b/cache/client/handler.go index b28056854..6e4edd0a6 100644 --- a/cache/client/handler.go +++ b/cache/client/handler.go @@ -1,8 +1,8 @@ package client import ( + "net/http" "strings" - "sync" "time" "github.com/kataras/iris/v12/cache/client/rule" @@ -23,19 +23,23 @@ type Handler struct { // See more at ruleset.go rule rule.Rule // when expires. - expiration time.Duration + maxAgeFunc MaxAgeFunc // entries the memory cache stored responses. - entries map[string]*entry.Entry - mu sync.RWMutex + entryPool *entry.Pool + entryStore entry.Store } +type MaxAgeFunc func(*context.Context) time.Duration + // NewHandler returns a new cached handler for the "bodyHandler" // which expires every "expiration". -func NewHandler(expiration time.Duration) *Handler { +func NewHandler(maxAgeFunc MaxAgeFunc) *Handler { return &Handler{ rule: DefaultRuleSet, - expiration: expiration, - entries: make(map[string]*entry.Entry), + maxAgeFunc: maxAgeFunc, + + entryPool: entry.NewPool(), + entryStore: entry.NewMemStore(), } } @@ -64,14 +68,20 @@ func (h *Handler) AddRule(r rule.Rule) *Handler { return h } -var emptyHandler = func(ctx *context.Context) { - ctx.StopWithText(500, "cache: empty body handler") +// Store sets a custom store for this handler. +func (h *Handler) Store(store entry.Store) *Handler { + h.entryStore = store + return h } -func parseLifeChanger(ctx *context.Context) entry.LifeChanger { - return func() time.Duration { - return time.Duration(ctx.MaxAge()) * time.Second - } +// MaxAge customizes the expiration duration for this handler. +func (h *Handler) MaxAge(fn MaxAgeFunc) *Handler { + h.maxAgeFunc = fn + return h +} + +var emptyHandler = func(ctx *context.Context) { + ctx.StopWithText(500, "cache: empty body handler") } const entryKeyContextKey = "iris.cache.server.entry.key" @@ -133,33 +143,10 @@ func (h *Handler) ServeHTTP(ctx *context.Context) { return } - var ( - response *entry.Response - valid = false - // unique per subdomains and paths with different url query. - key = getOrSetKey(ctx) - ) - - h.mu.RLock() - e, found := h.entries[key] - h.mu.RUnlock() - - if found { - // the entry is here, .Response will give us - // if it's expired or no - response, valid = e.Response() - } else { - // create the entry now. - // fmt.Printf("create new cache entry\n") - // fmt.Printf("key: %s\n", key) - - e = entry.NewEntry(h.expiration) - h.mu.Lock() - h.entries[key] = e - h.mu.Unlock() - } + key := getOrSetKey(ctx) // unique per subdomains and paths with different url query. - if !valid { + e := h.entryStore.Get(key) + if e == nil { // if it's expired, then execute the original handler // with our custom response recorder response writer // because the net/http doesn't give us @@ -182,30 +169,61 @@ func (h *Handler) ServeHTTP(ctx *context.Context) { return } - // check for an expiration time if the - // given expiration was not valid then check for GetMaxAge & - // update the response & release the recorder - e.Reset( - recorder.StatusCode(), - recorder.Header(), - body, - parseLifeChanger(ctx), - ) - // fmt.Printf("reset cache entry\n") // fmt.Printf("key: %s\n", key) // fmt.Printf("content type: %s\n", recorder.Header().Get(cfg.ContentTypeHeader)) // fmt.Printf("body len: %d\n", len(body)) + + r := entry.NewResponse(recorder.StatusCode(), recorder.Header(), body) + e = h.entryPool.Acquire(h.maxAgeFunc(ctx), r, func() { + h.entryStore.Delete(key) + }) + + h.entryStore.Set(key, e) return } // if it's valid then just write the cached results - entry.CopyHeaders(ctx.ResponseWriter().Header(), response.Headers()) + r := e.Response() + // if !ok { + // // it shouldn't be happen because if it's not valid (= expired) + // // then it shouldn't be found on the store, we return as it is, the body was written. + // return + // } + + copyHeaders(ctx.ResponseWriter().Header(), r.Headers()) ctx.SetLastModified(e.LastModified) - ctx.StatusCode(response.StatusCode()) - ctx.Write(response.Body()) + ctx.StatusCode(r.StatusCode()) + ctx.Write(r.Body()) // fmt.Printf("key: %s\n", key) // fmt.Printf("write content type: %s\n", response.Headers()["ContentType"]) // fmt.Printf("write body len: %d\n", len(response.Body())) } + +func copyHeaders(dst, src http.Header) { + // Clone returns a copy of h or nil if h is nil. + if src == nil { + return + } + + // Find total number of values. + nv := 0 + for _, vv := range src { + nv += len(vv) + } + + sv := make([]string, nv) // shared backing array for headers' values + for k, vv := range src { + if vv == nil { + // Preserve nil values. ReverseProxy distinguishes + // between nil and zero-length header values. + dst[k] = nil + continue + } + + n := copy(sv, vv) + dst[k] = sv[:n:n] + sv = sv[n:] + } +} diff --git a/cache/entry/entry.go b/cache/entry/entry.go index f03b8342a..4f7036d2a 100644 --- a/cache/entry/entry.go +++ b/cache/entry/entry.go @@ -3,15 +3,14 @@ package entry import ( "time" - "github.com/kataras/iris/v12/cache/cfg" + "github.com/kataras/iris/v12/core/memstore" ) // Entry is the cache entry // contains the expiration datetime and the response type Entry struct { - life time.Duration // ExpiresAt is the time which this cache will not be available - expiresAt time.Time + lifeTime *memstore.LifeTime // when `Reset` this value is reseting to time.Now(), // it's used to send the "Last-Modified" header, @@ -25,104 +24,30 @@ type Entry struct { // of store map } -// NewEntry returns a new cache entry -// it doesn't sets the expiresAt & the response -// because these are setting each time on Reset -func NewEntry(duration time.Duration) *Entry { - // if given duration is not <=0 (which means finds from the headers) - // then we should check for the MinimumCacheDuration here - if duration >= 0 && duration < cfg.MinimumCacheDuration { - duration = cfg.MinimumCacheDuration - } - - return &Entry{ - life: duration, - response: &Response{}, - } -} - -// Response gets the cache response contents -// if it's valid returns them with a true value -// otherwise returns nil, false -func (e *Entry) Response() (*Response, bool) { - if !e.valid() { - // it has been expired - return nil, false - } - return e.response, true +// reset called each time a new entry is acquired from the pool. +func (e *Entry) reset(lt *memstore.LifeTime, r *Response) { + e.response = r + e.LastModified = lt.Begun } -// valid returns true if this entry's response is still valid -// or false if the expiration time passed -func (e *Entry) valid() bool { - return !time.Now().After(e.expiresAt) +// Response returns the cached response as it's. +func (e *Entry) Response() *Response { + return e.response } -// LifeChanger is the function which returns -// a duration which will be compared with the current -// entry's (cache life) duration -// and execute the LifeChanger func -// to set the new life time -type LifeChanger func() time.Duration - -// ChangeLifetime modifies the life field -// which is the life duration of the cached response -// of this cache entry -// -// useful when we find a max-age header from the handler -func (e *Entry) ChangeLifetime(fdur LifeChanger) { - if e.life < cfg.MinimumCacheDuration { - newLifetime := fdur() - if newLifetime > e.life { - e.life = newLifetime - } else { - // if even the new lifetime is less than MinimumCacheDuration - // then change set it explicitly here - e.life = cfg.MinimumCacheDuration - } - } -} - -// CopyHeaders clones headers "src" to "dst" . -func CopyHeaders(dst map[string][]string, src map[string][]string) { - if dst == nil || src == nil { - return - } - - for k, vv := range src { - v := make([]string, len(vv)) - copy(v, vv) - dst[k] = v - } -} - -// Reset called each time the entry is expired -// and the handler calls this after the original handler executed -// to re-set the response with the new handler's content result -func (e *Entry) Reset(statusCode int, headers map[string][]string, - body []byte, lifeChanger LifeChanger) { - if e.response == nil { - e.response = &Response{} - } - if statusCode > 0 { - e.response.statusCode = statusCode - } - - if len(headers) > 0 { - newHeaders := make(map[string][]string, len(headers)) - CopyHeaders(newHeaders, headers) - e.response.headers = newHeaders - } - - e.response.body = make([]byte, len(body)) - copy(e.response.body, body) - // check if a given life changer provided - // and if it does then execute the change life time - if lifeChanger != nil { - e.ChangeLifetime(lifeChanger) - } - - now := time.Now() - e.expiresAt = now.Add(e.life) - e.LastModified = now -} +// // Response gets the cache response contents +// // if it's valid returns them with a true value +// // otherwise returns nil, false +// func (e *Entry) Response() (*Response, bool) { +// if !e.isValid() { +// // it has been expired +// return nil, false +// } +// return e.response, true +// } + +// // isValid reports whether this entry's response is still valid or expired. +// // If the entry exists in the store then it should be valid anyways. +// func (e *Entry) isValid() bool { +// return !e.lifeTime.HasExpired() +// } diff --git a/cache/entry/pool.go b/cache/entry/pool.go new file mode 100644 index 000000000..0a0b039d9 --- /dev/null +++ b/cache/entry/pool.go @@ -0,0 +1,69 @@ +package entry + +import ( + "sync" + "time" + + "github.com/kataras/iris/v12/cache/cfg" + "github.com/kataras/iris/v12/core/memstore" +) + +// Pool is the context pool, it's used inside router and the framework by itself. +type Pool struct { + pool *sync.Pool +} + +// NewPool creates and returns a new context pool. +func NewPool() *Pool { + return &Pool{pool: &sync.Pool{New: func() any { return &Entry{} }}} +} + +// func NewPool(newFunc func() any) *Pool { +// return &Pool{pool: &sync.Pool{New: newFunc}} +// } + +// Acquire returns an Entry from pool. +// See Release. +func (c *Pool) Acquire(lifeDuration time.Duration, r *Response, onExpire func()) *Entry { + // If the given duration is not <=0 (which means finds from the headers) + // then we should check for the MinimumCacheDuration here + if lifeDuration >= 0 && lifeDuration < cfg.MinimumCacheDuration { + lifeDuration = cfg.MinimumCacheDuration + } + + e := c.pool.Get().(*Entry) + + lt := memstore.NewLifeTime() + lt.Begin(lifeDuration, func() { + onExpire() + c.release(e) + }) + + e.reset(lt, r) + return e +} + +// Release puts an Entry back to its pull, this function releases its resources. +// See Acquire. +func (c *Pool) release(e *Entry) { + e.response.body = nil + e.response.headers = nil + e.response.statusCode = 0 + e.response = nil + + // do not call it, it contains a lock too, release is controlled only inside the Acquire itself when the entry is expired. + // if e.lifeTime != nil { + // e.lifeTime.ExpireNow() // stop any opening timers if force released. + // } + + c.pool.Put(e) +} + +// Release can be called by custom stores to release an entry. +func (c *Pool) Release(e *Entry) { + if e.lifeTime != nil { + e.lifeTime.ExpireNow() // stop any opening timers if force released. + } + + c.release(e) +} diff --git a/cache/entry/response.go b/cache/entry/response.go index 5e103dcb5..33443832d 100644 --- a/cache/entry/response.go +++ b/cache/entry/response.go @@ -1,6 +1,9 @@ package entry -import "net/http" +import ( + "io" + "net/http" +) // Response is the cached response will be send to the clients // its fields set at runtime on each of the non-cached executions @@ -15,11 +18,28 @@ type Response struct { headers http.Header } +// NewResponse returns a new cached Response. +func NewResponse(statusCode int, headers http.Header, body []byte) *Response { + r := new(Response) + + r.SetStatusCode(statusCode) + r.SetHeaders(headers) + r.SetBody(body) + + return r +} + +// SetStatusCode sets a valid status code. +func (r *Response) SetStatusCode(statusCode int) { + if statusCode <= 0 { + statusCode = http.StatusOK + } + + r.statusCode = statusCode +} + // StatusCode returns a valid status code. func (r *Response) StatusCode() int { - if r.statusCode <= 0 { - r.statusCode = 200 - } return r.statusCode } @@ -31,12 +51,39 @@ func (r *Response) StatusCode() int { // return r.contentType // } +// SetHeaders sets a clone of headers of the cached response. +func (r *Response) SetHeaders(h http.Header) { + r.headers = h.Clone() +} + // Headers returns the total headers of the cached response. func (r *Response) Headers() http.Header { return r.headers } +// SetBody consumes "b" and sets the body of the cached response. +func (r *Response) SetBody(body []byte) { + r.body = make([]byte, len(body)) + copy(r.body, body) +} + // Body returns contents will be served by the cache handler. func (r *Response) Body() []byte { return r.body } + +// Read implements the io.Reader interface. +func (r *Response) Read(b []byte) (int, error) { + if len(r.body) == 0 { + return 0, io.EOF + } + + n := copy(b, r.body) + r.body = r.body[n:] + return n, nil +} + +// Bytes returns a copy of the cached response body. +func (r *Response) Bytes() []byte { + return append([]byte(nil), r.body...) +} diff --git a/cache/entry/store.go b/cache/entry/store.go new file mode 100644 index 000000000..8c9d1bfdd --- /dev/null +++ b/cache/entry/store.go @@ -0,0 +1,52 @@ +package entry + +import ( + "sync" +) + +// Store is the interface which is responsible to store the cache entries. +type Store interface { + // Get returns an entry based on its key. + Get(key string) *Entry + // Set sets an entry based on its key. + Set(key string, e *Entry) + // Delete deletes an entry based on its key. + Delete(key string) +} + +// memStore is the default in-memory store for the cache entries. +type memStore struct { + entries map[string]*Entry + mu sync.RWMutex +} + +var _ Store = (*memStore)(nil) + +// NewMemStore returns a new in-memory store for the cache entries. +func NewMemStore() Store { + return &memStore{ + entries: make(map[string]*Entry), + } +} + +// Get returns an entry based on its key. +func (s *memStore) Get(key string) *Entry { + s.mu.RLock() + e := s.entries[key] + s.mu.RUnlock() + return e +} + +// Set sets an entry based on its key. +func (s *memStore) Set(key string, e *Entry) { + s.mu.Lock() + s.entries[key] = e + s.mu.Unlock() +} + +// Delete deletes an entry based on its key. +func (s *memStore) Delete(key string) { + s.mu.Lock() + delete(s.entries, key) + s.mu.Unlock() +} diff --git a/cache/entry/util.go b/cache/entry/util.go deleted file mode 100644 index df9d5b1a3..000000000 --- a/cache/entry/util.go +++ /dev/null @@ -1,24 +0,0 @@ -package entry - -import ( - "regexp" - "strconv" -) - -var maxAgeExp = regexp.MustCompile(`maxage=(\d+)`) - -// ParseMaxAge parses the max age from the receiver parameter, "cache-control" header -// returns seconds as int64 -// if header not found or parse failed then it returns -1 -func ParseMaxAge(header string) int64 { - if header == "" { - return -1 - } - m := maxAgeExp.FindStringSubmatch(header) - if len(m) == 2 { - if v, err := strconv.Atoi(m[1]); err == nil { - return int64(v) - } - } - return -1 -} diff --git a/context/context.go b/context/context.go index df9ccf206..871309634 100644 --- a/context/context.go +++ b/context/context.go @@ -5865,7 +5865,7 @@ func (ctx *Context) GetRequestCookie(name string, options ...CookieOption) (*htt var ( // CookieExpireDelete may be set on Cookie.Expire for expiring the given cookie. - CookieExpireDelete = time.Date(2009, time.November, 10, 23, 0, 0, 0, time.UTC) + CookieExpireDelete = memstore.ExpireDelete // CookieExpireUnlimited indicates that does expires after 24 years. CookieExpireUnlimited = time.Now().AddDate(24, 10, 10) diff --git a/sessions/lifetime.go b/core/memstore/lifetime.go similarity index 58% rename from sessions/lifetime.go rename to core/memstore/lifetime.go index b92c4d8d6..070f372f1 100644 --- a/sessions/lifetime.go +++ b/core/memstore/lifetime.go @@ -1,10 +1,19 @@ -package sessions +package memstore import ( "sync" "time" +) + +var ( + // Clock is the default clock to get the current time, + // it can be used for testing purposes too. + // + // Defaults to time.Now. + Clock func() time.Time = time.Now - "github.com/kataras/iris/v12/context" + // ExpireDelete may be set on Cookie.Expire for expiring the given cookie. + ExpireDelete = time.Date(2009, time.November, 10, 23, 0, 0, 0, time.UTC) ) // LifeTime controls the session expiration datetime. @@ -17,18 +26,29 @@ type LifeTime struct { time.Time timer *time.Timer + // StartTime holds the Now of the Begin. + Begun time.Time + mu sync.RWMutex } -// Begin will begin the life based on the time.Now().Add(d). +// NewLifeTime returns a pointer to an empty LifeTime instance. +func NewLifeTime() *LifeTime { + return &LifeTime{} +} + +// Begin will begin the life based on the Clock (time.Now()).Add(d). // Use `Continue` to continue from a stored time(database-based session does that). func (lt *LifeTime) Begin(d time.Duration, onExpire func()) { if d <= 0 { return } + now := Clock() + lt.mu.Lock() - lt.Time = time.Now().Add(d) + lt.Begun = now + lt.Time = now.Add(d) lt.timer = time.AfterFunc(d, onExpire) lt.mu.Unlock() } @@ -36,15 +56,23 @@ func (lt *LifeTime) Begin(d time.Duration, onExpire func()) { // Revive will continue the life based on the stored Time. // Other words that could be used for this func are: Continue, Restore, Resc. func (lt *LifeTime) Revive(onExpire func()) { - if lt.Time.IsZero() { + lt.mu.RLock() + t := lt.Time + lt.mu.RUnlock() + + if t.IsZero() { return } - now := time.Now() - if lt.Time.After(now) { - d := lt.Time.Sub(now) + now := Clock() + if t.After(now) { + d := t.Sub(now) lt.mu.Lock() - lt.timer = time.AfterFunc(d, onExpire) + if lt.timer != nil { + lt.timer.Stop() // Stop the existing timer, if any. + } + lt.Begun = now + lt.timer = time.AfterFunc(d, onExpire) // and execute on-time the new onExpire function. lt.mu.Unlock() } } @@ -53,7 +81,9 @@ func (lt *LifeTime) Revive(onExpire func()) { func (lt *LifeTime) Shift(d time.Duration) { lt.mu.Lock() if d > 0 && lt.timer != nil { - lt.Time = time.Now().Add(d) + now := Clock() + lt.Begun = now + lt.Time = now.Add(d) lt.timer.Reset(d) } lt.mu.Unlock() @@ -62,7 +92,7 @@ func (lt *LifeTime) Shift(d time.Duration) { // ExpireNow reduce the lifetime completely. func (lt *LifeTime) ExpireNow() { lt.mu.Lock() - lt.Time = context.CookieExpireDelete + lt.Time = ExpireDelete if lt.timer != nil { lt.timer.Stop() } @@ -71,15 +101,23 @@ func (lt *LifeTime) ExpireNow() { // HasExpired reports whether "lt" represents is expired. func (lt *LifeTime) HasExpired() bool { - if lt.IsZero() { + lt.mu.RLock() + t := lt.Time + lt.mu.RUnlock() + + if t.IsZero() { return false } - return lt.Time.Before(time.Now()) + return t.Before(Clock()) } // DurationUntilExpiration returns the duration until expires, it can return negative number if expired, // a call to `HasExpired` may be useful before calling this `Dur` function. func (lt *LifeTime) DurationUntilExpiration() time.Duration { - return time.Until(lt.Time) + lt.mu.RLock() + t := lt.Time + lt.mu.RUnlock() + + return t.Sub(Clock()) } diff --git a/sessions/database.go b/sessions/database.go index 071d292bd..055e4d68c 100644 --- a/sessions/database.go +++ b/sessions/database.go @@ -30,7 +30,7 @@ type Database interface { SetLogger(*golog.Logger) // Acquire receives a session's lifetime from the database, // if the return value is LifeTime{} then the session manager sets the life time based on the expiration duration lives in configuration. - Acquire(sid string, expires time.Duration) LifeTime + Acquire(sid string, expires time.Duration) memstore.LifeTime // OnUpdateExpiration should re-set the expiration (ttl) of the session entry inside the database, // it is fired on `ShiftExpiration` and `UpdateExpiration`. // If the database does not support change of ttl then the session entry will be cloned to another one @@ -81,11 +81,11 @@ func newMemDB() Database { return &mem{values: make(map[string]*memstore.Store)} func (s *mem) SetLogger(*golog.Logger) {} -func (s *mem) Acquire(sid string, expires time.Duration) LifeTime { +func (s *mem) Acquire(sid string, expires time.Duration) memstore.LifeTime { s.mu.Lock() s.values[sid] = new(memstore.Store) s.mu.Unlock() - return LifeTime{} + return memstore.LifeTime{} } // Do nothing, the `LifeTime` of the Session will be managed by the callers automatically on memory-based storage. diff --git a/sessions/session.go b/sessions/session.go index 068b86537..4fa3e8968 100644 --- a/sessions/session.go +++ b/sessions/session.go @@ -21,7 +21,7 @@ type ( mu sync.RWMutex // for flashes. // Lifetime it contains the expiration data, use it for read-only information. // See `Sessions.UpdateExpiration` too. - Lifetime *LifeTime + Lifetime *memstore.LifeTime // Man is the sessions manager that this session created of. Man *Sessions diff --git a/sessions/sessiondb/badger/database.go b/sessions/sessiondb/badger/database.go index 2864fc4f3..dc49e69d9 100644 --- a/sessions/sessiondb/badger/database.go +++ b/sessions/sessiondb/badger/database.go @@ -8,6 +8,7 @@ import ( "time" "github.com/kataras/iris/v12/context" + "github.com/kataras/iris/v12/core/memstore" "github.com/kataras/iris/v12/sessions" "github.com/dgraph-io/badger/v2" @@ -82,7 +83,7 @@ func (db *Database) SetLogger(logger *golog.Logger) { // Acquire receives a session's lifetime from the database, // if the return value is LifeTime{} then the session manager sets the life time based on the expiration duration lives in configuration. -func (db *Database) Acquire(sid string, expires time.Duration) sessions.LifeTime { +func (db *Database) Acquire(sid string, expires time.Duration) memstore.LifeTime { txn := db.Service.NewTransaction(true) defer txn.Commit() @@ -90,7 +91,7 @@ func (db *Database) Acquire(sid string, expires time.Duration) sessions.LifeTime item, err := txn.Get(bsid) if err == nil { // found, return the expiration. - return sessions.LifeTime{Time: time.Unix(int64(item.ExpiresAt()), 0)} + return memstore.LifeTime{Time: time.Unix(int64(item.ExpiresAt()), 0)} } // not found, create an entry with ttl and return an empty lifetime, session manager will do its job. @@ -105,7 +106,7 @@ func (db *Database) Acquire(sid string, expires time.Duration) sessions.LifeTime db.logger.Error(err) } - return sessions.LifeTime{} // session manager will handle the rest. + return memstore.LifeTime{} // session manager will handle the rest. } // OnUpdateExpiration not implemented here, yet. diff --git a/sessions/sessiondb/boltdb/database.go b/sessions/sessiondb/boltdb/database.go index 11b667e9b..92dfeb3b5 100644 --- a/sessions/sessiondb/boltdb/database.go +++ b/sessions/sessiondb/boltdb/database.go @@ -6,6 +6,7 @@ import ( "path/filepath" "time" + "github.com/kataras/iris/v12/core/memstore" "github.com/kataras/iris/v12/sessions" "github.com/kataras/golog" @@ -159,7 +160,7 @@ var expirationKey = []byte("exp") // it can be random. // Acquire receives a session's lifetime from the database, // if the return value is LifeTime{} then the session manager sets the life time based on the expiration duration lives in configuration. -func (db *Database) Acquire(sid string, expires time.Duration) (lifetime sessions.LifeTime) { +func (db *Database) Acquire(sid string, expires time.Duration) (lifetime memstore.LifeTime) { bsid := []byte(sid) err := db.Service.Update(func(tx *bolt.Tx) (err error) { root := db.getBucket(tx) @@ -204,7 +205,7 @@ func (db *Database) Acquire(sid string, expires time.Duration) (lifetime session return } - lifetime = sessions.LifeTime{Time: expirationTime} + lifetime = memstore.LifeTime{Time: expirationTime} return nil } @@ -214,7 +215,7 @@ func (db *Database) Acquire(sid string, expires time.Duration) (lifetime session }) if err != nil { db.logger.Debugf("unable to acquire session '%s': %v", sid, err) - return sessions.LifeTime{} + return memstore.LifeTime{} } return diff --git a/sessions/sessiondb/redis/database.go b/sessions/sessiondb/redis/database.go index 3307b1af4..230bb0de6 100644 --- a/sessions/sessiondb/redis/database.go +++ b/sessions/sessiondb/redis/database.go @@ -6,6 +6,7 @@ import ( "fmt" "time" + "github.com/kataras/iris/v12/core/memstore" "github.com/kataras/iris/v12/sessions" "github.com/kataras/golog" @@ -138,7 +139,7 @@ const SessionIDKey = "session_id" // Acquire receives a session's lifetime from the database, // if the return value is LifeTime{} then the session manager sets the life time based on the expiration duration lives in configuration. -func (db *Database) Acquire(sid string, expires time.Duration) sessions.LifeTime { +func (db *Database) Acquire(sid string, expires time.Duration) memstore.LifeTime { sidKey := db.makeSID(sid) if !db.c.Driver.Exists(sidKey) { if err := db.Set(sid, SessionIDKey, sid, 0, false); err != nil { @@ -149,11 +150,11 @@ func (db *Database) Acquire(sid string, expires time.Duration) sessions.LifeTime } } - return sessions.LifeTime{} // session manager will handle the rest. + return memstore.LifeTime{} // session manager will handle the rest. } untilExpire := db.c.Driver.TTL(sidKey) - return sessions.LifeTime{Time: time.Now().Add(untilExpire)} + return memstore.LifeTime{Time: time.Now().Add(untilExpire)} } // OnUpdateExpiration will re-set the database's session's entry ttl.