diff --git a/daft/io/_delta_lake.py b/daft/io/_delta_lake.py index a3c60e8c8e..6930d8bafc 100644 --- a/daft/io/_delta_lake.py +++ b/daft/io/_delta_lake.py @@ -8,6 +8,7 @@ from daft.daft import IOConfig, NativeStorageConfig, ScanOperatorHandle, StorageConfig from daft.dataframe import DataFrame from daft.io.catalog import DataCatalogTable +from daft.io.unity_catalog import UnityCatalogTable from daft.logical.builder import LogicalPlanBuilder @@ -25,7 +26,7 @@ def read_delta_lake( @PublicAPI def read_deltalake( - table: Union[str, DataCatalogTable], + table: Union[str, DataCatalogTable, UnityCatalogTable], io_config: Optional["IOConfig"] = None, _multithreaded_io: Optional[bool] = None, ) -> DataFrame: @@ -56,20 +57,27 @@ def read_deltalake( """ from daft.delta_lake.delta_lake_scan import DeltaLakeScanOperator - io_config = context.get_context().daft_planning_config.default_io_config if io_config is None else io_config - # If running on Ray, we want to limit the amount of concurrency and requests being made. # This is because each Ray worker process receives its own pool of thread workers and connections multithreaded_io = not context.get_context().is_ray_runner if _multithreaded_io is None else _multithreaded_io + + io_config = context.get_context().daft_planning_config.default_io_config if io_config is None else io_config storage_config = StorageConfig.native(NativeStorageConfig(multithreaded_io, io_config)) if isinstance(table, str): table_uri = table elif isinstance(table, DataCatalogTable): table_uri = table.table_uri(io_config) + elif isinstance(table, UnityCatalogTable): + table_uri = table.table_uri + + # Override the storage_config with the one provided by Unity catalog + table_io_config = table.io_config + if table_io_config is not None: + storage_config = StorageConfig.native(NativeStorageConfig(multithreaded_io, table_io_config)) else: raise ValueError( - f"table argument must be a table URI string or a DataCatalogTable instance, but got: {type(table)}, {table}" + f"table argument must be a table URI string, DataCatalogTable or UnityCatalogTable instance, but got: {type(table)}, {table}" ) delta_lake_operator = DeltaLakeScanOperator(table_uri, storage_config=storage_config) diff --git a/daft/io/unity_catalog.py b/daft/io/unity_catalog.py index e69de29bb2..fb790f96c6 100644 --- a/daft/io/unity_catalog.py +++ b/daft/io/unity_catalog.py @@ -0,0 +1,57 @@ +from __future__ import annotations + +import dataclasses + +import requests + +from daft.io import IOConfig, S3Config + + +@dataclasses.dataclass(frozen=True) +class UnityCatalogTable: + table_uri: str + io_config: IOConfig | None + + +class UnityCatalog: + def __init__(self, endpoint: str, token: str | None = None): + self._endpoint = endpoint + self._token_header = {"Authorization": f"Bearer {token}"} if token else {} + + def list_schemas(self): + raise NotImplementedError("Listing schemas not yet implemented.") + + def list_tables(self, schema: str): + raise NotImplementedError("Listing tables not yet implemented.") + + def load_table(self, name: str) -> UnityCatalogTable: + # Load the table ID + table_info = requests.get( + self._endpoint + f"/api/2.1/unity-catalog/tables/{name}", headers=self._token_header + ).json() + table_id = table_info["table_id"] + table_uri = table_info["storage_location"] + + # Grab credentials from Unity catalog and place it into the Table + temp_table_cred_endpoint = self._endpoint + "/api/2.1/unity-catalog/temporary-table-credentials" + response = requests.post( + temp_table_cred_endpoint, json={"table_id": table_id, "operation": "READ"}, headers=self._token_header + ) + + aws_temp_credentials = response.json()["aws_temp_credentials"] + io_config = ( + IOConfig( + s3=S3Config( + key_id=aws_temp_credentials.get("access_key_id"), + access_key=aws_temp_credentials.get("secret_access_key"), + session_token=aws_temp_credentials.get("session_token"), + ) + ) + if aws_temp_credentials is not None + else None + ) + + return UnityCatalogTable( + table_uri=table_uri, + io_config=io_config, + )