Skip to content

Commit

Permalink
refactor: support more header for http
Browse files Browse the repository at this point in the history
  • Loading branch information
vicanso committed Oct 16, 2024
1 parent 538ba7a commit 24267bd
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 9 deletions.
34 changes: 29 additions & 5 deletions src/http_extra/http_header.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,16 @@ use snafu::{ResultExt, Snafu};
use std::str::FromStr;

pub const HOST_NAME_TAG: &[u8] = b"$hostname";
const HOST_TAG: &[u8] = b"$host";
const SCHEME_TAG: &[u8] = b"$scheme";
const REMOTE_ADDR_TAG: &[u8] = b"$remote_addr";
const REMOTE_PORT_TAG: &[u8] = b"$remote_port";
const PROXY_ADD_FORWARDED_TAG: &[u8] = b"$proxy_add_x_forwarded_for";
const HTTP_ORIGIN_TAG: &[u8] = b"$http_origin";
const UPSTREAM_ADDR_TAG: &[u8] = b"$upstream_addr";

static SCHEME_HTTPS: HeaderValue = HeaderValue::from_static("https");
static SCHEME_HTTP: HeaderValue = HeaderValue::from_static("http");

#[derive(Debug, Snafu)]
pub enum Error {
#[snafu(display("Invalid header value {source}, {value}"))]
Expand Down Expand Up @@ -67,6 +72,17 @@ pub fn convert_header_value(
) -> Option<HeaderValue> {
let buf = value.as_bytes();
match buf {
HOST_TAG => {
if let Some(value) = util::get_host(session.req_header()) {
return HeaderValue::from_str(value).ok();
}
},
SCHEME_TAG => {
if ctx.tls_version.is_some() {
return Some(SCHEME_HTTPS.clone());
}
return Some(SCHEME_HTTP.clone());
},
HOST_NAME_TAG => {
return HeaderValue::from_str(get_hostname()).ok();
},
Expand All @@ -75,6 +91,11 @@ pub fn convert_header_value(
return HeaderValue::from_str(remote_addr).ok();
}
},
REMOTE_PORT_TAG => {
if let Some(remote_port) = &ctx.remote_port {
return HeaderValue::from_str(&remote_port.to_string()).ok();
}
},
UPSTREAM_ADDR_TAG => {
if !ctx.upstream_address.is_empty() {
return HeaderValue::from_str(&ctx.upstream_address).ok();
Expand All @@ -96,11 +117,14 @@ pub fn convert_header_value(
return HeaderValue::from_str(&value).ok();
}
},
HTTP_ORIGIN_TAG => {
return session.get_header("origin").cloned();
},
_ => {
if buf.starts_with(b"$") {
let http_prefix = b"$http_";
if buf.starts_with(http_prefix) {
let key =
std::str::from_utf8(&buf[http_prefix.len()..buf.len()])
.unwrap_or_default();
return session.get_header(key).cloned();
} else if buf.starts_with(b"$") {
if let Ok(value) = std::env::var(
std::str::from_utf8(&buf[1..buf.len()]).unwrap_or_default(),
) {
Expand Down
6 changes: 5 additions & 1 deletion src/proxy/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -431,7 +431,11 @@ impl ProxyHttp for Server {

ctx.processing = self.processing.fetch_add(1, Ordering::Relaxed) + 1;
ctx.accepted = self.accepted.fetch_add(1, Ordering::Relaxed) + 1;
ctx.remote_addr = util::get_remote_addr(session);
if let Some((remote_addr, remote_port)) = util::get_remote_addr(session)
{
ctx.remote_addr = Some(remote_addr);
ctx.remote_port = Some(remote_port);
}

let header = session.req_header_mut();
let host = util::get_host(header).unwrap_or_default();
Expand Down
1 change: 1 addition & 0 deletions src/state/ctx.rs
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ pub struct State {
// the upstream address
pub upstream_address: String,
pub client_ip: Option<String>,
pub remote_port: Option<u16>,
pub remote_addr: Option<String>,
pub guard: Option<Guard>,
pub request_id: Option<String>,
Expand Down
6 changes: 3 additions & 3 deletions src/util/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,10 +71,10 @@ pub static HTTP_HEADER_X_FORWARDED_FOR: Lazy<http::HeaderName> =
pub static HTTP_HEADER_X_REAL_IP: Lazy<http::HeaderName> =
Lazy::new(|| HeaderName::from_str("X-Real-Ip").unwrap());

pub fn get_remote_addr(session: &Session) -> Option<String> {
pub fn get_remote_addr(session: &Session) -> Option<(String, u16)> {
if let Some(addr) = session.client_addr() {
if let Some(addr) = addr.as_inet() {
return Some(addr.ip().to_string());
return Some((addr.ip().to_string(), addr.port()));
}
}
None
Expand All @@ -95,7 +95,7 @@ pub fn get_client_ip(session: &Session) -> String {
if let Some(value) = session.get_header(HTTP_HEADER_X_REAL_IP.clone()) {
return value.to_str().unwrap_or_default().to_string();
}
if let Some(addr) = get_remote_addr(session) {
if let Some((addr, _)) = get_remote_addr(session) {
return addr;
}
"".to_string()
Expand Down

0 comments on commit 24267bd

Please sign in to comment.