diff --git a/dlt/helpers/ibis_helper.py b/dlt/helpers/ibis_helper.py new file mode 100644 index 0000000000..9c5231c5f0 --- /dev/null +++ b/dlt/helpers/ibis_helper.py @@ -0,0 +1,26 @@ +import ibis +from typing import cast +from typing import Iterator +from dlt import Pipeline +from contextlib import contextmanager +from ibis import BaseBackend +from importlib import import_module + +IBIS_DESTINATION_MAP = {"synapse": "mssql", "redshift": "postgres"} + + +@contextmanager +def ibis_helper(p: Pipeline) -> Iterator[BaseBackend]: + """This helpers wraps a pipeline to expose an ibis backend to the main""" + + destination_type = p.destination_client().config.destination_type + + # apply destination map + destination_type = IBIS_DESTINATION_MAP.get(destination_type, destination_type) + + # get the right ibis module + ibis_module = import_module(f"ibis.backends.{destination_type}") + ibis_backend = cast(BaseBackend, ibis_module.Backend()) + + with p.sql_client() as c: + yield ibis_backend.from_connection(c) diff --git a/tests/load/test_ibis_helper.py b/tests/load/test_ibis_helper.py new file mode 100644 index 0000000000..ca251ac7c3 --- /dev/null +++ b/tests/load/test_ibis_helper.py @@ -0,0 +1,64 @@ +import pytest +import ibis +import dlt + +from tests.load.pipeline.utils import destinations_configs, DestinationTestConfiguration +from dlt.helpers.ibis_helper import ibis_helper + + +@pytest.mark.essential +@pytest.mark.parametrize( + "destination_config", + destinations_configs( + default_sql_configs=True, exclude=["athena", "dremio", "redshift", "databricks", "synapse"] + ), + ids=lambda x: x.name, +) +def test_ibis_helper(destination_config: DestinationTestConfiguration) -> None: + # we load a table with child table and check wether ibis works + pipeline = destination_config.setup_pipeline( + "ibis_pipeline", dataset_name="ibis_test", dev_mode=True + ) + pipeline.run( + [{"id": i + 10, "children": [{"id": i + 100}, {"id": i + 1000}]} for i in range(5)], + table_name="ibis_items", + ) + + with ibis_helper(pipeline) as ibis_backend: + # check we can read table names + assert {tname.lower() for tname in ibis_backend.list_tables()} >= { + "_dlt_loads", + "_dlt_pipeline_state", + "_dlt_version", + "ibis_items", + "ibis_items__children", + } + + id_identifier = "id" + if destination_config.destination == "snowflake": + id_identifier = id_identifier.upper() + + # check we can read data + assert ibis_backend.sql("SELECT id FROM ibis_items").to_pandas()[ + id_identifier + ].tolist() == [ + 10, + 11, + 12, + 13, + 14, + ] + assert ibis_backend.sql("SELECT id FROM ibis_items__children").to_pandas()[ + id_identifier + ].tolist() == [ + 100, + 1000, + 101, + 1001, + 102, + 1002, + 103, + 1003, + 104, + 1004, + ]