diff --git a/crates/catalog/rest/src/catalog.rs b/crates/catalog/rest/src/catalog.rs index c10d904b6..8112a44b0 100644 --- a/crates/catalog/rest/src/catalog.rs +++ b/crates/catalog/rest/src/catalog.rs @@ -38,7 +38,7 @@ use iceberg::{ use self::_serde::{ CatalogConfig, ErrorResponse, ListNamespaceResponse, ListTableResponse, NamespaceSerde, - RenameTableRequest, NO_CONTENT, OK, + RenameTableRequest, TokenResponse, NO_CONTENT, OK, }; const ICEBERG_REST_SPEC_VERSION: &str = "0.14.1"; @@ -96,9 +96,13 @@ impl RestCatalogConfig { .join("/") } + fn get_token_endpoint(&self) -> String { + [&self.uri, PATH_V1, "oauth", "tokens"].join("/") + } + fn try_create_rest_client(&self) -> Result { - //TODO: We will add oauth, ssl config, sigv4 later - let headers = HeaderMap::from_iter([ + // TODO: We will add ssl config, sigv4 later + let mut headers = HeaderMap::from_iter([ ( header::CONTENT_TYPE, HeaderValue::from_static("application/json"), @@ -113,6 +117,19 @@ impl RestCatalogConfig { ), ]); + if let Some(token) = self.props.get("token") { + headers.insert( + header::AUTHORIZATION, + HeaderValue::from_str(&format!("Bearer {token}")).map_err(|e| { + Error::new( + ErrorKind::DataInvalid, + "Invalid token received from catalog server!", + ) + .with_source(e) + })?, + ); + } + Ok(HttpClient( Client::builder().default_headers(headers).build()?, )) @@ -144,6 +161,7 @@ impl HttpClient { .with_source(e) })?) } else { + let code = resp.status(); let text = resp.bytes().await?; let e = serde_json::from_slice::(&text).map_err(|e| { Error::new( @@ -151,6 +169,7 @@ impl HttpClient { "Failed to parse response from rest catalog server!", ) .with_context("json", String::from_utf8_lossy(&text)) + .with_context("code", code.to_string()) .with_source(e) })?; Err(e.into()) @@ -497,13 +516,56 @@ impl RestCatalog { client: config.try_create_rest_client()?, config, }; - + catalog.fetch_access_token().await?; + catalog.client = catalog.config.try_create_rest_client()?; catalog.update_config().await?; catalog.client = catalog.config.try_create_rest_client()?; Ok(catalog) } + async fn fetch_access_token(&mut self) -> Result<()> { + if self.config.props.contains_key("token") { + return Ok(()); + } + if let Some(credential) = self.config.props.get("credential") { + let (client_id, client_secret) = if credential.contains(':') { + let (client_id, client_secret) = credential.split_once(':').unwrap(); + (Some(client_id), client_secret) + } else { + (None, credential.as_str()) + }; + let mut params = HashMap::with_capacity(4); + params.insert("grant_type", "client_credentials"); + if let Some(client_id) = client_id { + params.insert("client_id", client_id); + } + params.insert("client_secret", client_secret); + params.insert("scope", "catalog"); + let req = self + .client + .0 + .post(self.config.get_token_endpoint()) + .form(¶ms) + .build()?; + let res = self + .client + .query::(req) + .await + .map_err(|e| { + Error::new( + ErrorKind::Unexpected, + "Failed to fetch access token from catalog server!", + ) + .with_source(e) + })?; + let token = res.access_token; + self.config.props.insert("token".to_string(), token); + } + + Ok(()) + } + async fn update_config(&mut self) -> Result<()> { let mut request = self.client.0.get(self.config.config_endpoint()); @@ -626,6 +688,14 @@ mod _serde { } } + #[derive(Debug, Serialize, Deserialize)] + pub(super) struct TokenResponse { + pub(super) access_token: String, + pub(super) token_type: String, + pub(super) expires_in: Option, + pub(super) issued_token_type: Option, + } + #[derive(Debug, Serialize, Deserialize)] pub(super) struct NamespaceSerde { pub(super) namespace: Vec, @@ -778,6 +848,44 @@ mod tests { .await } + async fn create_oauth_mock(server: &mut ServerGuard) -> Mock { + server + .mock("POST", "/v1/oauth/tokens") + .with_status(200) + .with_body( + r#"{ + "access_token": "ey000000000000", + "token_type": "Bearer", + "issued_token_type": "urn:ietf:params:oauth:token-type:access_token", + "expires_in": 86400 + }"#, + ) + .create_async() + .await + } + + #[tokio::test] + async fn test_oauth() { + let mut server = Server::new_async().await; + let oauth_mock = create_oauth_mock(&mut server).await; + let config_mock = create_config_mock(&mut server).await; + + let mut props = HashMap::new(); + props.insert("credential".to_string(), "client1:secret1".to_string()); + + let _catalog = RestCatalog::new( + RestCatalogConfig::builder() + .uri(server.url()) + .props(props) + .build(), + ) + .await + .unwrap(); + + oauth_mock.assert_async().await; + config_mock.assert_async().await; + } + #[tokio::test] async fn test_list_namespace() { let mut server = Server::new_async().await; @@ -1557,7 +1665,7 @@ mod tests { "type": "NoSuchTableException", "code": 404 } -} +} "#, ) .create_async() diff --git a/crates/catalog/rest/tests/rest_catalog_test.rs b/crates/catalog/rest/tests/rest_catalog_test.rs index a4d07955b..205428d61 100644 --- a/crates/catalog/rest/tests/rest_catalog_test.rs +++ b/crates/catalog/rest/tests/rest_catalog_test.rs @@ -66,6 +66,7 @@ async fn set_test_fixture(func: &str) -> TestFixture { rest_catalog, } } + #[tokio::test] async fn test_get_non_exist_namespace() { let fixture = set_test_fixture("test_get_non_exist_namespace").await;