diff --git a/session.go b/session.go index 238c2a9..501c59b 100644 --- a/session.go +++ b/session.go @@ -220,14 +220,18 @@ func (s *Session) Accept() (net.Conn, error) { // AcceptStream is used to block until the next available stream // is ready to be accepted. func (s *Session) AcceptStream() (*Stream, error) { - select { - case stream := <-s.acceptCh: - if err := stream.sendWindowUpdate(); err != nil { - return nil, err + for { + select { + case stream := <-s.acceptCh: + if err := stream.sendWindowUpdate(); err != nil { + // don't return accept errors. + s.logger.Printf("[WARN] error sending window update before accepting: %s", err) + continue + } + return stream, nil + case <-s.shutdownCh: + return nil, s.shutdownErr } - return stream, nil - case <-s.shutdownCh: - return nil, s.shutdownErr } } diff --git a/session_test.go b/session_test.go index 8c1a175..c090ed8 100644 --- a/session_test.go +++ b/session_test.go @@ -407,6 +407,7 @@ func TestSendData_Small(t *testing.T) { t.Errorf("err: %v", err) return } + defer stream.Close() if server.NumStreams() != 1 { t.Errorf("bad") @@ -430,7 +431,7 @@ func TestSendData_Small(t *testing.T) { } } - if err := stream.Close(); err != nil { + if err := stream.CloseWrite(); err != nil { t.Errorf("err: %v", err) return } @@ -442,11 +443,12 @@ func TestSendData_Small(t *testing.T) { go func() { defer wg.Done() - stream, err := client.Open() + stream, err := client.OpenStream() if err != nil { t.Errorf("err: %v", err) return } + defer stream.Close() if client.NumStreams() != 1 { t.Errorf("bad") @@ -465,7 +467,7 @@ func TestSendData_Small(t *testing.T) { } } - if err := stream.Close(); err != nil { + if err := stream.CloseWrite(); err != nil { t.Errorf("err: %v", err) return } @@ -785,12 +787,12 @@ func TestManyStreams_PingPong(t *testing.T) { wg.Wait() } -func TestHalfClose(t *testing.T) { +func TestCloseRead(t *testing.T) { client, server := testClientServer() defer client.Close() defer server.Close() - stream, err := client.Open() + stream, err := client.OpenStream() if err != nil { t.Fatalf("err: %v", err) } @@ -798,17 +800,43 @@ func TestHalfClose(t *testing.T) { t.Fatalf("err: %v", err) } - stream2, err := server.Accept() + stream2, err := server.AcceptStream() if err != nil { t.Fatalf("err: %v", err) } - stream2.Close() // Half close + stream2.CloseRead() buf := make([]byte, 4) n, err := stream2.Read(buf) + if n != 0 || err == nil { + t.Fatalf("read after close: %d %s", n, err) + } +} + +func TestHalfClose(t *testing.T) { + client, server := testClientServer() + defer client.Close() + defer server.Close() + + stream, err := client.OpenStream() if err != nil { t.Fatalf("err: %v", err) } + if _, err = stream.Write([]byte("a")); err != nil { + t.Fatalf("err: %v", err) + } + + stream2, err := server.AcceptStream() + if err != nil { + t.Fatalf("err: %v", err) + } + stream2.CloseWrite() // Half close + + buf := make([]byte, 4) + n, err := io.ReadAtLeast(stream2, buf, 1) + if err != nil && err != io.EOF { + t.Fatalf("err: %v", err) + } if n != 1 { t.Fatalf("bad: %v", n) } @@ -817,11 +845,17 @@ func TestHalfClose(t *testing.T) { if _, err = stream.Write([]byte("bcd")); err != nil { t.Fatalf("err: %v", err) } - stream.Close() + stream.CloseWrite() + + // write after close + n, err = stream.Write([]byte("foobar")) + if n != 0 || err == nil { + t.Fatalf("wrote after close: %d %s", n, err) + } // Read after close - n, err = stream2.Read(buf) - if err != nil { + n, err = io.ReadAtLeast(stream2, buf, 3) + if err != nil && err != io.EOF { t.Fatalf("err: %v", err) } if n != 3 { @@ -1131,7 +1165,6 @@ func TestSession_PartialReadWindowUpdate(t *testing.T) { t.Errorf("err: %v", err) return } - defer wr.Close() sendWindow := atomic.LoadUint32(&wr.sendWindow) if sendWindow != client.config.MaxStreamWindowSize { @@ -1352,8 +1385,9 @@ func TestStreamHalfClose2(t *testing.T) { if err != nil { t.Error(err) } + defer stream.Close() - stream.Close() + stream.CloseWrite() wait <- struct{}{} buf, err := ioutil.ReadAll(stream) diff --git a/stream.go b/stream.go index b3a5e7c..b3adde9 100644 --- a/stream.go +++ b/stream.go @@ -14,10 +14,15 @@ const ( streamSYNSent streamSYNReceived streamEstablished - streamLocalClose - streamRemoteClose - streamClosed - streamReset + streamFinished +) + +type halfStreamState int + +const ( + halfOpen halfStreamState = iota + halfClosed + halfReset ) // Stream is used to represent a logical stream @@ -28,8 +33,9 @@ type Stream struct { id uint32 session *Session - state streamState - stateLock sync.Mutex + state streamState + writeState, readState halfStreamState + stateLock sync.Mutex recvLock sync.Mutex recvBuf segmentedBuffer @@ -74,19 +80,22 @@ func (s *Stream) Read(b []byte) (n int, err error) { defer asyncNotify(s.recvNotifyCh) START: s.stateLock.Lock() - state := s.state + state := s.readState s.stateLock.Unlock() switch state { - case streamRemoteClose: - fallthrough - case streamClosed: + case halfOpen: + // Open -> read + case halfClosed: empty := s.recvBuf.Len() == 0 if empty { return 0, io.EOF } - case streamReset: + // Closed, but we have data pending -> read. + case halfReset: return 0, ErrStreamReset + default: + panic("unknown state") } // If there is no data available, block @@ -138,16 +147,18 @@ func (s *Stream) write(b []byte) (n int, err error) { START: s.stateLock.Lock() - state := s.state + state := s.writeState s.stateLock.Unlock() switch state { - case streamLocalClose: - fallthrough - case streamClosed: + case halfOpen: + // Open for writing -> write + case halfClosed: return 0, ErrStreamClosed - case streamReset: + case halfReset: return 0, ErrStreamReset + default: + panic("unknown state") } // If there is no data available, block @@ -239,75 +250,117 @@ func (s *Stream) sendReset() error { // Reset resets the stream (forcibly closes the stream) func (s *Stream) Reset() error { + sendReset := false s.stateLock.Lock() switch s.state { - case streamInit: - // No need to send anything. - s.state = streamReset - s.stateLock.Unlock() - return nil - case streamClosed, streamReset: + case streamFinished: s.stateLock.Unlock() return nil + case streamInit: + // we haven't sent anything, so we don't need to send a reset. case streamSYNSent, streamSYNReceived, streamEstablished: - case streamLocalClose, streamRemoteClose: + sendReset = true default: panic("unhandled state") } - s.state = streamReset - s.stateLock.Unlock() - err := s.sendReset() + // at least one direction is open, we need to reset. + + // If we've already sent/received an EOF, no need to reset that side. + if s.writeState == halfOpen { + s.writeState = halfReset + } + if s.readState == halfOpen { + s.readState = halfReset + } + s.state = streamFinished s.notifyWaiting() + s.stateLock.Unlock() + if sendReset { + _ = s.sendReset() + } s.cleanup() - - return err + return nil } -// Close is used to close the stream -func (s *Stream) Close() error { - closeStream := false +// CloseWrite is used to close the stream for writing. +func (s *Stream) CloseWrite() error { s.stateLock.Lock() - switch s.state { - case streamInit, streamSYNSent, streamSYNReceived, streamEstablished: - s.state = streamLocalClose - goto SEND_CLOSE + switch s.writeState { + case halfOpen: + // Open for writing -> close write + case halfClosed: + s.stateLock.Unlock() + return nil + case halfReset: + s.stateLock.Unlock() + return ErrStreamReset + default: + panic("invalid state") + } + s.writeState = halfClosed + cleanup := s.readState != halfOpen + if cleanup { + s.state = streamFinished + } + s.stateLock.Unlock() + s.notifyWaiting() - case streamLocalClose: - case streamRemoteClose: - s.state = streamClosed - closeStream = true - goto SEND_CLOSE + err := s.sendClose() + if cleanup { + // we're fully closed, might as well be nice to the user and + // free everything early. + s.cleanup() + } + return err +} - case streamClosed: - case streamReset: +// CloseRead is used to close the stream for writing. +func (s *Stream) CloseRead() error { + cleanup := false + s.stateLock.Lock() + switch s.readState { + case halfOpen: + // Open for reading -> close read + case halfClosed, halfReset: + s.stateLock.Unlock() + return nil default: - panic("unhandled state") + panic("invalid state") + } + s.readState = halfReset + cleanup = s.writeState != halfOpen + if cleanup { + s.state = streamFinished } s.stateLock.Unlock() - return nil -SEND_CLOSE: - s.stateLock.Unlock() - err := s.sendClose() s.notifyWaiting() - if closeStream { + if cleanup { + // we're fully closed, might as well be nice to the user and + // free everything early. s.cleanup() } - return err + return nil +} + +// Close is used to close the stream. +func (s *Stream) Close() error { + _ = s.CloseRead() // can't fail. + return s.CloseWrite() } // forceClose is used for when the session is exiting func (s *Stream) forceClose() { s.stateLock.Lock() - switch s.state { - case streamClosed: - // Already successfully closed. It just hasn't been removed from - // the list of streams yet. - default: - s.state = streamReset + if s.readState == halfOpen { + s.readState = halfReset } - s.stateLock.Unlock() + if s.writeState == halfOpen { + s.writeState = halfReset + } + s.state = streamFinished s.notifyWaiting() + s.stateLock.Unlock() s.readDeadline.set(time.Time{}) s.writeDeadline.set(time.Time{}) @@ -340,25 +393,24 @@ func (s *Stream) processFlags(flags uint16) error { s.session.establishStream(s.id) } if flags&flagFIN == flagFIN { - switch s.state { - case streamSYNSent: - fallthrough - case streamSYNReceived: - fallthrough - case streamEstablished: - s.state = streamRemoteClose + if s.readState == halfOpen { + s.readState = halfClosed + if s.writeState != halfOpen { + // We're now fully closed. + closeStream = true + s.state = streamFinished + } s.notifyWaiting() - case streamLocalClose: - s.state = streamClosed - closeStream = true - s.notifyWaiting() - default: - s.session.logger.Printf("[ERR] yamux: unexpected FIN flag in state %d", s.state) - return ErrUnexpectedFlag } } if flags&flagRST == flagRST { - s.state = streamReset + if s.readState == halfOpen { + s.readState = halfReset + } + if s.writeState == halfOpen { + s.writeState = halfReset + } + s.state = streamFinished closeStream = true s.notifyWaiting() } @@ -426,11 +478,9 @@ func (s *Stream) SetDeadline(t time.Time) error { func (s *Stream) SetReadDeadline(t time.Time) error { s.stateLock.Lock() defer s.stateLock.Unlock() - switch s.state { - case streamClosed, streamRemoteClose, streamReset: - return nil + if s.readState == halfOpen { + s.readDeadline.set(t) } - s.readDeadline.set(t) return nil } @@ -438,11 +488,9 @@ func (s *Stream) SetReadDeadline(t time.Time) error { func (s *Stream) SetWriteDeadline(t time.Time) error { s.stateLock.Lock() defer s.stateLock.Unlock() - switch s.state { - case streamClosed, streamLocalClose, streamReset: - return nil + if s.writeState == halfOpen { + s.writeDeadline.set(t) } - s.writeDeadline.set(t) return nil }