Skip to content

Commit

Permalink
[#184] Send IV/Salt with the first payload packet
Browse files Browse the repository at this point in the history
TFO on macOS (seems) the second send call must wait until the first recv
has called. Don't know why.
  • Loading branch information
zonyitoo committed Jan 2, 2020
1 parent 353e7fc commit a9fdfc5
Show file tree
Hide file tree
Showing 4 changed files with 68 additions and 65 deletions.
20 changes: 16 additions & 4 deletions src/relay/tcprelay/aead.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ use std::{
};

use byteorder::{BigEndian, ByteOrder};
use bytes::{BufMut, BytesMut};
use bytes::{BufMut, Bytes, BytesMut};
use futures::ready;
use tokio::prelude::*;

Expand Down Expand Up @@ -217,15 +217,17 @@ pub struct EncryptedWriter {
cipher: BoxAeadEncryptor,
tag_size: usize,
steps: EncryptWriteStep,
nonce_opt: Option<Bytes>,
}

impl EncryptedWriter {
/// Creates a new EncryptedWriter
pub fn new(t: CipherType, key: &[u8], nonce: &[u8]) -> EncryptedWriter {
pub fn new(t: CipherType, key: &[u8], nonce: Bytes) -> EncryptedWriter {
EncryptedWriter {
cipher: crypto::new_aead_encryptor(t, key, nonce),
cipher: crypto::new_aead_encryptor(t, key, &nonce),
tag_size: t.tag_size(),
steps: EncryptWriteStep::Nothing,
nonce_opt: Some(nonce),
}
}

Expand Down Expand Up @@ -253,7 +255,17 @@ impl EncryptedWriter {
let output_length = self.buffer_size(data);
let data_length = data.len() as u16;

let mut buf = BytesMut::with_capacity(output_length);
// First packet is IV
let iv_len = match self.nonce_opt {
Some(ref v) => v.len(),
None => 0,
};

let mut buf = BytesMut::with_capacity(iv_len + output_length);

if let Some(iv) = self.nonce_opt.take() {
buf.extend(iv);
}

let mut data_len_buf = [0u8; 2];
BigEndian::write_u16(&mut data_len_buf, data_length);
Expand Down
58 changes: 15 additions & 43 deletions src/relay/tcprelay/crypto_io.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ use std::{
};

use byte_string::ByteStr;
use bytes::Bytes;
use futures::ready;
use log::trace;
use tokio::prelude::*;
Expand All @@ -36,18 +35,12 @@ enum ReadStatus {
Established,
}

enum WriteStatus {
SendIv(Bytes, usize),
Established,
}

pub struct CryptoStream<S> {
stream: S,
dec: Option<DecryptedReader>,
enc: Option<EncryptedWriter>,
enc: EncryptedWriter,
svr_cfg: Arc<ServerConfig>,
read_status: ReadStatus,
write_status: WriteStatus,
}

impl<S: Unpin> Unpin for CryptoStream<S> {}
Expand All @@ -73,13 +66,24 @@ impl<S> CryptoStream<S> {
}
};

let method = svr_cfg.method();
let enc = match method.category() {
CipherCategory::Stream => {
trace!("Sent Stream cipher IV {:?}", ByteStr::new(&local_iv));
EncryptedWriter::Stream(StreamEncryptedWriter::new(method, svr_cfg.key(), local_iv))
}
CipherCategory::Aead => {
trace!("Sent AEAD cipher salt {:?}", ByteStr::new(&local_iv));
EncryptedWriter::Aead(AeadEncryptedWriter::new(method, svr_cfg.key(), local_iv))
}
};

CryptoStream {
stream,
dec: None,
enc: None,
enc,
svr_cfg,
read_status: ReadStatus::WaitIv(vec![0u8; prev_len], 0usize),
write_status: WriteStatus::SendIv(local_iv, 0usize),
}
}
}
Expand Down Expand Up @@ -133,41 +137,9 @@ impl<S> CryptoStream<S>
where
S: AsyncWrite + Unpin,
{
fn poll_write_handshake(&mut self, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
if let WriteStatus::SendIv(ref iv, ref mut pos) = self.write_status {
while *pos < iv.len() {
let n = ready!(Pin::new(&mut self.stream).poll_write(cx, &iv[*pos..]))?;
if n == 0 {
use std::io::ErrorKind;
return Poll::Ready(Err(ErrorKind::UnexpectedEof.into()));
}
*pos += n;
}

let method = self.svr_cfg.method();
let enc = match method.category() {
CipherCategory::Stream => {
trace!("Sent Stream cipher IV {:?}", ByteStr::new(&iv));
EncryptedWriter::Stream(StreamEncryptedWriter::new(method, self.svr_cfg.key(), &iv))
}
CipherCategory::Aead => {
trace!("Sent AEAD cipher salt {:?}", ByteStr::new(&iv));
EncryptedWriter::Aead(AeadEncryptedWriter::new(method, self.svr_cfg.key(), &iv))
}
};

self.enc = Some(enc);
self.write_status = WriteStatus::Established;
}

Poll::Ready(Ok(()))
}

fn priv_poll_write(mut self: Pin<&mut Self>, ctx: &mut Context<'_>, buf: &[u8]) -> Poll<io::Result<usize>> {
ready!(self.poll_write_handshake(ctx))?;

let stream = unsafe { &mut *(&mut self.stream as *mut _) };
match *self.enc.as_mut().unwrap() {
match self.enc {
EncryptedWriter::Aead(ref mut w) => w.poll_write_encrypted(ctx, stream, buf),
EncryptedWriter::Stream(ref mut w) => w.poll_write_encrypted(ctx, stream, buf),
}
Expand Down
19 changes: 15 additions & 4 deletions src/relay/tcprelay/stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use std::{
};

use crate::crypto::{new_stream, BoxStreamCipher, CipherType, CryptoMode};
use bytes::{BufMut, BytesMut};
use bytes::{BufMut, Bytes, BytesMut};
use futures::ready;
use tokio::prelude::*;

Expand Down Expand Up @@ -94,14 +94,16 @@ enum EncryptWriteStep {
pub struct EncryptedWriter {
cipher: BoxStreamCipher,
steps: EncryptWriteStep,
iv_opt: Option<Bytes>,
}

impl EncryptedWriter {
/// Creates a new EncryptedWriter
pub fn new(t: CipherType, key: &[u8], iv: &[u8]) -> EncryptedWriter {
pub fn new(t: CipherType, key: &[u8], iv: Bytes) -> EncryptedWriter {
EncryptedWriter {
cipher: new_stream(t, key, iv, CryptoMode::Encrypt),
cipher: new_stream(t, key, &iv, CryptoMode::Encrypt),
steps: EncryptWriteStep::Nothing,
iv_opt: Some(iv),
}
}

Expand All @@ -122,7 +124,16 @@ impl EncryptedWriter {
loop {
match self.steps {
EncryptWriteStep::Nothing => {
let mut buf = BytesMut::with_capacity(self.buffer_size(data));
let iv_len = match self.iv_opt {
Some(ref iv) => iv.len(),
None => 0,
};

let mut buf = BytesMut::with_capacity(iv_len + self.buffer_size(data));
if let Some(iv) = self.iv_opt.take() {
buf.extend(iv);
}

self.cipher_update(data, &mut buf)?;

self.steps = EncryptWriteStep::Writing(buf, 0);
Expand Down
36 changes: 22 additions & 14 deletions src/relay/tcprelay/utils/split.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,11 @@ pub struct ReadHalf<'a> {
}

impl<'a> ReadHalf<'a> {
fn stream(&self) -> &'a mut TcpStream {
fn stream(&self) -> &'a TcpStream {
unsafe { &mut *self.stream }
}

fn stream_mut(&mut self) -> &'a mut TcpStream {
unsafe { &mut *self.stream }
}
}
Expand All @@ -35,16 +39,20 @@ impl AsyncRead for ReadHalf<'_> {
self.stream().prepare_uninitialized_buffer(buf)
}

fn poll_read(self: Pin<&mut Self>, cx: &mut task::Context<'_>, buf: &mut [u8]) -> task::Poll<io::Result<usize>> {
Pin::new(self.stream()).poll_read(cx, buf)
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut task::Context<'_>,
buf: &mut [u8],
) -> task::Poll<io::Result<usize>> {
Pin::new(self.stream_mut()).poll_read(cx, buf)
}

fn poll_read_buf<B: BufMut>(
self: Pin<&mut Self>,
mut self: Pin<&mut Self>,
cx: &mut task::Context<'_>,
buf: &mut B,
) -> task::Poll<io::Result<usize>> {
Pin::new(self.stream()).poll_read_buf(cx, buf)
Pin::new(self.stream_mut()).poll_read_buf(cx, buf)
}
}

Expand All @@ -61,7 +69,7 @@ pub struct WriteHalf<'a> {
}

impl<'a> WriteHalf<'a> {
fn stream(&self) -> &'a mut TcpStream {
fn stream_mut(&mut self) -> &'a mut TcpStream {
unsafe { &mut *self.stream }
}
}
Expand All @@ -74,27 +82,27 @@ impl AsRef<TcpStream> for WriteHalf<'_> {

impl AsyncWrite for WriteHalf<'_> {
fn poll_write(
self: Pin<&mut Self>,
mut self: Pin<&mut Self>,
cx: &mut task::Context<'_>,
buf: &[u8],
) -> task::Poll<Result<usize, io::Error>> {
Pin::new(self.stream()).poll_write(cx, buf)
Pin::new(self.stream_mut()).poll_write(cx, buf)
}

fn poll_flush(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> task::Poll<Result<(), io::Error>> {
Pin::new(self.stream()).poll_flush(cx)
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> task::Poll<Result<(), io::Error>> {
Pin::new(self.stream_mut()).poll_flush(cx)
}

fn poll_shutdown(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> task::Poll<Result<(), io::Error>> {
Pin::new(self.stream()).poll_shutdown(cx)
fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> task::Poll<Result<(), io::Error>> {
Pin::new(self.stream_mut()).poll_shutdown(cx)
}

fn poll_write_buf<B: Buf>(
self: Pin<&mut Self>,
mut self: Pin<&mut Self>,
cx: &mut task::Context<'_>,
buf: &mut B,
) -> task::Poll<Result<usize, io::Error>> {
Pin::new(self.stream()).poll_write_buf(cx, buf)
Pin::new(self.stream_mut()).poll_write_buf(cx, buf)
}
}

Expand Down

0 comments on commit a9fdfc5

Please sign in to comment.