Skip to content

Commit

Permalink
async ttrpc: Add size checks where packets are created
Browse files Browse the repository at this point in the history
server: check before sending to client
client: check before sending to server

Signed-off-by: Tim Zhang <[email protected]>
  • Loading branch information
Tim-Zhang committed Jul 12, 2023
1 parent 6d649a2 commit 5342f91
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 9 deletions.
4 changes: 2 additions & 2 deletions src/asynchronous/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ impl Client {
let timeout_nano = req.timeout_nano;
let stream_id = self.next_stream_id.fetch_add(2, Ordering::Relaxed);

let msg: GenMessage = Message::new_request(stream_id, req)
let msg: GenMessage = Message::new_request(stream_id, req)?
.try_into()
.map_err(|e: protobuf::Error| Error::Others(e.to_string()))?;

Expand Down Expand Up @@ -119,7 +119,7 @@ impl Client {
) -> Result<StreamInner> {
let stream_id = self.next_stream_id.fetch_add(2, Ordering::Relaxed);

let mut msg: GenMessage = Message::new_request(stream_id, req)
let mut msg: GenMessage = Message::new_request(stream_id, req)?
.try_into()
.map_err(|e: protobuf::Error| Error::Others(e.to_string()))?;

Expand Down
12 changes: 9 additions & 3 deletions src/asynchronous/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ use async_trait::async_trait;
use futures::stream::Stream;
use futures::StreamExt as _;
use nix::unistd;
use protobuf::Message as _;
use tokio::{
self,
io::{AsyncRead, AsyncWrite},
Expand All @@ -31,7 +32,7 @@ use tokio::{
use tokio_vsock::VsockListener;

use crate::asynchronous::unix_incoming::UnixIncoming;
use crate::common::{self, Domain};
use crate::common::{self, check_oversize, Domain};
use crate::context;
use crate::error::{get_status, Error, Result};
use crate::proto::{
Expand Down Expand Up @@ -426,8 +427,13 @@ impl HandlerContext {
match msg.header.type_ {
MESSAGE_TYPE_REQUEST => match self.handle_request(msg).await {
Ok(opt_msg) => match opt_msg {
Some(msg) => {
Self::respond(self.tx.clone(), stream_id, msg)
Some(mut resp) => {
// Server: check size before sending to client
if let Err(e) = check_oversize(resp.compute_size() as usize, true) {
resp = e.into();
}

Self::respond(self.tx.clone(), stream_id, resp)
.await
.map_err(|e| {
error!("respond got error {:?}", e);
Expand Down
6 changes: 6 additions & 0 deletions src/asynchronous/stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -418,6 +418,9 @@ impl StreamSender {
header,
payload: buf,
};

msg.check()?;

_send(&self.tx, msg).await?;

Ok(())
Expand Down Expand Up @@ -447,6 +450,9 @@ impl StreamReceiver {
return Err(Error::RemoteClosed);
}
let msg = _recv(&mut self.rx).await?;

msg.check()?;

let payload = match msg.header.type_ {
MESSAGE_TYPE_RESPONSE => {
debug_assert_eq!(self.kind, Kind::Client);
Expand Down
11 changes: 7 additions & 4 deletions src/proto.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,9 @@ use protobuf::{CodedInputStream, CodedOutputStream};

#[cfg(feature = "async")]
use crate::common::{check_oversize, DEFAULT_PAGE_SIZE};
use crate::error::Result as TtResult;
#[cfg(feature = "async")]
use crate::error::{get_rpc_status, Error, Result as TtResult};
use crate::error::{get_rpc_status, Error};

pub const MESSAGE_HEADER_LENGTH: usize = 10;
pub const MESSAGE_LENGTH_MAX: usize = 4 << 20;
Expand Down Expand Up @@ -305,11 +306,13 @@ where
}

impl<C: Codec> Message<C> {
pub fn new_request(stream_id: u32, message: C) -> Self {
Self {
pub fn new_request(stream_id: u32, message: C) -> TtResult<Self> {
check_oversize(message.size() as usize, false)?;

Ok(Self {
header: MessageHeader::new_request(stream_id, message.size()),
payload: message,
}
})
}
}

Expand Down

0 comments on commit 5342f91

Please sign in to comment.