Skip to content

Commit

Permalink
add connection pool for tunnel
Browse files Browse the repository at this point in the history
  • Loading branch information
cssivision committed Apr 28, 2024
1 parent 794d584 commit 56de24e
Showing 1 changed file with 99 additions and 35 deletions.
134 changes: 99 additions & 35 deletions src/bin/http2socks.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use std::collections::HashMap;
use std::future::poll_fn;
use std::io;
use std::net::{IpAddr, SocketAddr};
use std::pin::pin;
use std::str::FromStr;
use std::sync::Arc;
use std::task::{ready, Poll};
Expand All @@ -16,7 +17,7 @@ use base64::engine::general_purpose::PAD;
use base64::engine::GeneralPurpose;
use base64::Engine;
use bytes::Bytes;
use futures_util::io::{AsyncReadExt, AsyncWriteExt};
use futures_util::io::{AsyncRead, AsyncReadExt, AsyncWriteExt};
use futures_util::lock::Mutex;
use http_body_util::{combinators::BoxBody, BodyExt, Empty, Full};
use hyper::client::conn::http1::SendRequest;
Expand Down Expand Up @@ -120,6 +121,7 @@ fn main() -> io::Result<()> {
.to_vec();

let socks_client = SocksClient::new(server_addr);
let socks_stream_client = SocksStreamClient::new(server_addr);

awak::block_on(async move {
let listener = TcpListener::bind(&local_addr).await?;
Expand All @@ -131,6 +133,7 @@ fn main() -> io::Result<()> {
let io = HyperIo::new(stream);
let authorization = authorization.clone();
let socks_client = socks_client.clone();
let socks_stream_client = socks_stream_client.clone();

awak::spawn(async move {
if let Err(err) = http1::Builder::new()
Expand All @@ -139,10 +142,25 @@ fn main() -> io::Result<()> {
.timer(HyperTimer::new())
.serve_connection(
io,
service_fn(|req| {
service_fn(|mut req| {
let authorization = authorization.clone();
let socks_client = socks_client.clone();
async move { proxy(socks_client, req, server_addr, authorization).await }
let socks_stream_client = socks_stream_client.clone();
async move {
log::debug!("req: {:?}", req);
if !proxy_authorization(
&authorization,
req.headers().get(PROXY_AUTHORIZATION),
) {
log::error!("authorization fail");
let mut resp = Response::new(empty());
*resp.status_mut() =
http::StatusCode::PROXY_AUTHENTICATION_REQUIRED;
return Ok(resp);
}
let _ = req.headers_mut().remove(PROXY_AUTHORIZATION);
proxy(socks_client, socks_stream_client, req).await
}
}),
)
.with_upgrades()
Expand Down Expand Up @@ -185,20 +203,10 @@ fn full<T: Into<Bytes>>(chunk: T) -> BoxBody<Bytes, hyper::Error> {
}

async fn proxy(
socks_client: SocksClient,
mut req: Request<hyper::body::Incoming>,
server_addr: SocketAddr,
authorization: Vec<u8>,
client: SocksClient,
stream_client: SocksStreamClient,
req: Request<hyper::body::Incoming>,
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
log::debug!("req: {:?}", req);
if !proxy_authorization(&authorization, req.headers().get(PROXY_AUTHORIZATION)) {
log::error!("authorization fail");
let mut resp = Response::new(empty());
*resp.status_mut() = http::StatusCode::PROXY_AUTHENTICATION_REQUIRED;
return Ok(resp);
}
let _ = req.headers_mut().remove(PROXY_AUTHORIZATION);

if Method::CONNECT == req.method() {
// Received an HTTP request like:
// ```
Expand All @@ -219,7 +227,7 @@ async fn proxy(
awak::spawn(async move {
match hyper::upgrade::on(req).await {
Ok(upgraded) => {
if let Err(e) = tunnel(upgraded, host, port, server_addr).await {
if let Err(e) = stream_client.tunnel(upgraded, host, port).await {
log::error!("tunnel io error: {}", e);
};
}
Expand All @@ -235,7 +243,7 @@ async fn proxy(
Ok(resp)
}
} else {
match socks_client.send_request(req).await {
match client.send_request(req).await {
Ok(res) => Ok(Response::new(res.boxed())),
Err(e) => {
let mut resp = Response::new(full(format!("proxy server interval error {:?}", e)));
Expand All @@ -246,20 +254,75 @@ async fn proxy(
}
}

// Create a TCP connection to host:port, build a tunnel between the connection and
// the upgraded connection
async fn tunnel(
upgraded: Upgraded,
struct SocksStreamConnector {
server_addr: SocketAddr,
host: String,
port: u16,
}

impl ManageConnection for SocksStreamConnector {
/// The connection type this manager deals with.
type Connection = TcpStream;

/// Attempts to create a new connection.
async fn connect(&self) -> io::Result<Self::Connection> {
let mut stream = timeout(CONNECT_TIMEOUT, TcpStream::connect(self.server_addr)).await??;
handshake(&mut stream, CONNECT_TIMEOUT, self.host.clone(), self.port).await?;
Ok(stream)
}

/// Check if the connection is still valid, check background every `check_interval`.
///
/// A standard implementation would check if a simple query like `PING` succee,
/// if the `Connection` is broken, error should return.
async fn check(&self, mut conn: &mut Self::Connection) -> io::Result<()> {
poll_fn(|cx| {
if matches!(
ready!(pin!(&mut conn).poll_read(cx, &mut [0u8])),
Err(err) if err.kind() == io::ErrorKind::WouldBlock
) {
Poll::Ready(Ok(()))
} else {
Poll::Ready(Err(other("check fail")))
}
})
.await
}
}

#[derive(Clone)]
struct SocksStreamClient {
inner: Arc<Mutex<HashMap<String, Pool<SocksStreamConnector>>>>,
server_addr: SocketAddr,
) -> io::Result<()> {
let mut upgraded = HyperIo::new(upgraded);
let mut server = timeout(CONNECT_TIMEOUT, TcpStream::connect(server_addr)).await??;
handshake(&mut server, CONNECT_TIMEOUT, host, port).await?;
let (n1, n2) = copy_bidirectional(&mut upgraded, &mut server).await?;
log::debug!("client wrote {} bytes and received {} bytes", n1, n2);
Ok(())
}

impl SocksStreamClient {
fn new(server_addr: SocketAddr) -> Self {
Self {
inner: Arc::new(Mutex::new(HashMap::new())),
server_addr,
}
}

// Create a TCP connection to host:port, build a tunnel between the connection and
// the upgraded connection
async fn tunnel(&self, upgraded: Upgraded, host: String, port: u16) -> io::Result<()> {
let mut upgraded = HyperIo::new(upgraded);
let mut inner = self.inner.lock().await;
let pool = inner
.entry(format!("{}:{}", host, port))
.or_insert_with(|| {
Builder::new().build(SocksStreamConnector {
server_addr: self.server_addr,
host,
port,
})
});
let mut stream = pool.get().await?;
let (n1, n2) = copy_bidirectional(&mut upgraded, &mut *stream).await?;
log::debug!("client wrote {} bytes and received {} bytes", n1, n2);
Ok(())
}
}

struct SocksConnector {
Expand Down Expand Up @@ -328,15 +391,16 @@ impl SocksClient {
let port = req.uri().port_u16().unwrap_or(80);
log::debug!("proxy {}:{} to {:?}", host, port, self.server_addr);

let conn = SocksConnector {
server_addr: self.server_addr,
host: host.clone(),
port,
};
let mut inner = self.inner.lock().await;
let pool = inner
.entry(format!("{}:{}", host, port))
.or_insert_with(|| Builder::new().build(conn));
.or_insert_with(|| {
Builder::new().build(SocksConnector {
server_addr: self.server_addr,
host,
port,
})
});

let mut request_sender = pool.get().await?;
request_sender
Expand Down

0 comments on commit 56de24e

Please sign in to comment.