Skip to content

Commit

Permalink
fix: read buf set to the wrong place (#169)
Browse files Browse the repository at this point in the history
  • Loading branch information
PureWhiteWu authored May 25, 2023
1 parent c68ec2b commit 652fcad
Show file tree
Hide file tree
Showing 4 changed files with 67 additions and 52 deletions.
46 changes: 32 additions & 14 deletions pilota-build/src/codegen/thrift/ty.rs
Original file line number Diff line number Diff line change
Expand Up @@ -347,21 +347,39 @@ impl ThriftBackend {
let read_list_begin = helper.codegen_read_list_begin();
let read_list_end = helper.codegen_read_list_end();
let read_el = self.codegen_decode_ty(helper, ty);
format! {
r#"
unsafe {{
let list_ident = {read_list_begin};
let mut val = Vec::with_capacity(list_ident.size);
for i in 0..list_ident.size {{
*val.get_unchecked_mut(i) = {read_el};
}};
val.set_len(list_ident.size);
{read_list_end};
val
}}
"#
let ty_rust_name = self.codegen_item_ty(ty.kind.clone());
if !helper.is_async {
format! {
r#"
unsafe {{
let list_ident = {read_list_begin};
let mut val: Vec<{ty_rust_name}> = Vec::with_capacity(list_ident.size);
for i in 0..list_ident.size {{
val.as_mut_ptr().offset(i as isize).write({read_el});
}};
val.set_len(list_ident.size);
{read_list_end};
val
}}
"#
}
.into()
} else {
format! {
r#"
{{
let list_ident = {read_list_begin};
let mut val = Vec::with_capacity(list_ident.size);
for _ in 0..list_ident.size {{
val.push({read_el});
}};
{read_list_end};
val
}}
"#
}
.into()
}
.into()
}
ty::Set(ty) => {
let read_set_begin = helper.codegen_read_set_begin();
Expand Down
54 changes: 31 additions & 23 deletions pilota-build/test_data/thrift/wrapper_arc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -618,20 +618,26 @@ pub mod wrapper_arc {
Some(2) if field_ident.field_type == ::pilota::thrift::TType::List => {
name2 = Some(unsafe {
let list_ident = protocol.read_list_begin()?;
let mut val = Vec::with_capacity(list_ident.size);
let mut val: Vec<::std::vec::Vec<::std::sync::Arc<A>>> =
Vec::with_capacity(list_ident.size);
for i in 0..list_ident.size {
*val.get_unchecked_mut(i) = unsafe {
val.as_mut_ptr().offset(i as isize).write(unsafe {
let list_ident = protocol.read_list_begin()?;
let mut val = Vec::with_capacity(list_ident.size);
let mut val: Vec<::std::sync::Arc<A>> =
Vec::with_capacity(list_ident.size);
for i in 0..list_ident.size {
*val.get_unchecked_mut(i) = ::std::sync::Arc::new(
::pilota::thrift::Message::decode(protocol)?,
val.as_mut_ptr().offset(i as isize).write(
::std::sync::Arc::new(
::pilota::thrift::Message::decode(
protocol,
)?,
),
);
}
val.set_len(list_ident.size);
protocol.read_list_end()?;
val
};
});
}
val.set_len(list_ident.size);
protocol.read_list_end()?;
Expand All @@ -646,10 +652,15 @@ pub mod wrapper_arc {
for _ in 0..map_ident.size {
val.insert(protocol.read_i32()?, unsafe {
let list_ident = protocol.read_list_begin()?;
let mut val = Vec::with_capacity(list_ident.size);
let mut val: Vec<::std::sync::Arc<A>> =
Vec::with_capacity(list_ident.size);
for i in 0..list_ident.size {
*val.get_unchecked_mut(i) = ::std::sync::Arc::new(
::pilota::thrift::Message::decode(protocol)?,
val.as_mut_ptr().offset(i as isize).write(
::std::sync::Arc::new(
::pilota::thrift::Message::decode(
protocol,
)?,
),
);
}
val.set_len(list_ident.size);
Expand Down Expand Up @@ -739,27 +750,25 @@ pub mod wrapper_arc {
id = Some(protocol.read_faststr().await?);
}
Some(2) if field_ident.field_type == ::pilota::thrift::TType::List => {
name2 = Some(unsafe {
name2 = Some({
let list_ident = protocol.read_list_begin().await?;
let mut val = Vec::with_capacity(list_ident.size);
for i in 0..list_ident.size {
*val.get_unchecked_mut(i) = unsafe {
for _ in 0..list_ident.size {
val.push({
let list_ident = protocol.read_list_begin().await?;
let mut val = Vec::with_capacity(list_ident.size);
for i in 0..list_ident.size {
*val.get_unchecked_mut(i) = ::std::sync::Arc::new(
for _ in 0..list_ident.size {
val.push(::std::sync::Arc::new(
::pilota::thrift::Message::decode_async(
protocol,
)
.await?,
);
));
}
val.set_len(list_ident.size);
protocol.read_list_end().await?;
val
};
});
}
val.set_len(list_ident.size);
protocol.read_list_end().await?;
val
});
Expand All @@ -770,18 +779,17 @@ pub mod wrapper_arc {
let mut val =
::std::collections::HashMap::with_capacity(map_ident.size);
for _ in 0..map_ident.size {
val.insert(protocol.read_i32().await?, unsafe {
val.insert(protocol.read_i32().await?, {
let list_ident = protocol.read_list_begin().await?;
let mut val = Vec::with_capacity(list_ident.size);
for i in 0..list_ident.size {
*val.get_unchecked_mut(i) = ::std::sync::Arc::new(
for _ in 0..list_ident.size {
val.push(::std::sync::Arc::new(
::pilota::thrift::Message::decode_async(
protocol,
)
.await?,
);
));
}
val.set_len(list_ident.size);
protocol.read_list_end().await?;
val
});
Expand Down
11 changes: 3 additions & 8 deletions pilota/benches/thrift_binary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,14 +33,8 @@ fn binary_bench(c: &mut criterion::Criterion) {
drop(p);
assert_eq!(buf_le.len(), 8 * size);

let b = buf_le.clone();
let mut v2: Vec<i64> = Vec::with_capacity(size);
let src = b.as_ptr();
let dst = v2.as_mut_ptr();
unsafe {
std::ptr::copy_nonoverlapping(src, dst as *mut u8, size * 8);
v2.set_len(size);
}
let b = buf.clone();
let v2 = read_be_unsafe_vec(b, size);
assert_eq!(v, v2);

group.bench_function("big endian decode vec i64", |b| {
Expand Down Expand Up @@ -165,6 +159,7 @@ fn read_be_unsafe_vec(mut b: BytesMut, size: usize) -> Vec<i64> {
for i in 0..size {
*v.get_unchecked_mut(i) = p.read_i64().unwrap();
}
v.set_len(size);
v
}
}
Expand Down
8 changes: 1 addition & 7 deletions pilota/src/thrift/binary_unsafe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1043,13 +1043,7 @@ impl TInputProtocol for TBinaryProtocol<&mut BytesMut> {
self.index = 0;
// split and freeze it
let val = self.trans.split_to(len as usize).freeze();
self.buf = unsafe {
let l = self.trans.len();
slice::from_raw_parts_mut(
self.trans.as_mut_ptr().offset(l as isize),
self.trans.capacity() - l,
)
};
self.buf = unsafe { slice::from_raw_parts_mut(self.trans.as_mut_ptr(), self.trans.len()) };
Ok(val)
}

Expand Down

0 comments on commit 652fcad

Please sign in to comment.