Skip to content

Commit

Permalink
async ttrpc: Fix the bug caused by oversized packets
Browse files Browse the repository at this point in the history
Fix #198

Signed-off-by: Tim Zhang <[email protected]>
  • Loading branch information
Tim-Zhang committed Jul 12, 2023
1 parent 1587248 commit f7f35be
Show file tree
Hide file tree
Showing 3 changed files with 77 additions and 31 deletions.
3 changes: 3 additions & 0 deletions src/asynchronous/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,9 @@ impl Client {
};

let msg = result?;
// Check received server response
msg.check()?;

let res = Response::decode(msg.payload)
.map_err(err_to_others_err!(e, "Unpack response error "))?;

Expand Down
14 changes: 7 additions & 7 deletions src/asynchronous/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -413,13 +413,13 @@ impl HandlerContext {
async fn handle_msg(&self, msg: GenMessage) {
let stream_id = msg.header.stream_id;

if (stream_id % 2) != 1 {
Self::respond_with_status(
self.tx.clone(),
stream_id,
get_status(Code::INVALID_ARGUMENT, "stream id must be odd"),
)
.await;
if let Err(e) = msg.check() {
Self::respond(self.tx.clone(), stream_id, e.into())
.await
.map_err(|e| {
error!("respond error got error {:?}", e);
})
.ok();
return;
}

Expand Down
91 changes: 67 additions & 24 deletions src/proto.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ pub use compiled::ttrpc::*;
use byteorder::{BigEndian, ByteOrder};
use protobuf::{CodedInputStream, CodedOutputStream};

#[cfg(feature = "async")]
use crate::common::{check_oversize, DEFAULT_PAGE_SIZE};
#[cfg(feature = "async")]
use crate::error::{get_rpc_status, Error, Result as TtResult};

Expand All @@ -27,6 +29,27 @@ pub const FLAG_REMOTE_CLOSED: u8 = 0x1;
pub const FLAG_REMOTE_OPEN: u8 = 0x2;
pub const FLAG_NO_DATA: u8 = 0x4;

// Discard the unwanted message body
#[cfg(feature = "async")]
async fn discard_message_body(
mut reader: impl tokio::io::AsyncReadExt + Unpin,
header: &MessageHeader,
) -> TtResult<()> {
let mut need_discard = header.length as usize;

while need_discard > 0 {
let once_discard = std::cmp::min(DEFAULT_PAGE_SIZE, need_discard);
let mut content = vec![0; once_discard];
reader
.read_exact(&mut content)
.await
.map_err(|e| Error::Socket(e.to_string()))?;
need_discard -= once_discard;
}

Ok(())
}

/// Message header of ttrpc.
#[derive(Default, Debug, Clone, Copy, PartialEq, Eq)]
pub struct MessageHeader {
Expand Down Expand Up @@ -180,14 +203,12 @@ impl GenMessage {
.await
.map_err(|e| Error::Socket(e.to_string()))?;

if header.length > MESSAGE_LENGTH_MAX as u32 {
return Err(get_rpc_status(
Code::INVALID_ARGUMENT,
format!(
"message length {} exceed maximum message size of {}",
header.length, MESSAGE_LENGTH_MAX
),
));
if check_oversize(header.length as usize, true).is_err() {
discard_message_body(reader, &header).await?;
return Ok(Self {
header,
payload: Vec::new(),
});
}

let mut content = vec![0; header.length as usize];
Expand All @@ -201,6 +222,20 @@ impl GenMessage {
payload: content,
})
}

pub fn check(&self) -> TtResult<()> {
check_oversize(self.header.length as usize, true)?;

// check stream_id
if (self.header.stream_id % 2) != 1 {
return Err(get_rpc_status(
Code::INVALID_ARGUMENT,
"stream id must be odd",
));
}

Ok(())
}
}

/// TTRPC codec, only protobuf is supported.
Expand Down Expand Up @@ -310,14 +345,12 @@ where
.await
.map_err(|e| Error::Socket(e.to_string()))?;

if header.length > MESSAGE_LENGTH_MAX as u32 {
return Err(get_rpc_status(
Code::INVALID_ARGUMENT,
format!(
"message length {} exceed maximum message size of {}",
header.length, MESSAGE_LENGTH_MAX
),
));
if check_oversize(header.length as usize, true).is_err() {
discard_message_body(reader, &header).await?;
return Ok(Self {
header,
payload: C::decode("").map_err(err_to_others_err!(e, "Decode payload failed."))?,
});
}

let mut content = vec![0; header.length as usize];
Expand Down Expand Up @@ -429,11 +462,15 @@ mod tests {
#[cfg(feature = "async")]
#[tokio::test]
async fn async_gen_message() {
// Test packet which exceeds maximum message size
let mut buf = Vec::from(MESSAGE_HEADER);
buf.extend_from_slice(&PROTOBUF_REQUEST);
let res = GenMessage::read_from(&*buf).await;
// exceed maximum message size
assert!(matches!(res, Err(Error::RpcStatus(_))));
let header = MessageHeader::read_from(&*buf).await.expect("read header");
buf.append(&mut vec![0x0; header.length as usize]);

let gen = GenMessage::read_from(&*buf).await.expect("read message");

assert_eq!(gen.header, header);
assert_eq!(gen.payload.len(), 0);

let mut buf = Vec::from(PROTOBUF_MESSAGE_HEADER);
buf.extend_from_slice(&PROTOBUF_REQUEST);
Expand All @@ -459,11 +496,17 @@ mod tests {
#[cfg(feature = "async")]
#[tokio::test]
async fn async_message() {
// Test packet which exceeds maximum message size
let mut buf = Vec::from(MESSAGE_HEADER);
buf.extend_from_slice(&PROTOBUF_REQUEST);
let res = Message::<Request>::read_from(&*buf).await;
// exceed maximum message size
assert!(matches!(res, Err(Error::RpcStatus(_))));
let header = MessageHeader::read_from(&*buf).await.expect("read header");
buf.append(&mut vec![0x0; header.length as usize]);

let gen = Message::<Request>::read_from(&*buf)
.await
.expect("read message");

assert_eq!(gen.header, header);
assert_eq!(protobuf::Message::compute_size(&gen.payload), 0);

let mut buf = Vec::from(PROTOBUF_MESSAGE_HEADER);
buf.extend_from_slice(&PROTOBUF_REQUEST);
Expand Down

0 comments on commit f7f35be

Please sign in to comment.