This repository has been archived by the owner on Mar 11, 2020. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 50
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
channel: truncate twrite messages based on msize
While there are a few problems around handling of msize, the easiest to address and, arguably, the most problematic is that of Twrite. We now truncate Twrite.Data to the correct length if it will overflow the msize limit negotiated on the session. ErrShortWrite is returned by the `Session.Write` method if written data is truncated. In addition, we now reject incoming messages from `ReadFcall` that overflow the msize. Such messages are probably terminal in practice, but can be detected with the `Overflow` function. Tread is also handled accordingly, such that the Count field will be rewritten such that the response doesn't overflow the msize. Signed-off-by: Stephen J Day <[email protected]>
- Loading branch information
Showing
6 changed files
with
354 additions
and
6 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,181 @@ | ||
package p9p | ||
|
||
import ( | ||
"bytes" | ||
"context" | ||
"encoding/binary" | ||
"net" | ||
"testing" | ||
"time" | ||
) | ||
|
||
// TestWriteOverflow ensures that a Twrite message will have the data field | ||
// truncated if the msize would be exceeded. | ||
func TestWriteOverflow(t *testing.T) { | ||
const ( | ||
msize = 512 | ||
overflowMSize = msize * 3 / 2 | ||
) | ||
|
||
var ( | ||
ctx = context.Background() | ||
conn = &mockConn{} | ||
ch = NewChannel(conn, msize) | ||
data = bytes.Repeat([]byte{'A'}, overflowMSize) | ||
fcall = newFcall(1, MessageTwrite{ | ||
Data: data, | ||
}) | ||
messageSize uint32 | ||
) | ||
|
||
if err := ch.WriteFcall(ctx, fcall); err != nil { | ||
t.Fatal(err) | ||
} | ||
|
||
if err := binary.Read(bytes.NewReader(conn.buf.Bytes()), binary.LittleEndian, &messageSize); err != nil { | ||
t.Fatal(err) | ||
} | ||
|
||
if messageSize != msize { | ||
t.Fatalf("should have truncated size header: %d != %d", messageSize, msize) | ||
} | ||
|
||
if conn.buf.Len() != msize { | ||
t.Fatalf("should have truncated message: conn.buf.Len(%v) != msize(%v)", conn.buf.Len(), msize) | ||
} | ||
} | ||
|
||
// TestWriteOverflowError ensures that we return an error in cases when there | ||
// will certainly be an overflow and it cannot be resolved. | ||
func TestWriteOverflowError(t *testing.T) { | ||
const ( | ||
msize = 4 | ||
overflowMSize = msize + 1 | ||
) | ||
|
||
var ( | ||
ctx = context.Background() | ||
conn = &mockConn{} | ||
ch = NewChannel(conn, msize) | ||
data = bytes.Repeat([]byte{'A'}, 4) | ||
fcall = newFcall(1, MessageTwrite{ | ||
Data: data, | ||
}) | ||
messageSize = 4 + ch.(*channel).codec.Size(fcall) | ||
) | ||
|
||
err := ch.WriteFcall(ctx, fcall) | ||
if err == nil { | ||
t.Fatal("error expected when overflowing message") | ||
} | ||
|
||
if Overflow(err) != messageSize-msize { | ||
t.Fatalf("overflow should reflect messageSize and msize, %d != %d", Overflow(err), messageSize-msize) | ||
} | ||
} | ||
|
||
// TestReadOverflow ensures that messages coming over a network connection do | ||
// not overflow the msize. Invalid messages will cause `ReadFcall` to return an | ||
// Overflow error. | ||
func TestReadOverflow(t *testing.T) { | ||
const ( | ||
msize = 256 | ||
overflowMSize = msize + 1 | ||
) | ||
|
||
var ( | ||
ctx = context.Background() | ||
conn = &mockConn{} | ||
ch = NewChannel(conn, msize) | ||
data = bytes.Repeat([]byte{'A'}, overflowMSize) | ||
fcall = newFcall(1, MessageTwrite{ | ||
Data: data, | ||
}) | ||
messageSize = 4 + ch.(*channel).codec.Size(fcall) | ||
) | ||
|
||
// prepare the raw message | ||
p, err := ch.(*channel).codec.Marshal(fcall) | ||
if err != nil { | ||
t.Fatal(err) | ||
} | ||
|
||
// "send" the message into the buffer | ||
// this message is crafted to overflow the read buffer. | ||
if err := sendmsg(&conn.buf, p); err != nil { | ||
t.Fatal(err) | ||
} | ||
|
||
var incoming Fcall | ||
err = ch.ReadFcall(ctx, &incoming) | ||
if err == nil { | ||
t.Fatal("expected error on fcall") | ||
} | ||
|
||
if Overflow(err) != messageSize-msize { | ||
t.Fatalf("unexpected overflow on error: %v !=%v", Overflow(err), messageSize-msize) | ||
} | ||
} | ||
|
||
// TestTreadRewrite ensures that messages that whose response would overflow | ||
// the msize will have be adjusted before sending. | ||
func TestTreadRewrite(t *testing.T) { | ||
const ( | ||
msize = 256 | ||
overflowMSize = msize + 1 | ||
) | ||
|
||
var ( | ||
ctx = context.Background() | ||
conn = &mockConn{} | ||
ch = NewChannel(conn, msize) | ||
buf = make([]byte, overflowMSize) | ||
// data = bytes.Repeat([]byte{'A'}, overflowMSize) | ||
fcall = newFcall(1, MessageTread{ | ||
Count: overflowMSize, | ||
}) | ||
responseMSize = ch.(*channel).msgmsize(newFcall(1, MessageRread{ | ||
Data: buf, | ||
})) | ||
) | ||
|
||
if err := ch.WriteFcall(ctx, fcall); err != nil { | ||
t.Fatal(err) | ||
} | ||
|
||
// just read the message off the buffer | ||
n, err := readmsg(&conn.buf, buf) | ||
if err != nil { | ||
t.Fatal(err) | ||
} | ||
|
||
*fcall = Fcall{} | ||
if err := ch.(*channel).codec.Unmarshal(buf[:n], fcall); err != nil { | ||
t.Fatal(err) | ||
} | ||
|
||
tread, ok := fcall.Message.(MessageTread) | ||
if !ok { | ||
t.Fatalf("unexpected message: %v", fcall) | ||
} | ||
|
||
if tread.Count != overflowMSize-(uint32(responseMSize)-msize) { | ||
t.Fatalf("count not rewritten: %v != %v", tread.Count, overflowMSize-(uint32(responseMSize)-msize)) | ||
} | ||
} | ||
|
||
type mockConn struct { | ||
net.Conn | ||
buf bytes.Buffer | ||
} | ||
|
||
func (m mockConn) SetWriteDeadline(t time.Time) error { return nil } | ||
func (m mockConn) SetReadDeadline(t time.Time) error { return nil } | ||
|
||
func (m *mockConn) Write(p []byte) (int, error) { | ||
return m.buf.Write(p) | ||
} | ||
|
||
func (m *mockConn) Read(p []byte) (int, error) { | ||
return m.buf.Read(p) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.