diff --git a/airflow/providers/apache/kafka/hooks/base.py b/airflow/providers/apache/kafka/hooks/base.py index 2f99cb21eaa4d3..f45b773c278928 100644 --- a/airflow/providers/apache/kafka/hooks/base.py +++ b/airflow/providers/apache/kafka/hooks/base.py @@ -40,7 +40,6 @@ def __init__(self, kafka_config_id=default_conn_name, *args, **kwargs): """Initialize our Base.""" super().__init__() self.kafka_config_id = kafka_config_id - self.get_conn @classmethod def get_ui_field_behaviour(cls) -> dict[str, Any]: @@ -74,6 +73,6 @@ def test_connection(self) -> tuple[bool, str]: if t: return True, "Connection successful." except Exception as e: - False, str(e) + return False, str(e) return False, "Failed to establish connection." diff --git a/tests/providers/apache/kafka/hooks/test_base.py b/tests/providers/apache/kafka/hooks/test_base.py new file mode 100644 index 00000000000000..c1ca9544b8f9b3 --- /dev/null +++ b/tests/providers/apache/kafka/hooks/test_base.py @@ -0,0 +1,81 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from unittest import mock +from unittest.mock import MagicMock + +import pytest + +from airflow.providers.apache.kafka.hooks.base import KafkaBaseHook + + +class SomeKafkaHook(KafkaBaseHook): + def _get_client(self, config): + return config + + +@pytest.fixture +def hook(): + return SomeKafkaHook() + + +TIMEOUT = 10 + + +class TestKafkaBaseHook: + @mock.patch("airflow.hooks.base.BaseHook.get_connection") + def test_get_conn(self, mock_get_connection, hook): + config = {"bootstrap.servers": MagicMock()} + mock_get_connection.return_value.extra_dejson = config + assert hook.get_conn == config + + @mock.patch("airflow.hooks.base.BaseHook.get_connection") + def test_get_conn_value_error(self, mock_get_connection, hook): + mock_get_connection.return_value.extra_dejson = {} + with pytest.raises(ValueError, match="must be provided"): + hook.get_conn() + + @mock.patch("airflow.providers.apache.kafka.hooks.base.AdminClient") + @mock.patch("airflow.hooks.base.BaseHook.get_connection") + def test_test_connection(self, mock_get_connection, admin_client, hook): + config = {"bootstrap.servers": MagicMock()} + mock_get_connection.return_value.extra_dejson = config + connection = hook.test_connection() + admin_client.assert_called_once_with(config, timeout=10) + assert connection == (True, "Connection successful.") + + @mock.patch( + "airflow.providers.apache.kafka.hooks.base.AdminClient", + return_value=MagicMock(list_topics=MagicMock(return_value=[])), + ) + @mock.patch("airflow.hooks.base.BaseHook.get_connection") + def test_test_connection_no_topics(self, mock_get_connection, admin_client, hook): + config = {"bootstrap.servers": MagicMock()} + mock_get_connection.return_value.extra_dejson = config + connection = hook.test_connection() + admin_client.assert_called_once_with(config, timeout=TIMEOUT) + assert connection == (False, "Failed to establish connection.") + + @mock.patch("airflow.providers.apache.kafka.hooks.base.AdminClient") + @mock.patch("airflow.hooks.base.BaseHook.get_connection") + def test_test_connection_exception(self, mock_get_connection, admin_client, hook): + config = {"bootstrap.servers": MagicMock()} + mock_get_connection.return_value.extra_dejson = config + admin_client.return_value.list_topics.side_effect = [ValueError("some error")] + connection = hook.test_connection() + assert connection == (False, "some error") diff --git a/tests/providers/apache/kafka/hooks/test_client.py b/tests/providers/apache/kafka/hooks/test_client.py index 16ffa5ac4d354f..3a97b9e3929c12 100644 --- a/tests/providers/apache/kafka/hooks/test_client.py +++ b/tests/providers/apache/kafka/hooks/test_client.py @@ -18,9 +18,11 @@ import json import logging +from unittest.mock import MagicMock, patch import pytest -from confluent_kafka.admin import AdminClient +from confluent_kafka import KafkaException +from confluent_kafka.admin import AdminClient, NewTopic from airflow.models import Connection from airflow.providers.apache.kafka.hooks.client import KafkaAdminClientHook @@ -31,11 +33,7 @@ log = logging.getLogger(__name__) -class TestSampleHook: - """ - Test Admin Client Hook. - """ - +class TestKafkaAdminClientHook: def setup_method(self): db.merge_conn( Connection( @@ -54,23 +52,48 @@ def setup_method(self): extra=json.dumps({"socket.timeout.ms": 10}), ) ) - - def test_init(self): - """test initialization of AdminClientHook""" - - # Standard Init - KafkaAdminClientHook(kafka_config_id="kafka_d") - - # # Not Enough Args - with pytest.raises(ValueError): - KafkaAdminClientHook(kafka_config_id="kafka_bad") + self.hook = KafkaAdminClientHook(kafka_config_id="kafka_d") def test_get_conn(self): - """test get_conn""" - - # Standard Init - k = KafkaAdminClientHook(kafka_config_id="kafka_d") - - c = k.get_conn - - assert isinstance(c, AdminClient) + assert isinstance(self.hook.get_conn, AdminClient) + + @patch( + "airflow.providers.apache.kafka.hooks.client.AdminClient", + ) + def test_create_topic(self, admin_client): + mock_f = MagicMock() + admin_client.return_value.create_topics.return_value = {"topic_name": mock_f} + self.hook.create_topic(topics=[("topic_name", 0, 1)]) + admin_client.return_value.create_topics.assert_called_with([NewTopic("topic_name", 0, 1)]) + mock_f.result.assert_called_once() + + @patch( + "airflow.providers.apache.kafka.hooks.client.AdminClient", + ) + def test_create_topic_error(self, admin_client): + mock_f = MagicMock() + kafka_exception = KafkaException() + mock_arg = MagicMock() + # mock_arg.name = "TOPIC_ALREADY_EXISTS" + kafka_exception.args = [mock_arg] + mock_f.result.side_effect = [kafka_exception] + admin_client.return_value.create_topics.return_value = {"topic_name": mock_f} + with pytest.raises(KafkaException): + self.hook.create_topic(topics=[("topic_name", 0, 1)]) + + @patch( + "airflow.providers.apache.kafka.hooks.client.AdminClient", + ) + def test_create_topic_warning(self, admin_client, caplog): + mock_f = MagicMock() + kafka_exception = KafkaException() + mock_arg = MagicMock() + mock_arg.name = "TOPIC_ALREADY_EXISTS" + kafka_exception.args = [mock_arg] + mock_f.result.side_effect = [kafka_exception] + admin_client.return_value.create_topics.return_value = {"topic_name": mock_f} + with caplog.at_level( + logging.WARNING, logger="airflow.providers.apache.kafka.hooks.client.KafkaAdminClientHook" + ): + self.hook.create_topic(topics=[("topic_name", 0, 1)]) + assert "The topic topic_name already exists" in caplog.text diff --git a/tests/providers/apache/kafka/hooks/test_consume.py b/tests/providers/apache/kafka/hooks/test_consume.py index 852d7374489c60..5d649845103bb3 100644 --- a/tests/providers/apache/kafka/hooks/test_consume.py +++ b/tests/providers/apache/kafka/hooks/test_consume.py @@ -52,13 +52,8 @@ def setup_method(self): extra=json.dumps({}), ) ) + self.hook = KafkaConsumerHook(["test_1"], kafka_config_id="kafka_d") - def test_init(self): - """test initialization of AdminClientHook""" + def test_get_consumer(self): + assert self.hook.get_consumer() == self.hook.get_conn - # Standard Init - KafkaConsumerHook(["test_1"], kafka_config_id="kafka_d") - - # Not Enough Args - with pytest.raises(ValueError): - KafkaConsumerHook(["test_1"], kafka_config_id="kafka_bad") diff --git a/tests/providers/apache/kafka/hooks/test_produce.py b/tests/providers/apache/kafka/hooks/test_produce.py index 0f5ed0e1865eff..3bcdd010ca050e 100644 --- a/tests/providers/apache/kafka/hooks/test_produce.py +++ b/tests/providers/apache/kafka/hooks/test_produce.py @@ -54,13 +54,7 @@ def setup_method(self): extra=json.dumps({}), ) ) + self.hook = KafkaProducerHook(kafka_config_id="kafka_d") - def test_init(self): - """test initialization of AdminClientHook""" - - # Standard Init - KafkaProducerHook(kafka_config_id="kafka_d") - - # Not Enough Args - with pytest.raises(ValueError): - KafkaProducerHook(kafka_config_id="kafka_bad") + def test_get_producer(self): + assert self.hook.get_producer() == self.hook.get_conn