From 7b0814700c0e70603a76bf776c5649a4afe12513 Mon Sep 17 00:00:00 2001 From: Shahar Epstein Date: Sat, 8 Jun 2024 16:44:45 +0300 Subject: [PATCH] Add unit tests to Kafka base hook --- airflow/providers/apache/kafka/hooks/base.py | 3 +- .../providers/apache/kafka/hooks/test_base.py | 62 +++++++++++++++++++ 2 files changed, 63 insertions(+), 2 deletions(-) create mode 100644 tests/providers/apache/kafka/hooks/test_base.py diff --git a/airflow/providers/apache/kafka/hooks/base.py b/airflow/providers/apache/kafka/hooks/base.py index 2f99cb21eaa4d3..f45b773c278928 100644 --- a/airflow/providers/apache/kafka/hooks/base.py +++ b/airflow/providers/apache/kafka/hooks/base.py @@ -40,7 +40,6 @@ def __init__(self, kafka_config_id=default_conn_name, *args, **kwargs): """Initialize our Base.""" super().__init__() self.kafka_config_id = kafka_config_id - self.get_conn @classmethod def get_ui_field_behaviour(cls) -> dict[str, Any]: @@ -74,6 +73,6 @@ def test_connection(self) -> tuple[bool, str]: if t: return True, "Connection successful." except Exception as e: - False, str(e) + return False, str(e) return False, "Failed to establish connection." diff --git a/tests/providers/apache/kafka/hooks/test_base.py b/tests/providers/apache/kafka/hooks/test_base.py new file mode 100644 index 00000000000000..75b17396188af2 --- /dev/null +++ b/tests/providers/apache/kafka/hooks/test_base.py @@ -0,0 +1,62 @@ +from unittest import mock +from unittest.mock import MagicMock + +import pytest + +from airflow.providers.apache.kafka.hooks.base import KafkaBaseHook + + +class SomeKafkaHook(KafkaBaseHook): + def _get_client(self, config): + return config + + +@pytest.fixture +def hook(): + return SomeKafkaHook() + + +TIMEOUT = 10 + + +class TestKafkaBaseHook: + + @mock.patch("airflow.hooks.base.BaseHook.get_connection") + def test_get_conn(self, mock_get_connection, hook): + config = {"bootstrap.servers": MagicMock()} + mock_get_connection.return_value.extra_dejson = config + assert hook.get_conn == config + + @mock.patch("airflow.hooks.base.BaseHook.get_connection") + def test_get_conn_value_error(self, mock_get_connection, hook): + mock_get_connection.return_value.extra_dejson = {} + with pytest.raises(ValueError, match="must be provided"): + hook.get_conn() + + @mock.patch("airflow.providers.apache.kafka.hooks.base.AdminClient") + @mock.patch("airflow.hooks.base.BaseHook.get_connection") + def test_test_connection(self, mock_get_connection, admin_client, hook): + config = {"bootstrap.servers": MagicMock()} + mock_get_connection.return_value.extra_dejson = config + connection = hook.test_connection() + admin_client.assert_called_once_with(config, timeout=10) + assert connection == (True, "Connection successful.") + + @mock.patch("airflow.providers.apache.kafka.hooks.base.AdminClient", + return_value=MagicMock(list_topics=MagicMock(return_value=[]))) + @mock.patch("airflow.hooks.base.BaseHook.get_connection") + def test_test_connection_no_topics(self, mock_get_connection, admin_client, hook): + config = {"bootstrap.servers": MagicMock()} + mock_get_connection.return_value.extra_dejson = config + connection = hook.test_connection() + admin_client.assert_called_once_with(config, timeout=TIMEOUT) + assert connection == (False, "Failed to establish connection.") + + @mock.patch("airflow.providers.apache.kafka.hooks.base.AdminClient") + @mock.patch("airflow.hooks.base.BaseHook.get_connection") + def test_test_connection_exception(self, mock_get_connection, admin_client, hook): + config = {"bootstrap.servers": MagicMock()} + mock_get_connection.return_value.extra_dejson = config + admin_client.return_value.list_topics.side_effect = [ValueError("some error")] + connection = hook.test_connection() + assert connection == (False, "some error")