Skip to content

Commit

Permalink
Add unit tests to Kafka base hook
Browse files Browse the repository at this point in the history
  • Loading branch information
shahar1 committed Jun 8, 2024
1 parent 1a61eb3 commit 7b08147
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 2 deletions.
3 changes: 1 addition & 2 deletions airflow/providers/apache/kafka/hooks/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@ def __init__(self, kafka_config_id=default_conn_name, *args, **kwargs):
"""Initialize our Base."""
super().__init__()
self.kafka_config_id = kafka_config_id
self.get_conn

@classmethod
def get_ui_field_behaviour(cls) -> dict[str, Any]:
Expand Down Expand Up @@ -74,6 +73,6 @@ def test_connection(self) -> tuple[bool, str]:
if t:
return True, "Connection successful."
except Exception as e:
False, str(e)
return False, str(e)

return False, "Failed to establish connection."
62 changes: 62 additions & 0 deletions tests/providers/apache/kafka/hooks/test_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
from unittest import mock
from unittest.mock import MagicMock

import pytest

from airflow.providers.apache.kafka.hooks.base import KafkaBaseHook


class SomeKafkaHook(KafkaBaseHook):
def _get_client(self, config):
return config


@pytest.fixture
def hook():
return SomeKafkaHook()


TIMEOUT = 10


class TestKafkaBaseHook:

@mock.patch("airflow.hooks.base.BaseHook.get_connection")
def test_get_conn(self, mock_get_connection, hook):
config = {"bootstrap.servers": MagicMock()}
mock_get_connection.return_value.extra_dejson = config
assert hook.get_conn == config

@mock.patch("airflow.hooks.base.BaseHook.get_connection")
def test_get_conn_value_error(self, mock_get_connection, hook):
mock_get_connection.return_value.extra_dejson = {}
with pytest.raises(ValueError, match="must be provided"):
hook.get_conn()

@mock.patch("airflow.providers.apache.kafka.hooks.base.AdminClient")
@mock.patch("airflow.hooks.base.BaseHook.get_connection")
def test_test_connection(self, mock_get_connection, admin_client, hook):
config = {"bootstrap.servers": MagicMock()}
mock_get_connection.return_value.extra_dejson = config
connection = hook.test_connection()
admin_client.assert_called_once_with(config, timeout=10)
assert connection == (True, "Connection successful.")

@mock.patch("airflow.providers.apache.kafka.hooks.base.AdminClient",
return_value=MagicMock(list_topics=MagicMock(return_value=[])))
@mock.patch("airflow.hooks.base.BaseHook.get_connection")
def test_test_connection_no_topics(self, mock_get_connection, admin_client, hook):
config = {"bootstrap.servers": MagicMock()}
mock_get_connection.return_value.extra_dejson = config
connection = hook.test_connection()
admin_client.assert_called_once_with(config, timeout=TIMEOUT)
assert connection == (False, "Failed to establish connection.")

@mock.patch("airflow.providers.apache.kafka.hooks.base.AdminClient")
@mock.patch("airflow.hooks.base.BaseHook.get_connection")
def test_test_connection_exception(self, mock_get_connection, admin_client, hook):
config = {"bootstrap.servers": MagicMock()}
mock_get_connection.return_value.extra_dejson = config
admin_client.return_value.list_topics.side_effect = [ValueError("some error")]
connection = hook.test_connection()
assert connection == (False, "some error")

0 comments on commit 7b08147

Please sign in to comment.