Skip to content

Commit

Permalink
async ttrpc: Add size checks where packets are created
Browse files Browse the repository at this point in the history
server: check before sending to client
client: check before sending to server

Signed-off-by: Tim Zhang <[email protected]>
  • Loading branch information
Tim-Zhang committed Jul 12, 2023
1 parent 6d649a2 commit bf0b176
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 14 deletions.
4 changes: 2 additions & 2 deletions src/asynchronous/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ impl Client {
let timeout_nano = req.timeout_nano;
let stream_id = self.next_stream_id.fetch_add(2, Ordering::Relaxed);

let msg: GenMessage = Message::new_request(stream_id, req)
let msg: GenMessage = Message::new_request(stream_id, req)?
.try_into()
.map_err(|e: protobuf::Error| Error::Others(e.to_string()))?;

Expand Down Expand Up @@ -119,7 +119,7 @@ impl Client {
) -> Result<StreamInner> {
let stream_id = self.next_stream_id.fetch_add(2, Ordering::Relaxed);

let mut msg: GenMessage = Message::new_request(stream_id, req)
let mut msg: GenMessage = Message::new_request(stream_id, req)?
.try_into()
.map_err(|e: protobuf::Error| Error::Others(e.to_string()))?;

Expand Down
12 changes: 9 additions & 3 deletions src/asynchronous/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ use async_trait::async_trait;
use futures::stream::Stream;
use futures::StreamExt as _;
use nix::unistd;
use protobuf::Message as _;
use tokio::{
self,
io::{AsyncRead, AsyncWrite},
Expand All @@ -31,7 +32,7 @@ use tokio::{
use tokio_vsock::VsockListener;

use crate::asynchronous::unix_incoming::UnixIncoming;
use crate::common::{self, Domain};
use crate::common::{self, check_oversize, Domain};
use crate::context;
use crate::error::{get_status, Error, Result};
use crate::proto::{
Expand Down Expand Up @@ -426,8 +427,13 @@ impl HandlerContext {
match msg.header.type_ {
MESSAGE_TYPE_REQUEST => match self.handle_request(msg).await {
Ok(opt_msg) => match opt_msg {
Some(msg) => {
Self::respond(self.tx.clone(), stream_id, msg)
Some(mut resp) => {
// Server: check size before sending to client
if let Err(e) = check_oversize(resp.compute_size() as usize, true) {
resp = e.into();
}

Self::respond(self.tx.clone(), stream_id, resp)
.await
.map_err(|e| {
error!("respond got error {:?}", e);
Expand Down
6 changes: 6 additions & 0 deletions src/asynchronous/stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -418,6 +418,9 @@ impl StreamSender {
header,
payload: buf,
};

msg.check()?;

_send(&self.tx, msg).await?;

Ok(())
Expand Down Expand Up @@ -447,6 +450,9 @@ impl StreamReceiver {
return Err(Error::RemoteClosed);
}
let msg = _recv(&mut self.rx).await?;

msg.check()?;

let payload = match msg.header.type_ {
MESSAGE_TYPE_RESPONSE => {
debug_assert_eq!(self.kind, Kind::Client);
Expand Down
20 changes: 11 additions & 9 deletions src/proto.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,10 @@ pub use compiled::ttrpc::*;
use byteorder::{BigEndian, ByteOrder};
use protobuf::{CodedInputStream, CodedOutputStream};

use crate::common::check_oversize;
use crate::error::Result as TtResult;
#[cfg(feature = "async")]
use crate::common::{check_oversize, DEFAULT_PAGE_SIZE};
#[cfg(feature = "async")]
use crate::error::{get_rpc_status, Error, Result as TtResult};
use crate::error::{get_rpc_status, Error};

pub const MESSAGE_HEADER_LENGTH: usize = 10;
pub const MESSAGE_LENGTH_MAX: usize = 4 << 20;
Expand All @@ -38,7 +38,7 @@ async fn discard_message_body(
let mut need_discard = header.length as usize;

while need_discard > 0 {
let once_discard = std::cmp::min(DEFAULT_PAGE_SIZE, need_discard);
let once_discard = std::cmp::min(crate::common::DEFAULT_PAGE_SIZE, need_discard);
let mut content = vec![0; once_discard];
reader
.read_exact(&mut content)
Expand Down Expand Up @@ -305,11 +305,13 @@ where
}

impl<C: Codec> Message<C> {
pub fn new_request(stream_id: u32, message: C) -> Self {
Self {
pub fn new_request(stream_id: u32, message: C) -> TtResult<Self> {
check_oversize(message.size() as usize, false)?;

Ok(Self {
header: MessageHeader::new_request(stream_id, message.size()),
payload: message,
}
})
}
}

Expand Down Expand Up @@ -438,7 +440,7 @@ mod tests {
#[test]
fn gen_message_to_message() {
let req = new_protobuf_request();
let msg = Message::new_request(3, req);
let msg = Message::new_request(3, req).unwrap();
let msg_clone = msg.clone();
let gen: GenMessage = msg.try_into().unwrap();
let dmsg = Message::<Request>::try_from(gen).unwrap();
Expand Down Expand Up @@ -529,7 +531,7 @@ mod tests {
assert_eq!(&msg.payload.metadata[0].value, "test_value1");

let req = new_protobuf_request();
let mut dmsg = Message::new_request(u32::MAX, req);
let mut dmsg = Message::new_request(u32::MAX, req).unwrap();
dmsg.header.set_stream_id(0x123456);
dmsg.header.set_flags(0xe0);
dmsg.header.add_flags(0x0f);
Expand Down

0 comments on commit bf0b176

Please sign in to comment.