Skip to content

Commit

Permalink
Merge pull request #203 from Tim-Zhang/fix-over-size-limit-master
Browse files Browse the repository at this point in the history
[master] Fix the bug caused by oversized packets
  • Loading branch information
teawater authored Jul 19, 2023
2 parents e5e1dbe + 3ef0e4e commit 555c412
Show file tree
Hide file tree
Showing 11 changed files with 340 additions and 193 deletions.
129 changes: 75 additions & 54 deletions src/asynchronous/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ use tokio::{self, sync::mpsc, task};
use crate::common::client_connect;
use crate::error::{Error, Result};
use crate::proto::{
Code, Codec, GenMessage, Message, Request, Response, FLAG_REMOTE_CLOSED, FLAG_REMOTE_OPEN,
MESSAGE_TYPE_DATA, MESSAGE_TYPE_RESPONSE,
Code, Codec, GenMessage, Message, MessageHeader, Request, Response, FLAG_REMOTE_CLOSED,
FLAG_REMOTE_OPEN, MESSAGE_TYPE_DATA, MESSAGE_TYPE_RESPONSE,
};
use crate::r#async::connection::*;
use crate::r#async::shutdown;
Expand Down Expand Up @@ -68,7 +68,7 @@ impl Client {
let timeout_nano = req.timeout_nano;
let stream_id = self.next_stream_id.fetch_add(2, Ordering::Relaxed);

let msg: GenMessage = Message::new_request(stream_id, req)
let msg: GenMessage = Message::new_request(stream_id, req)?
.try_into()
.map_err(|e: protobuf::Error| Error::Others(e.to_string()))?;

Expand Down Expand Up @@ -97,6 +97,7 @@ impl Client {
};

let msg = result?;

let res = Response::decode(msg.payload)
.map_err(err_to_others_err!(e, "Unpack response error "))?;

Expand All @@ -117,7 +118,7 @@ impl Client {
) -> Result<StreamInner> {
let stream_id = self.next_stream_id.fetch_add(2, Ordering::Relaxed);

let mut msg: GenMessage = Message::new_request(stream_id, req)
let mut msg: GenMessage = Message::new_request(stream_id, req)?
.try_into()
.map_err(|e: protobuf::Error| Error::Others(e.to_string()))?;

Expand Down Expand Up @@ -223,6 +224,58 @@ impl WriterDelegate for ClientWriter {
}
}

async fn get_resp_tx(
req_map: Arc<Mutex<HashMap<u32, ResultSender>>>,
header: &MessageHeader,
) -> Option<ResultSender> {
let resp_tx = match header.type_ {
MESSAGE_TYPE_RESPONSE => match req_map.lock().unwrap().remove(&header.stream_id) {
Some(tx) => tx,
None => {
debug!("Receiver got unknown response packet {:?}", header);
return None;
}
},
MESSAGE_TYPE_DATA => {
if (header.flags & FLAG_REMOTE_CLOSED) == FLAG_REMOTE_CLOSED {
match req_map.lock().unwrap().remove(&header.stream_id) {
Some(tx) => tx,
None => {
debug!("Receiver got unknown data packet {:?}", header);
return None;
}
}
} else {
match req_map.lock().unwrap().get(&header.stream_id) {
Some(tx) => tx.clone(),
None => {
debug!("Receiver got unknown data packet {:?}", header);
return None;
}
}
}
}
_ => {
let resp_tx = match req_map.lock().unwrap().remove(&header.stream_id) {
Some(tx) => tx,
None => {
debug!("Receiver got unknown packet {:?}", header);
return None;
}
};
resp_tx
.send(Err(Error::Others(format!(
"Receiver got malformed packet {header:?}"
))))
.await
.unwrap_or_else(|_e| error!("The request has returned"));
return None;
}
};

Some(resp_tx)
}

struct ClientReader {
streams: Arc<Mutex<HashMap<u32, ResultSender>>>,
shutdown_waiter: shutdown::Waiter,
Expand Down Expand Up @@ -252,59 +305,27 @@ impl ReaderDelegate for ClientReader {

async fn exit(&self) {}

async fn handle_err(&self, header: MessageHeader, e: Error) {
let req_map = self.streams.clone();
tokio::spawn(async move {
if let Some(resp_tx) = get_resp_tx(req_map, &header).await {
resp_tx
.send(Err(e))
.await
.unwrap_or_else(|_e| error!("The request has returned"));
}
});
}

async fn handle_msg(&self, msg: GenMessage) {
let req_map = self.streams.clone();
tokio::spawn(async move {
let resp_tx = match msg.header.type_ {
MESSAGE_TYPE_RESPONSE => {
match req_map.lock().unwrap().remove(&msg.header.stream_id) {
Some(tx) => tx,
None => {
debug!("Receiver got unknown response packet {:?}", msg);
return;
}
}
}
MESSAGE_TYPE_DATA => {
if (msg.header.flags & FLAG_REMOTE_CLOSED) == FLAG_REMOTE_CLOSED {
match req_map.lock().unwrap().remove(&msg.header.stream_id) {
Some(tx) => tx.clone(),
None => {
debug!("Receiver got unknown data packet {:?}", msg);
return;
}
}
} else {
match req_map.lock().unwrap().get(&msg.header.stream_id) {
Some(tx) => tx.clone(),
None => {
debug!("Receiver got unknown data packet {:?}", msg);
return;
}
}
}
}
_ => {
let resp_tx = match req_map.lock().unwrap().remove(&msg.header.stream_id) {
Some(tx) => tx,
None => {
debug!("Receiver got unknown packet {:?}", msg);
return;
}
};
resp_tx
.send(Err(Error::Others(format!(
"Receiver got malformed packet {msg:?}"
))))
.await
.unwrap_or_else(|_e| error!("The request has returned"));
return;
}
};
resp_tx
.send(Ok(msg))
.await
.unwrap_or_else(|_e| error!("The request has returned"));
if let Some(resp_tx) = get_resp_tx(req_map, &msg.header).await {
resp_tx
.send(Ok(msg))
.await
.unwrap_or_else(|_e| error!("The request has returned"));
}
});
}
}
10 changes: 8 additions & 2 deletions src/asynchronous/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ use tokio::{
};

use crate::error::Error;
use crate::proto::GenMessage;
use crate::proto::{GenMessage, GenMessageError, MessageHeader};

pub trait Builder {
type Reader;
Expand All @@ -36,6 +36,7 @@ pub trait ReaderDelegate {
async fn disconnect(&self, e: Error, task: &mut task::JoinHandle<()>);
async fn exit(&self);
async fn handle_msg(&self, msg: GenMessage);
async fn handle_err(&self, header: MessageHeader, e: Error);
}

pub struct Connection<S, B: Builder> {
Expand Down Expand Up @@ -89,7 +90,12 @@ where
trace!("Got Message {:?}", msg);
reader_delegate.handle_msg(msg).await;
}
Err(e) => {
Err(GenMessageError::ReturnError(header, e)) => {
trace!("Read msg err (can be return): {:?}", e);
reader_delegate.handle_err(header, e).await;
}

Err(GenMessageError::InternalError(e)) => {
trace!("Read msg err: {:?}", e);
reader_delegate.disconnect(e, &mut writer_task).await;
break;
Expand Down
26 changes: 22 additions & 4 deletions src/asynchronous/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ use async_trait::async_trait;
use futures::stream::Stream;
use futures::StreamExt as _;
use nix::unistd;
use protobuf::Message as _;
use tokio::{
self,
io::{AsyncRead, AsyncWrite},
Expand All @@ -35,8 +36,8 @@ use crate::common::{self, Domain};
use crate::context;
use crate::error::{get_status, Error, Result};
use crate::proto::{
Code, Codec, GenMessage, Message, MessageHeader, Request, Response, Status, FLAG_NO_DATA,
FLAG_REMOTE_CLOSED, FLAG_REMOTE_OPEN, MESSAGE_TYPE_DATA, MESSAGE_TYPE_REQUEST,
check_oversize, Code, Codec, GenMessage, Message, MessageHeader, Request, Response, Status,
FLAG_NO_DATA, FLAG_REMOTE_CLOSED, FLAG_REMOTE_OPEN, MESSAGE_TYPE_DATA, MESSAGE_TYPE_REQUEST,
};
use crate::r#async::connection::*;
use crate::r#async::shutdown;
Expand Down Expand Up @@ -386,6 +387,10 @@ impl ReaderDelegate for ServerReader {
}
});
}

async fn handle_err(&self, header: MessageHeader, e: Error) {
self.context().handle_err(header, e).await
}
}

impl ServerReader {
Expand All @@ -410,6 +415,14 @@ struct HandlerContext {
}

impl HandlerContext {
async fn handle_err(&self, header: MessageHeader, e: Error) {
Self::respond(self.tx.clone(), header.stream_id, e.into())
.await
.map_err(|e| {
error!("respond error got error {:?}", e);
})
.ok();
}
async fn handle_msg(&self, msg: GenMessage) {
let stream_id = msg.header.stream_id;

Expand All @@ -426,8 +439,13 @@ impl HandlerContext {
match msg.header.type_ {
MESSAGE_TYPE_REQUEST => match self.handle_request(msg).await {
Ok(opt_msg) => match opt_msg {
Some(msg) => {
Self::respond(self.tx.clone(), stream_id, msg)
Some(mut resp) => {
// Server: check size before sending to client
if let Err(e) = check_oversize(resp.compute_size() as usize, true) {
resp = e.into();
}

Self::respond(self.tx.clone(), stream_id, resp)
.await
.map_err(|e| {
error!("respond got error {:?}", e);
Expand Down
4 changes: 4 additions & 0 deletions src/asynchronous/stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -418,6 +418,9 @@ impl StreamSender {
header,
payload: buf,
};

msg.check()?;

_send(&self.tx, msg).await?;

Ok(())
Expand Down Expand Up @@ -447,6 +450,7 @@ impl StreamReceiver {
return Err(Error::RemoteClosed);
}
let msg = _recv(&mut self.rx).await?;

let payload = match msg.header.type_ {
MESSAGE_TYPE_RESPONSE => {
debug_assert_eq!(self.kind, Kind::Client);
Expand Down
3 changes: 2 additions & 1 deletion src/common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

//! Common functions.

use crate::error::{Error, Result};
#[cfg(any(
feature = "async",
not(any(target_os = "linux", target_os = "android"))
Expand All @@ -16,6 +15,8 @@ use nix::fcntl::{fcntl, FcntlArg, OFlag};
use nix::sys::socket::*;
use std::os::unix::io::RawFd;

use crate::error::{Error, Result};

#[derive(Debug, Clone, Copy, PartialEq)]
pub(crate) enum Domain {
Unix,
Expand Down
16 changes: 15 additions & 1 deletion src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

//! Error and Result of ttrpc and relevant functions, macros.

use crate::proto::{Code, Status};
use crate::proto::{Code, Response, Status};
use std::result;
use thiserror::Error;

Expand Down Expand Up @@ -48,6 +48,20 @@ pub enum Error {
Others(String),
}

impl From<Error> for Response {
fn from(e: Error) -> Self {
let status = if let Error::RpcStatus(stat) = e {
stat
} else {
get_status(Code::UNKNOWN, e)
};

let mut res = Response::new();
res.set_status(status);
res
}
}

/// A specialized Result type for ttrpc.
pub type Result<T> = result::Result<T, Error>;

Expand Down
Loading

0 comments on commit 555c412

Please sign in to comment.