Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

pool: add pool raciness tests, fix p.usedTotal house-keeping #1340

Merged
merged 8 commits into from
Jul 23, 2019
27 changes: 21 additions & 6 deletions pkg/pool/pool.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package pool

import (
"sync"
"sync/atomic"

"github.com/pkg/errors"
)
Expand All @@ -15,6 +14,7 @@ type BytesPool struct {
sizes []int
maxTotal uint64
usedTotal uint64
mtx sync.Mutex

new func(s int) *[]byte
}
Expand Down Expand Up @@ -55,11 +55,13 @@ var ErrPoolExhausted = errors.New("pool exhausted")

// Get returns a new byte slices that fits the given size.
func (p *BytesPool) Get(sz int) (*[]byte, error) {
used := atomic.LoadUint64(&p.usedTotal)
p.mtx.Lock()
defer p.mtx.Unlock()

if p.maxTotal > 0 && used+uint64(sz) > p.maxTotal {
if p.maxTotal > 0 && p.usedTotal+uint64(sz) > p.maxTotal {
return nil, ErrPoolExhausted
}

for i, bktSize := range p.sizes {
if sz > bktSize {
continue
Expand All @@ -68,12 +70,13 @@ func (p *BytesPool) Get(sz int) (*[]byte, error) {
if !ok {
b = p.new(bktSize)
}
atomic.AddUint64(&p.usedTotal, uint64(cap(*b)))

p.usedTotal += uint64(cap(*b))
return b, nil
}

// The requested size exceeds that of our highest bucket, allocate it directly.
atomic.AddUint64(&p.usedTotal, uint64(sz))
p.usedTotal += uint64(sz)
return p.new(sz), nil
}

Expand All @@ -82,6 +85,7 @@ func (p *BytesPool) Put(b *[]byte) {
if b == nil {
return
}

for i, bktSize := range p.sizes {
if cap(*b) > bktSize {
continue
Expand All @@ -90,5 +94,16 @@ func (p *BytesPool) Put(b *[]byte) {
p.buckets[i].Put(b)
break
}
atomic.AddUint64(&p.usedTotal, ^uint64(p.usedTotal-1))

p.mtx.Lock()
defer p.mtx.Unlock()

// We could assume here that our users will not make the slices larger
// but lets be on the safe side to avoid an underflow of p.usedTotal.
sz := uint64(cap(*b))
if sz >= p.usedTotal {
p.usedTotal = 0
} else {
p.usedTotal -= sz
}
}
70 changes: 70 additions & 0 deletions pkg/pool/pool_test.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,14 @@
package pool

import (
"bytes"
"fmt"
"sync"
"testing"
"time"

"github.com/fortytw2/leaktest"
"github.com/pkg/errors"
"github.com/thanos-io/thanos/pkg/testutil"
)

Expand Down Expand Up @@ -50,3 +56,67 @@ func TestBytesPool(t *testing.T) {

testutil.Equals(t, uint64(0), chunkPool.usedTotal)
}

func TestRacePutGet(t *testing.T) {
chunkPool, err := NewBytesPool(3, 100, 2, 5000)
testutil.Ok(t, err)
defer leaktest.CheckTimeout(t, 10*time.Second)()

s := sync.WaitGroup{}

// Start two goroutines: they always Get and Put two byte slices
// to which they write 'foo' / 'barbazbaz' and check if the data is still
// there after writing it, before putting it back
errs := make(chan error, 2)
stop := make(chan bool, 2)

f := func(txt string) {
for {
select {
case <-stop:
s.Done()
return
default:
c, err := chunkPool.Get(3)
if err != nil {
errs <- errors.Wrapf(err, "goroutine %s", txt)
s.Done()
return
}

buf := bytes.NewBuffer(*c)

_, err = fmt.Fprintf(buf, "%s", txt)
if err != nil {
errs <- errors.Wrapf(err, "goroutine %s", txt)
s.Done()
return
}

if buf.String() != txt {
errs <- errors.New("expected to get the data just written")
s.Done()
return
}

b := buf.Bytes()
chunkPool.Put(&b)
}
}
}

s.Add(2)
go f("foo")
go f("barbazbaz")

time.Sleep(5 * time.Second)
stop <- true
stop <- true

s.Wait()
select {
case err := <-errs:
testutil.Ok(t, err)
default:
}
}