Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: replace risky dma_buffer_as_vec implementations #16829

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions src/common/base/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ databend-common-exception = { workspace = true }
async-backtrace = { workspace = true }
async-trait = { workspace = true }
borsh = { workspace = true }
bytes = { workspace = true }
bytesize = { workspace = true }
chrono = { workspace = true }
ctrlc = { workspace = true }
Expand Down
69 changes: 56 additions & 13 deletions src/common/base/src/base/dma.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ use std::path::Path;
use std::ptr;
use std::ptr::NonNull;

use bytes::Bytes;
use rustix::fs::OFlags;
use tokio::fs::File;
use tokio::io::AsyncSeekExt;
Expand Down Expand Up @@ -116,10 +117,6 @@ impl DmaAllocator {
Layout::from_size_align(layout.size(), self.0.as_usize()).unwrap()
}
}

fn real_cap(&self, cap: usize) -> usize {
self.0.align_up(cap)
}
}

unsafe impl Allocator for DmaAllocator {
Expand All @@ -131,6 +128,10 @@ unsafe impl Allocator for DmaAllocator {
Global {}.allocate_zeroed(self.real_layout(layout))
}

unsafe fn deallocate(&self, ptr: std::ptr::NonNull<u8>, layout: Layout) {
Global {}.deallocate(ptr, self.real_layout(layout))
}

unsafe fn grow(
&self,
ptr: NonNull<u8>,
Expand All @@ -157,20 +158,38 @@ unsafe impl Allocator for DmaAllocator {
)
}

unsafe fn deallocate(&self, ptr: std::ptr::NonNull<u8>, layout: Layout) {
Global {}.deallocate(ptr, self.real_layout(layout))
unsafe fn shrink(
&self,
ptr: NonNull<u8>,
old_layout: Layout,
new_layout: Layout,
) -> Result<NonNull<[u8]>, AllocError> {
Global {}.shrink(
ptr,
self.real_layout(old_layout),
self.real_layout(new_layout),
)
}
}

type DmaBuffer = Vec<u8, DmaAllocator>;

pub fn dma_buffer_as_vec(mut buf: DmaBuffer) -> Vec<u8> {
let ptr = buf.as_mut_ptr();
let len = buf.len();
let cap = buf.allocator().real_cap(buf.capacity());
std::mem::forget(buf);

unsafe { Vec::from_raw_parts(ptr, len, cap) }
pub fn dma_buffer_to_bytes(buf: DmaBuffer) -> Bytes {
if buf.is_empty() {
return Bytes::new();
}
let (ptr, len, cap, alloc) = buf.into_raw_parts_with_alloc();
// Memory fitting
let old_layout = Layout::from_size_align(cap, alloc.0.as_usize()).unwrap();
let new_layout = Layout::from_size_align(len, std::mem::align_of::<u8>()).unwrap();
let data = unsafe {
let p = Global {}
.shrink(NonNull::new(ptr).unwrap(), old_layout, new_layout)
.unwrap();
let cap = p.len();
Vec::from_raw_parts(p.cast().as_mut(), len, cap)
};
Bytes::from(data)
}

/// A `DmaFile` is similar to a `File`, but it is opened with the `O_DIRECT` file in order to
Expand Down Expand Up @@ -697,4 +716,28 @@ mod tests {

let _ = std::fs::remove_file(filename);
}

#[test]
fn test_dma_buffer_to_bytes() {
let want = (0..10_u8).collect::<Vec<_>>();
let alloc = DmaAllocator::new(Alignment::new(4096).unwrap());
let mut buf = DmaBuffer::with_capacity_in(3000, alloc);
buf.extend_from_slice(&want);

println!("{:?} {}", buf.as_ptr(), buf.capacity());
buf.shrink_to_fit();
println!("{:?} {}", buf.as_ptr(), buf.capacity());
buf.reserve(3000 - buf.capacity());
println!("{:?} {}", buf.as_ptr(), buf.capacity());

// let slice = buf.into_boxed_slice();
// println!("{:?}", slice.as_ptr());

let got = dma_buffer_to_bytes(buf);
println!("{:?}", got.as_ptr());
assert_eq!(&want, &got);

let buf = got.to_vec();
println!("{:?} {}", buf.as_ptr(), buf.capacity());
}
}
2 changes: 1 addition & 1 deletion src/common/base/src/base/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ mod take_mut;
mod uniq_id;
mod watch_notify;

pub use dma::dma_buffer_as_vec;
pub use dma::dma_buffer_to_bytes;
pub use dma::dma_read_file;
pub use dma::dma_read_file_range;
pub use dma::dma_write_file_vectored;
Expand Down
1 change: 1 addition & 0 deletions src/common/base/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#![feature(slice_swap_unchecked)]
#![feature(variant_count)]
#![feature(ptr_alignment_type)]
#![feature(vec_into_raw_parts)]

pub mod base;
pub mod containers;
Expand Down
11 changes: 5 additions & 6 deletions src/query/service/src/spillers/spiller.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,7 @@ use std::ops::Range;
use std::sync::Arc;
use std::time::Instant;

use bytes::Bytes;
use databend_common_base::base::dma_buffer_as_vec;
use databend_common_base::base::dma_buffer_to_bytes;
use databend_common_base::base::dma_read_file_range;
use databend_common_base::base::Alignment;
use databend_common_base::base::DmaWriteBuf;
Expand Down Expand Up @@ -277,7 +276,7 @@ impl Spiller {
None => {
let file_size = path.size();
let (buf, range) = dma_read_file_range(path, 0..file_size as u64).await?;
Buffer::from(dma_buffer_as_vec(buf)).slice(range)
Buffer::from(dma_buffer_to_bytes(buf)).slice(range)
}
}
}
Expand Down Expand Up @@ -330,7 +329,7 @@ impl Spiller {
);

let (buf, range) = dma_read_file_range(path, 0..file_size as u64).await?;
Buffer::from(dma_buffer_as_vec(buf)).slice(range)
Buffer::from(dma_buffer_to_bytes(buf)).slice(range)
}
(Location::Local(path), Some(ref local)) => {
local
Expand Down Expand Up @@ -371,7 +370,7 @@ impl Spiller {
}
None => {
let (buf, range) = dma_read_file_range(path, data_range).await?;
Buffer::from(dma_buffer_as_vec(buf)).slice(range)
Buffer::from(dma_buffer_to_bytes(buf)).slice(range)
}
},
Location::Remote(loc) => self.operator.read_with(loc).range(data_range).await?,
Expand Down Expand Up @@ -410,7 +409,7 @@ impl Spiller {
let buf = buf
.into_data()
.into_iter()
.map(|x| Bytes::from(dma_buffer_as_vec(x)))
.map(dma_buffer_to_bytes)
.collect::<Buffer>();
let written = buf.len();
writer.write(buf).await?;
Expand Down
Loading