diff --git a/stream.go b/stream.go index b630b78..037c22d 100644 --- a/stream.go +++ b/stream.go @@ -422,15 +422,9 @@ func (s *Stream) readData(hdr header, flags uint16, conn io.Reader) error { return nil } - // Validate it's okay to copy - if !s.recvBuf.TryReserve(length) { - s.session.logger.Printf("[ERR] yamux: receive window exceeded (stream: %d, remain: %d, recv: %d)", s.id, s.recvBuf.Cap(), length) - return ErrRecvWindowExceeded - } - // Copy into buffer - if err := s.recvBuf.Append(conn, int(length)); err != nil { - s.session.logger.Printf("[ERR] yamux: Failed to read stream data: %v", err) + if err := s.recvBuf.Append(conn, length); err != nil { + s.session.logger.Printf("[ERR] yamux: Failed to read stream data on stream %d: %v", s.id, err) return err } // Unblock the reader diff --git a/util.go b/util.go index 45da28c..177eb98 100644 --- a/util.go +++ b/util.go @@ -1,6 +1,7 @@ package yamux import ( + "fmt" "io" "sync" @@ -42,14 +43,12 @@ func min(values ...uint32) uint32 { // | data | empty space | // < window (10) > // < len (5) > < cap (5) > -// < pending (4) > // // As data is read, the buffer gets updated like so: // // | data | empty space | // < window (8) > // < len (3) > < cap (5) > -// < pending (4) > // // It can then grow as follows (given a "max" of 10): // @@ -57,21 +56,18 @@ func min(values ...uint32) uint32 { // | data | empty space | // < window (10) > // < len (3) > < cap (7) > -// < pending (4) > // -// Data can then be written into the pending space, expanding len, and shrinking -// cap and pending: +// Data can then be written into the empty space, expanding len, +// and shrinking cap: // // | data | empty space | // < window (10) > // < len (5) > < cap (5) > -// < pending (2)> // type segmentedBuffer struct { - cap uint32 - pending uint32 - len uint32 - bm sync.Mutex + cap uint32 + len uint32 + bm sync.Mutex // read position in b[0]. // We must not reslice any of the buffers in b, as we need to put them back into the pool. readPos int @@ -84,22 +80,10 @@ func newSegmentedBuffer(initialCapacity uint32) segmentedBuffer { } // Len is the amount of data in the receive buffer. -func (s *segmentedBuffer) Len() int { +func (s *segmentedBuffer) Len() uint32 { s.bm.Lock() - len := s.len - s.bm.Unlock() - return int(len) -} - -// Cap is the remaining capacity in the receive buffer. -// -// Note: this is _not_ the same as go's 'cap' function. The total size of the -// buffer is len+cap. -func (s *segmentedBuffer) Cap() uint32 { - s.bm.Lock() - cap := s.cap - s.bm.Unlock() - return cap + defer s.bm.Unlock() + return s.len } // If the space to write into + current buffer size has grown to half of the window size, @@ -122,16 +106,6 @@ func (s *segmentedBuffer) GrowTo(max uint32, force bool) (bool, uint32) { return true, delta } -func (s *segmentedBuffer) TryReserve(space uint32) bool { - s.bm.Lock() - defer s.bm.Unlock() - if s.cap < s.pending+space { - return false - } - s.pending += space - return true -} - func (s *segmentedBuffer) Read(b []byte) (int, error) { s.bm.Lock() defer s.bm.Unlock() @@ -154,8 +128,21 @@ func (s *segmentedBuffer) Read(b []byte) (int, error) { return n, nil } -func (s *segmentedBuffer) Append(input io.Reader, length int) error { - dst := pool.Get(length) +func (s *segmentedBuffer) checkOverflow(l uint32) error { + s.bm.Lock() + defer s.bm.Unlock() + if s.cap < l { + return fmt.Errorf("receive window exceeded (remain: %d, recv: %d)", s.cap, l) + } + return nil +} + +func (s *segmentedBuffer) Append(input io.Reader, length uint32) error { + if err := s.checkOverflow(length); err != nil { + return err + } + + dst := pool.Get(int(length)) n, err := io.ReadFull(input, dst) if err == io.EOF { err = io.ErrUnexpectedEOF @@ -165,7 +152,6 @@ func (s *segmentedBuffer) Append(input io.Reader, length int) error { if n > 0 { s.len += uint32(n) s.cap -= uint32(n) - s.pending = s.pending - uint32(length) s.b = append(s.b, dst[0:n]) } return err diff --git a/util_test.go b/util_test.go index 68fb3ba..90b9cbe 100644 --- a/util_test.go +++ b/util_test.go @@ -54,18 +54,17 @@ func TestMin(t *testing.T) { func TestSegmentedBuffer(t *testing.T) { buf := newSegmentedBuffer(100) - assert := func(len, cap int) { + assert := func(len, cap uint32) { if buf.Len() != len { t.Fatalf("expected length %d, got %d", len, buf.Len()) } - if buf.Cap() != uint32(cap) { + buf.bm.Lock() + defer buf.bm.Unlock() + if buf.cap != cap { t.Fatalf("expected length %d, got %d", len, buf.Len()) } } assert(0, 100) - if !buf.TryReserve(3) { - t.Fatal("reservation should have worked") - } if err := buf.Append(bytes.NewReader([]byte("fooo")), 3); err != nil { t.Fatal(err) } @@ -87,9 +86,6 @@ func TestSegmentedBuffer(t *testing.T) { t.Fatal("should have grown by 2") } - if !buf.TryReserve(50) { - t.Fatal("reservation should have worked") - } if err := buf.Append(bytes.NewReader(make([]byte, 50)), 50); err != nil { t.Fatal(err) } @@ -104,9 +100,7 @@ func TestSegmentedBuffer(t *testing.T) { if read != 50 { t.Fatal("expected to read 50 bytes") } - if !buf.TryReserve(49) { - t.Fatal("should have been able to reserve rest of space") - } + assert(1, 49) if grew, amount := buf.GrowTo(100, false); !grew || amount != 50 { t.Fatal("should have grown when below half, even with reserved space")