From dbf2b88bdad09ab9e37f66cabbcab75f1640246c Mon Sep 17 00:00:00 2001 From: Simon Ferquel Date: Mon, 14 Nov 2016 11:14:44 +0100 Subject: [PATCH] Added a test for write overflows Signed-off-by: Simon Ferquel --- channel_test.go | 78 +++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 78 insertions(+) create mode 100644 channel_test.go diff --git a/channel_test.go b/channel_test.go new file mode 100644 index 0000000..a0e5367 --- /dev/null +++ b/channel_test.go @@ -0,0 +1,78 @@ +package p9p + +import ( + "bytes" + "context" + "encoding/binary" + "errors" + "net" + "testing" + "time" +) + +type fakeAddr struct{} + +func (a fakeAddr) Network() string { + return "" +} +func (a fakeAddr) String() string { + return "fake address" +} + +type writeConnMock struct { + data []byte +} + +func (c *writeConnMock) Read(b []byte) (n int, err error) { + return 0, errors.New("not implemented") +} + +func (c *writeConnMock) Write(b []byte) (n int, err error) { + c.data = append(c.data, b...) + + n = len(b) + return n, nil +} + +func (c *writeConnMock) Close() error { + return nil +} + +func (c *writeConnMock) LocalAddr() net.Addr { + return fakeAddr{} +} + +func (c *writeConnMock) RemoteAddr() net.Addr { + return fakeAddr{} +} + +func (c *writeConnMock) SetDeadline(t time.Time) error { + return nil +} + +func (c *writeConnMock) SetReadDeadline(t time.Time) error { + return nil +} + +func (c *writeConnMock) SetWriteDeadline(t time.Time) error { + return nil +} + +func TestWriteOverflow(t *testing.T) { + const testMsize = 500 + conn := writeConnMock{} + channel := newChannel(&conn, NewCodec(), testMsize) + writeRequest := MessageTwrite{1, 0, make([]byte, 2*testMsize)} + ctx := context.Background() + channel.WriteFcall(ctx, newFcall(Tag(1), writeRequest)) + reader := bytes.NewReader(conn.data) + var writtenSize uint32 + err := binary.Read(reader, binary.LittleEndian, &writtenSize) + if err != nil { + t.Errorf("error reading result: %v", err) + } + // as there is an overflow, written size should have been truncated such that the message size is equal to channel's msize + if int(writtenSize) != testMsize { + t.Errorf("message should have been truncated to size %v. written message has size %v", testMsize, writtenSize) + } +}