Skip to content

Commit

Permalink
Merge pull request #3119 from Ruadhri17/jwt-handling-unwrap
Browse files Browse the repository at this point in the history
fix: handle error when jwt token could not be retrieved
  • Loading branch information
reubenmiller authored Sep 10, 2024
2 parents e8b3405 + 49edd1a commit ea6adf6
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 20 deletions.
44 changes: 31 additions & 13 deletions crates/extensions/c8y_auth_proxy/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -256,21 +256,30 @@ async fn proxy_ws(
use axum::extract::ws::CloseFrame;
use tungstenite::error::Error;
let uri = format!("{}/{path}", host.ws);
let mut token = retrieve_token.not_matching(None).await;
let c8y = match connect_to_websocket(&token, &headers, &uri, &host).await {
Ok(c8y) => Ok(c8y),
Err(Error::Http(res)) if res.status() == StatusCode::UNAUTHORIZED => {
token = retrieve_token.not_matching(Some(&token)).await;
match connect_to_websocket(&token, &headers, &uri, &host).await {

let c8y = {
match retrieve_token.not_matching(None).await {
Ok(token) => match connect_to_websocket(&token, &headers, &uri, &host).await {
Ok(c8y) => Ok(c8y),
Err(e) => {
Err(anyhow::Error::from(e).context("Failed to connect to proxied websocket"))
Err(Error::Http(res)) if res.status() == StatusCode::UNAUTHORIZED => {
match retrieve_token.not_matching(Some(&token)).await {
Ok(token) => {
match connect_to_websocket(&token, &headers, &uri, &host).await {
Ok(c8y) => Ok(c8y),
Err(e) => Err(anyhow::Error::from(e)
.context("Failed to connect to proxied websocket")),
}
}
Err(e) => Err(e.context("Failed to retrieve JWT token")),
}
}
}
Err(e) => Err(anyhow::Error::from(e)),
},
Err(e) => Err(e.context("Failed to retrieve JWT token")),
}
Err(e) => Err(anyhow::Error::from(e)),
}
.context("Error connecting to proxied websocket");

let c8y = match c8y {
Err(e) => {
let _ = ws
Expand Down Expand Up @@ -413,7 +422,10 @@ async fn respond_to(
destination += query;
}

let mut token = retrieve_token.not_matching(None).await;
let mut token = retrieve_token
.not_matching(None)
.await
.with_context(|| "failed to retrieve JWT token")?;

if let Some(ws) = ws {
let path = path.to_owned();
Expand All @@ -429,7 +441,10 @@ async fn respond_to(
.await
.with_context(|| format!("making HEAD request to {destination}"))?;
if response.status() == StatusCode::UNAUTHORIZED {
token = retrieve_token.not_matching(Some(&token)).await;
token = retrieve_token
.not_matching(Some(&token))
.await
.with_context(|| "failed to retrieve JWT token")?;
}
}

Expand All @@ -448,7 +463,10 @@ async fn respond_to(
.with_context(|| format!("making proxied request to {destination}"))?;

if res.status() == StatusCode::UNAUTHORIZED {
token = retrieve_token.not_matching(Some(&token)).await;
token = retrieve_token
.not_matching(Some(&token))
.await
.with_context(|| "failed to retrieve JWT token")?;
if let Some(body) = body_clone {
res = send_request(Body::from(body), &token)
.await
Expand Down
14 changes: 7 additions & 7 deletions crates/extensions/c8y_auth_proxy/src/tokens.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ impl SharedTokenManager {
/// Returns a JWT that doesn't match the provided JWT
///
/// This prevents needless token refreshes if multiple requests are made in parallel
pub async fn not_matching(&self, input: Option<&Arc<str>>) -> Arc<str> {
pub async fn not_matching(&self, input: Option<&Arc<str>>) -> Result<Arc<str>, anyhow::Error> {
self.0.lock().await.not_matching(input).await
}
}
Expand All @@ -31,17 +31,17 @@ impl TokenManager {
}

impl TokenManager {
async fn not_matching(&mut self, input: Option<&Arc<str>>) -> Arc<str> {
async fn not_matching(&mut self, input: Option<&Arc<str>>) -> Result<Arc<str>, anyhow::Error> {
match (self.cached.as_mut(), input) {
(Some(token), None) => token.clone(),
(Some(token), None) => Ok(token.clone()),
// The token should have arisen from this TokenManager, so pointer equality is sufficient
(Some(token), Some(no_match)) if !Arc::ptr_eq(token, no_match) => token.clone(),
(Some(token), Some(no_match)) if !Arc::ptr_eq(token, no_match) => Ok(token.clone()),
_ => self.refresh().await,
}
}

async fn refresh(&mut self) -> Arc<str> {
self.cached = Some(self.recv.await_response(()).await.unwrap().unwrap().into());
self.cached.as_ref().unwrap().clone()
async fn refresh(&mut self) -> Result<Arc<str>, anyhow::Error> {
self.cached = Some(self.recv.await_response(()).await??.into());
Ok(self.cached.as_ref().unwrap().clone())
}
}

0 comments on commit ea6adf6

Please sign in to comment.