Skip to content

Commit

Permalink
async ttrpc: Add size checks where packets are created
Browse files Browse the repository at this point in the history
server: check before sending to client
client: check before sending to server

Signed-off-by: Tim Zhang <[email protected]>
  • Loading branch information
Tim-Zhang committed Jul 12, 2023
1 parent ae06d00 commit 5d8be72
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 11 deletions.
4 changes: 2 additions & 2 deletions src/asynchronous/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ impl Client {
let timeout_nano = req.timeout_nano;
let stream_id = self.next_stream_id.fetch_add(2, Ordering::Relaxed);

let msg: GenMessage = Message::new_request(stream_id, req)
let msg: GenMessage = Message::new_request(stream_id, req)?
.try_into()
.map_err(|e: protobuf::Error| Error::Others(e.to_string()))?;

Expand Down Expand Up @@ -120,7 +120,7 @@ impl Client {
) -> Result<StreamInner> {
let stream_id = self.next_stream_id.fetch_add(2, Ordering::Relaxed);

let mut msg: GenMessage = Message::new_request(stream_id, req)
let mut msg: GenMessage = Message::new_request(stream_id, req)?
.try_into()
.map_err(|e: protobuf::Error| Error::Others(e.to_string()))?;

Expand Down
14 changes: 10 additions & 4 deletions src/asynchronous/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ use async_trait::async_trait;
use futures::stream::Stream;
use futures::StreamExt as _;
use nix::unistd;
use protobuf::Message as _;
use tokio::{
self,
io::{AsyncRead, AsyncWrite},
Expand All @@ -35,8 +36,8 @@ use crate::common::{self, Domain};
use crate::context;
use crate::error::{get_status, Error, Result};
use crate::proto::{
Code, Codec, GenMessage, Message, MessageHeader, Request, Response, Status, FLAG_NO_DATA,
FLAG_REMOTE_CLOSED, FLAG_REMOTE_OPEN, MESSAGE_TYPE_DATA, MESSAGE_TYPE_REQUEST,
check_oversize, Code, Codec, GenMessage, Message, MessageHeader, Request, Response, Status,
FLAG_NO_DATA, FLAG_REMOTE_CLOSED, FLAG_REMOTE_OPEN, MESSAGE_TYPE_DATA, MESSAGE_TYPE_REQUEST,
};
use crate::r#async::connection::*;
use crate::r#async::shutdown;
Expand Down Expand Up @@ -426,8 +427,13 @@ impl HandlerContext {
match msg.header.type_ {
MESSAGE_TYPE_REQUEST => match self.handle_request(msg).await {
Ok(opt_msg) => match opt_msg {
Some(msg) => {
Self::respond(self.tx.clone(), stream_id, msg)
Some(mut resp) => {
// Server: check size before sending to client
if let Err(e) = check_oversize(resp.compute_size() as usize, true) {
resp = e.into();
}

Self::respond(self.tx.clone(), stream_id, resp)
.await
.map_err(|e| {
error!("respond got error {:?}", e);
Expand Down
6 changes: 6 additions & 0 deletions src/asynchronous/stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -418,6 +418,9 @@ impl StreamSender {
header,
payload: buf,
};

msg.check()?;

_send(&self.tx, msg).await?;

Ok(())
Expand Down Expand Up @@ -447,6 +450,9 @@ impl StreamReceiver {
return Err(Error::RemoteClosed);
}
let msg = _recv(&mut self.rx).await?;

msg.check()?;

let payload = match msg.header.type_ {
MESSAGE_TYPE_RESPONSE => {
debug_assert_eq!(self.kind, Kind::Client);
Expand Down
12 changes: 7 additions & 5 deletions src/proto.rs
Original file line number Diff line number Diff line change
Expand Up @@ -321,11 +321,13 @@ where
}

impl<C: Codec> Message<C> {
pub fn new_request(stream_id: u32, message: C) -> Self {
Self {
pub fn new_request(stream_id: u32, message: C) -> TtResult<Self> {
check_oversize(message.size() as usize, false)?;

Ok(Self {
header: MessageHeader::new_request(stream_id, message.size()),
payload: message,
}
})
}
}

Expand Down Expand Up @@ -454,7 +456,7 @@ mod tests {
#[test]
fn gen_message_to_message() {
let req = new_protobuf_request();
let msg = Message::new_request(3, req);
let msg = Message::new_request(3, req).unwrap();
let msg_clone = msg.clone();
let gen: GenMessage = msg.try_into().unwrap();
let dmsg = Message::<Request>::try_from(gen).unwrap();
Expand Down Expand Up @@ -545,7 +547,7 @@ mod tests {
assert_eq!(&msg.payload.metadata[0].value, "test_value1");

let req = new_protobuf_request();
let mut dmsg = Message::new_request(u32::MAX, req);
let mut dmsg = Message::new_request(u32::MAX, req).unwrap();
dmsg.header.set_stream_id(0x123456);
dmsg.header.set_flags(0xe0);
dmsg.header.add_flags(0x0f);
Expand Down

0 comments on commit 5d8be72

Please sign in to comment.