Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optmize tun read/write peformance for windows platform #78

Closed
wants to merge 12 commits into from
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,5 @@
target/
**/*.rs.bk
Cargo.lock
.cargo/
wintun.dll
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "tun"
version = "0.6.1"
version = "0.6.2"
edition = "2021"

authors = ["meh. <[email protected]>"]
Expand Down
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ First, add the following to your `Cargo.toml`:

```toml
[dependencies]
tun = "0.6.1"
tun = "0.6.2"
```

Next, add this to your crate root:
Expand All @@ -21,7 +21,7 @@ If you want to use the TUN interface with mio/tokio, you need to enable the `asy

```toml
[dependencies]
tun = { version = "0.6.1", features = ["async"] }
tun = { version = "0.6.2", features = ["async"] }
```

Example
Expand Down
9 changes: 3 additions & 6 deletions src/async/win/device.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,8 @@ impl AsyncRead for AsyncDevice {
) -> Poll<io::Result<()>> {
let rbuf = buf.initialize_unfilled();
match Pin::new(&mut self.inner).poll_read(cx, rbuf) {
Poll::Ready(Ok(n)) => {
buf.advance(n);
Poll::Ready(Ok(size)) => {
buf.advance(size);
Poll::Ready(Ok(()))
}
Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
Expand Down Expand Up @@ -126,10 +126,7 @@ impl AsyncRead for AsyncQueue {
) -> Poll<io::Result<()>> {
let rbuf = buf.initialize_unfilled();
match Pin::new(&mut self.inner).poll_read(cx, rbuf) {
Poll::Ready(Ok(n)) => {
buf.advance(n);
Poll::Ready(Ok(()))
}
Poll::Ready(Ok(_)) => Poll::Ready(Ok(())),
Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
Poll::Pending => Poll::Pending,
}
Expand Down
222 changes: 123 additions & 99 deletions src/platform/windows/device.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,11 @@

use std::io::{self, Read, Write};
use std::net::{IpAddr, Ipv4Addr};
#[cfg(feature = "async")]
use std::pin::Pin;
use std::sync::{Arc, Mutex};
use std::task::{Context, Poll};
use std::sync::Arc;
#[cfg(feature = "async")]
use std::thread;
use std::vec::Vec;

use wintun::Session;

Expand All @@ -32,30 +32,52 @@ pub struct Device {
mtu: usize,
}

#[cfg(feature = "async")]
fn create_queue(session: Session) -> Queue {
let session = Arc::new(session);
let (receiver_tx, receiver_rx) = tokio::sync::mpsc::unbounded_channel::<Vec<u8>>();
let session_reader = session.clone();
let task = thread::spawn(move || {
while let Ok(packet) = session_reader.receive_blocking() {
let bytes = packet.bytes().to_vec();
//dbg!(&bytes);
receiver_tx.send(bytes).unwrap();
}
});
Queue {
session,
receiver: receiver_rx,
_task: task,
}
}

#[cfg(not(feature = "async"))]
fn create_queue(session: Session) -> Queue {
Queue {
session: Arc::new(session),
}
}

impl Device {
/// Create a new `Device` for the given `Configuration`.
pub fn new(config: &Configuration) -> Result<Self> {
let wintun = unsafe { wintun::load()? };
let tun_name = config.name.as_deref().unwrap_or("wintun");
let guid = Some(9099482345783245345345_u128);
let adapter = match wintun::Adapter::open(&wintun, tun_name) {
Ok(a) => a,
Err(_) => wintun::Adapter::create(&wintun, tun_name, tun_name, guid)?,
Err(_) => wintun::Adapter::create(&wintun, tun_name, tun_name, None)?,
};

let address = config.address.unwrap_or(Ipv4Addr::new(10, 1, 0, 2));
let mask = config.netmask.unwrap_or(Ipv4Addr::new(255, 255, 255, 0));
let address = config.address.ok_or(Error::InvalidConfig)?;
let mask = config.netmask.ok_or(Error::InvalidConfig)?;
let gateway = config.destination.map(IpAddr::from);
adapter.set_network_addresses_tuple(IpAddr::V4(address), IpAddr::V4(mask), gateway)?;
let mtu = config.mtu.unwrap_or(1500) as usize;

let session = adapter.start_session(wintun::MAX_RING_CAPACITY)?;

let mut device = Device {
queue: Queue {
session: Arc::new(session),
cached: Arc::new(Mutex::new(Vec::with_capacity(mtu))),
},
queue: create_queue(session),
mtu,
};

Expand All @@ -64,37 +86,54 @@ impl Device {

Ok(device)
}
}

#[cfg(feature = "async")]
impl Device {
pub fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &mut [u8],
) -> Poll<io::Result<usize>> {
) -> std::task::Poll<std::io::Result<usize>> {
Pin::new(&mut self.queue).poll_read(cx, buf)
}

pub fn poll_write(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &[u8],
) -> std::task::Poll<Result<usize, std::io::Error>> {
Pin::new(&mut self.queue).poll_write(cx, buf)
}

pub fn poll_flush(
self: std::pin::Pin<&mut Self>,
_cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), std::io::Error>> {
std::task::Poll::Ready(Ok(()))
}

pub fn poll_shutdown(
self: std::pin::Pin<&mut Self>,
_cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), std::io::Error>> {
std::task::Poll::Ready(Ok(()))
}
}

impl Read for Device {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
self.queue.read(buf)
}

fn read_vectored(&mut self, bufs: &mut [io::IoSliceMut<'_>]) -> io::Result<usize> {
self.queue.read_vectored(bufs)
}
}

impl Write for Device {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
self.queue.write(buf)
}

fn write_vectored(&mut self, bufs: &[io::IoSlice<'_>]) -> io::Result<usize> {
self.queue.write_vectored(bufs)
}

fn flush(&mut self) -> io::Result<()> {
self.queue.flush()
Ok(())
}
}

Expand Down Expand Up @@ -183,7 +222,6 @@ impl D for Device {

fn set_mtu(&mut self, value: i32) -> Result<()> {
self.mtu = value as usize;
self.queue.cached = Arc::new(Mutex::new(Vec::with_capacity(self.mtu)));
Ok(())
}

Expand All @@ -194,99 +232,85 @@ impl D for Device {

pub struct Queue {
session: Arc<Session>,
cached: Arc<Mutex<Vec<u8>>>,
#[cfg(feature = "async")]
receiver: tokio::sync::mpsc::UnboundedReceiver<Vec<u8>>,
#[cfg(feature = "async")]
_task: thread::JoinHandle<()>,
}

#[cfg(feature = "async")]
impl Queue {
pub fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
mut buf: &mut [u8],
) -> Poll<io::Result<usize>> {
{
let mut cached = self
.cached
.lock()
.map_err(|e| io::Error::new(io::ErrorKind::Other, e.to_string()))?;
if cached.len() > 0 {
let res = match io::copy(&mut cached.as_slice(), &mut buf) {
Ok(n) => Poll::Ready(Ok(n as usize)),
Err(e) => Poll::Ready(Err(e)),
};
cached.clear();
return res;
}
}
let reader_session = self.session.clone();
match reader_session.try_receive() {
Err(e) => Poll::Ready(Err(io::Error::new(io::ErrorKind::Other, e))),
Ok(Some(packet)) => match io::copy(&mut packet.bytes(), &mut buf) {
Ok(n) => Poll::Ready(Ok(n as usize)),
Err(e) => Poll::Ready(Err(e)),
},
Ok(None) => {
let waker = cx.waker().clone();
let cached = self.cached.clone();
thread::spawn(move || {
match reader_session.receive_blocking() {
Ok(packet) => {
if let Ok(mut cached) = cached.lock() {
cached.extend_from_slice(packet.bytes());
} else {
log::error!("cached lock error in wintun reciever thread, packet will be dropped");
}
}
Err(e) => log::error!("receive_blocking error: {:?}", e),
}
waker.wake()
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &mut [u8],
) -> std::task::Poll<std::io::Result<usize>> {
match std::task::ready!(self.receiver.poll_recv(cx)) {
Some(bytes) => {
//dbg!(buf.len(), bytes.len());
bytes.iter().enumerate().for_each(|(index, value)| {
buf[index] = *value;
});
Poll::Pending
std::task::Poll::Ready(Ok(bytes.len()))
}
None => std::task::Poll::Ready(Ok(0)),
}
}

#[allow(dead_code)]
fn try_read(&mut self, mut buf: &mut [u8]) -> io::Result<usize> {
let reader_session = self.session.clone();
match reader_session.try_receive() {
Err(e) => Err(io::Error::new(io::ErrorKind::Other, e)),
Ok(op) => match op {
None => Ok(0),
Some(packet) => match io::copy(&mut packet.bytes(), &mut buf) {
Ok(s) => Ok(s as usize),
Err(e) => Err(e),
},
},
}
pub fn poll_write(
self: std::pin::Pin<&mut Self>,
_cx: &mut std::task::Context<'_>,
buf: &[u8],
) -> std::task::Poll<Result<usize, std::io::Error>> {
let mut write_pack = self.session.allocate_send_packet(buf.len() as u16)?;
write_pack.bytes_mut().copy_from_slice(buf.as_ref());
self.session.send_packet(write_pack);
std::task::Poll::Ready(Ok(buf.len()))
}

pub fn poll_flush(
self: std::pin::Pin<&mut Self>,
_cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), std::io::Error>> {
std::task::Poll::Ready(Ok(()))
}

pub fn poll_shutdown(
self: std::pin::Pin<&mut Self>,
_cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), std::io::Error>> {
std::task::Poll::Ready(Ok(()))
}
}

impl Read for Queue {
fn read(&mut self, mut buf: &mut [u8]) -> io::Result<usize> {
let reader_session = self.session.clone();
match reader_session.receive_blocking() {
Ok(pkt) => match io::copy(&mut pkt.bytes(), &mut buf) {
Ok(n) => Ok(n as usize),
Err(e) => Err(e),
},
Err(e) => Err(io::Error::new(io::ErrorKind::ConnectionAborted, e)),
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
match self.session.receive_blocking() {
Ok(pkt) => {
let bytes = pkt.bytes();
let len = bytes.len();
if len <= buf.len() {
buf[..len].clone_from_slice(bytes);
Ok(len)
} else {
Err(io::Error::new(
io::ErrorKind::InvalidInput,
"no large enough storage to save data",
))
}
}
Err(_) => Err(std::io::Error::new(std::io::ErrorKind::NotConnected, "")),
}
}
}

impl Write for Queue {
fn write(&mut self, mut buf: &[u8]) -> io::Result<usize> {
let size = buf.len();
match self.session.allocate_send_packet(size as u16) {
Err(e) => Err(io::Error::new(io::ErrorKind::OutOfMemory, e)),
Ok(mut packet) => match io::copy(&mut buf, &mut packet.bytes_mut()) {
Ok(s) => {
self.session.send_packet(packet);
Ok(s as usize)
}
Err(e) => Err(e),
},
}
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
let len = buf.len();
let mut write_pack = self.session.allocate_send_packet(len as u16)?;
write_pack.bytes_mut().copy_from_slice(buf.as_ref());
self.session.send_packet(write_pack);
Ok(len)
}

fn flush(&mut self) -> io::Result<()> {
Expand Down