Skip to content

Commit

Permalink
util: assert compatibility between LengthDelimitedCodec options (#6414
Browse files Browse the repository at this point in the history
)
  • Loading branch information
maminrayej authored Mar 23, 2024
1 parent 4c453e9 commit 8342e4b
Show file tree
Hide file tree
Showing 2 changed files with 98 additions and 1 deletion.
39 changes: 38 additions & 1 deletion tokio-util/src/codec/length_delimited.rs
Original file line number Diff line number Diff line change
Expand Up @@ -386,6 +386,10 @@ use std::{cmp, fmt, mem};
/// `Builder` enables constructing configured length delimited codecs. Note
/// that not all configuration settings apply to both encoding and decoding. See
/// the documentation for specific methods for more detail.
///
/// Note that the if the value of [`Builder::max_frame_length`] becomes larger than
/// what can actually fit in [`Builder::length_field_length`], it will be clipped to
/// the maximum value that can fit.
#[derive(Debug, Clone, Copy)]
pub struct Builder {
// Maximum frame length
Expand Down Expand Up @@ -935,8 +939,12 @@ impl Builder {
/// # }
/// ```
pub fn new_codec(&self) -> LengthDelimitedCodec {
let mut builder = *self;

builder.adjust_max_frame_len();

LengthDelimitedCodec {
builder: *self,
builder,
state: DecodeState::Head,
}
}
Expand Down Expand Up @@ -1018,6 +1026,35 @@ impl Builder {
self.num_skip
.unwrap_or(self.length_field_offset + self.length_field_len)
}

fn adjust_max_frame_len(&mut self) {
// This function is basically `std::u64::saturating_add_signed`. Since it
// requires MSRV 1.66, its implementation is copied here.
//
// TODO: use the method from std when MSRV becomes >= 1.66
fn saturating_add_signed(num: u64, rhs: i64) -> u64 {
let (res, overflow) = num.overflowing_add(rhs as u64);
if overflow == (rhs < 0) {
res
} else if overflow {
u64::MAX
} else {
0
}
}

// Calculate the maximum number that can be represented using `length_field_len` bytes.
let max_number = match 1u64.checked_shl((8 * self.length_field_len) as u32) {
Some(shl) => shl - 1,
None => u64::MAX,
};

let max_allowed_len = saturating_add_signed(max_number, self.length_adjustment as i64);

if self.max_frame_len as u64 > max_allowed_len {
self.max_frame_len = usize::try_from(max_allowed_len).unwrap_or(usize::MAX);
}
}
}

impl Default for Builder {
Expand Down
60 changes: 60 additions & 0 deletions tokio-util/tests/length_delimited.rs
Original file line number Diff line number Diff line change
Expand Up @@ -689,6 +689,66 @@ fn encode_overflow() {
codec.encode(Bytes::from("hello"), &mut buf).unwrap();
}

#[test]
fn frame_does_not_fit() {
let codec = LengthDelimitedCodec::builder()
.length_field_length(1)
.max_frame_length(256)
.new_codec();

assert_eq!(codec.max_frame_length(), 255);
}

#[test]
fn neg_adjusted_frame_does_not_fit() {
let codec = LengthDelimitedCodec::builder()
.length_field_length(1)
.length_adjustment(-1)
.new_codec();

assert_eq!(codec.max_frame_length(), 254);
}

#[test]
fn pos_adjusted_frame_does_not_fit() {
let codec = LengthDelimitedCodec::builder()
.length_field_length(1)
.length_adjustment(1)
.new_codec();

assert_eq!(codec.max_frame_length(), 256);
}

#[test]
fn max_allowed_frame_fits() {
let codec = LengthDelimitedCodec::builder()
.length_field_length(std::mem::size_of::<usize>())
.max_frame_length(usize::MAX)
.new_codec();

assert_eq!(codec.max_frame_length(), usize::MAX);
}

#[test]
fn smaller_frame_len_not_adjusted() {
let codec = LengthDelimitedCodec::builder()
.max_frame_length(10)
.length_field_length(std::mem::size_of::<usize>())
.new_codec();

assert_eq!(codec.max_frame_length(), 10);
}

#[test]
fn max_allowed_length_field() {
let codec = LengthDelimitedCodec::builder()
.length_field_length(8)
.max_frame_length(usize::MAX)
.new_codec();

assert_eq!(codec.max_frame_length(), usize::MAX);
}

// ===== Test utils =====

struct Mock {
Expand Down

0 comments on commit 8342e4b

Please sign in to comment.