Skip to content

Commit

Permalink
runtimes/core: propagate auth error from auth handler (#1485)
Browse files Browse the repository at this point in the history
Propagate auth errors from the auth handler
  • Loading branch information
fredr authored Oct 16, 2024
1 parent 38f7756 commit 86c4df8
Show file tree
Hide file tree
Showing 5 changed files with 70 additions and 35 deletions.
12 changes: 8 additions & 4 deletions runtimes/core/src/api/auth/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,9 @@ pub enum AuthResponse {
auth_uid: String,
auth_data: serde_json::Map<String, serde_json::Value>,
},
Unauthenticated,
Unauthenticated {
error: api::Error,
},
}

/// A trait for handlers that accept auth parameters and return an auth result.
Expand Down Expand Up @@ -86,7 +88,9 @@ impl Authenticator {
meta: CallMeta,
) -> APIResult<AuthResponse> {
if !self.contains_auth_params(req) {
return Ok(AuthResponse::Unauthenticated);
return Ok(AuthResponse::Unauthenticated {
error: api::Error::unauthenticated(),
});
}

let auth_req = self.build_auth_request(req, meta);
Expand All @@ -96,8 +100,8 @@ impl Authenticator {
};
match resp {
Ok(resp) => Ok(resp),
Err(err) if err.code == api::ErrCode::Unauthenticated => {
Ok(AuthResponse::Unauthenticated)
Err(error) if error.code == api::ErrCode::Unauthenticated => {
Ok(AuthResponse::Unauthenticated { error })
}
Err(err) => Err(err),
}
Expand Down
8 changes: 5 additions & 3 deletions runtimes/core/src/api/auth/remote.rs
Original file line number Diff line number Diff line change
Expand Up @@ -112,13 +112,15 @@ impl RemoteAuthHandler {
auth_data: data,
})
} else {
Ok(AuthResponse::Unauthenticated)
Ok(AuthResponse::Unauthenticated {
error: api::Error::unauthenticated(),
})
}
}

// Map the unauthenticated error code to the unauthenticated result.
Err(err) if err.code == api::ErrCode::Unauthenticated => {
Ok(AuthResponse::Unauthenticated)
Err(error) if error.code == api::ErrCode::Unauthenticated => {
Ok(AuthResponse::Unauthenticated { error })
}

Err(err) => Err(err),
Expand Down
9 changes: 9 additions & 0 deletions runtimes/core/src/api/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,15 @@ impl Error {
stack: None,
}
}

pub fn unauthenticated() -> Self {
Self {
code: ErrCode::Unauthenticated,
message: ErrCode::Unauthenticated.default_public_message().into(),
internal_message: None,
stack: None,
}
}
}

impl From<WebSocketUpgradeRejection> for Error {
Expand Down
45 changes: 28 additions & 17 deletions runtimes/core/src/api/gateway/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ use pingora::server::configuration::{Opt, ServerConf};
use pingora::services::Service;
use pingora::upstreams::peer::HttpPeer;
use pingora::{Error, ErrorSource, ErrorType, OkOrErr, OrErr};
use router::Target;
use tokio::sync::watch;
use url::Url;

Expand Down Expand Up @@ -49,6 +50,7 @@ pub struct GatewayCtx {
upstream_service_name: EncoreName,
upstream_base_path: String,
upstream_host: Option<String>,
upstream_require_auth: bool,
}

impl GatewayCtx {
Expand Down Expand Up @@ -199,14 +201,18 @@ impl ProxyHttp for Gateway {
// Check if this is a pubsub push request and if we need to proxy it to another service
let push_proxy_svc = path
.strip_prefix("/__encore/pubsub/push/")
.and_then(|sub_id| self.inner.proxied_push_subs.get(sub_id));
.and_then(|sub_id| self.inner.proxied_push_subs.get(sub_id))
.map(|svc| Target {
service_name: svc.clone(),
requires_auth: false,
});

if let Some(own_api_addr) = &self.inner.own_api_address {
if push_proxy_svc.is_none() && path.starts_with("/__encore/") {
return Ok(Box::new(HttpPeer::new(own_api_addr, false, "".to_string())));
}
}
let service_name = push_proxy_svc
let target = push_proxy_svc
.map_or_else(
|| {
// Find which service handles the path route
Expand All @@ -222,14 +228,16 @@ impl ProxyHttp for Gateway {
.route_to_service(method, path)
.context("couldn't find upstream")
})
.cloned()
},
Ok,
)
.or_err(ErrorType::InternalError, "couldn't find upstream")?;

let upstream = self
.inner
.service_registry
.service_base_url(service_name)
.service_base_url(&target.service_name)
.or_err(ErrorType::InternalError, "couldn't find upstream")?;

let upstream_url: Url = upstream
Expand Down Expand Up @@ -259,7 +267,8 @@ impl ProxyHttp for Gateway {
ctx.replace(GatewayCtx {
upstream_base_path: upstream_url.path().to_string(),
upstream_host: host,
upstream_service_name: service_name.clone(),
upstream_service_name: target.service_name.clone(),
upstream_require_auth: target.requires_auth,
});

Ok(Box::new(peer))
Expand Down Expand Up @@ -356,14 +365,20 @@ impl ProxyHttp for Gateway {
.await
.or_err(ErrorType::InternalError, "couldn't authenticate request")?;

if let auth::AuthResponse::Authenticated {
auth_uid,
auth_data,
} = auth_response
{
desc.auth_user_id = Some(Cow::Owned(auth_uid));
desc.auth_data = Some(auth_data);
}
match auth_response {
auth::AuthResponse::Authenticated {
auth_uid,
auth_data,
} => {
desc.auth_user_id = Some(Cow::Owned(auth_uid));
desc.auth_data = Some(auth_data);
}
auth::AuthResponse::Unauthenticated { error } => {
if gateway_ctx.upstream_require_auth {
return Err(error.into());
}
}
};
}

desc.add_meta(upstream_request)
Expand Down Expand Up @@ -441,11 +456,7 @@ impl ProxyHttp for Gateway {
}

fn as_api_error(err: &pingora::Error) -> Option<&api::Error> {
if let Some(cause) = &err.cause {
cause.downcast_ref::<api::Error>()
} else {
None
}
err.root_cause().downcast_ref::<api::Error>()
}

fn api_error_response(err: &api::Error) -> (ResponseHeader, bytes::Bytes) {
Expand Down
31 changes: 20 additions & 11 deletions runtimes/core/src/api/gateway/router.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,10 @@ impl Router {
::log::error!(method = method.as_str(), path = path; "tried to register same route twice, skipping");
continue;
}
dst.replace(service.clone());
dst.replace(Target {
service_name: service.clone(),
requires_auth: endpoint.requires_auth,
});
}
}

Expand Down Expand Up @@ -79,7 +82,7 @@ impl Router {
&self,
method: api::schema::Method,
path: &str,
) -> Result<&EncoreName, api::Error> {
) -> Result<&Target, api::Error> {
let mut found_path_match = false;
for router in [&self.main, &self.fallback] {
if let Ok(service) = router.at(path) {
Expand Down Expand Up @@ -110,20 +113,26 @@ impl Router {
}
}

#[derive(Clone, Debug)]
pub struct Target {
pub service_name: EncoreName,
pub requires_auth: bool,
}

#[derive(Clone, Default)]
pub struct MethodRoute {
get: Option<EncoreName>,
head: Option<EncoreName>,
post: Option<EncoreName>,
put: Option<EncoreName>,
delete: Option<EncoreName>,
option: Option<EncoreName>,
trace: Option<EncoreName>,
patch: Option<EncoreName>,
get: Option<Target>,
head: Option<Target>,
post: Option<Target>,
put: Option<Target>,
delete: Option<Target>,
option: Option<Target>,
trace: Option<Target>,
patch: Option<Target>,
}

impl MethodRoute {
fn for_method(&self, method: api::schema::Method) -> Option<&EncoreName> {
fn for_method(&self, method: api::schema::Method) -> Option<&Target> {
match method {
Method::GET => self.get.as_ref(),
Method::HEAD => self.head.as_ref(),
Expand Down

0 comments on commit 86c4df8

Please sign in to comment.