diff --git a/tests/providers/apache/cassandra/hooks/test_cassandra_hook.py b/tests/providers/apache/cassandra/hooks/test_cassandra_hook.py index 2d6be1afbffb28..38432184b9b76b 100644 --- a/tests/providers/apache/cassandra/hooks/test_cassandra_hook.py +++ b/tests/providers/apache/cassandra/hooks/test_cassandra_hook.py @@ -20,18 +20,25 @@ import unittest import mock -from cassandra.cluster import Cluster +from cassandra.cluster import Cluster, UnresolvableContactPoints from cassandra.policies import ( DCAwareRoundRobinPolicy, RoundRobinPolicy, TokenAwarePolicy, WhiteListRoundRobinPolicy, ) -from flaky import flaky from airflow.models import Connection from airflow.providers.apache.cassandra.hooks.cassandra import CassandraHook from airflow.utils import db -@flaky(max_runs=4, min_passes=1) +def cassandra_is_not_up(): + try: + Cluster(["cassandra"]) + return False + except UnresolvableContactPoints: + return True + + +@unittest.skipIf(cassandra_is_not_up(), "Cassandra is not up.") class TestCassandraHook(unittest.TestCase): def setUp(self): db.merge_conn( @@ -73,13 +80,14 @@ def test_get_conn(self): self.assertEqual(cluster.port, 9042) self.assertTrue(isinstance(cluster.load_balancing_policy, TokenAwarePolicy)) - def test_get_lb_policy(self): + def test_get_lb_policy_with_no_args(self): # test LB policies with no args self._assert_get_lb_policy('RoundRobinPolicy', {}, RoundRobinPolicy) self._assert_get_lb_policy('DCAwareRoundRobinPolicy', {}, DCAwareRoundRobinPolicy) self._assert_get_lb_policy('TokenAwarePolicy', {}, TokenAwarePolicy, expected_child_policy_type=RoundRobinPolicy) + def test_get_lb_policy_with_args(self): # test DCAwareRoundRobinPolicy with args self._assert_get_lb_policy('DCAwareRoundRobinPolicy', {'local_dc': 'foo', 'used_hosts_per_remote_dc': '3'}, @@ -101,6 +109,7 @@ def test_get_lb_policy(self): 'child_load_balancing_policy_args': {'hosts': ['host-1', 'host-2']} }, TokenAwarePolicy, expected_child_policy_type=WhiteListRoundRobinPolicy) + def test_get_lb_policy_invalid_policy(self): # test invalid policy name should default to RoundRobinPolicy self._assert_get_lb_policy('DoesNotExistPolicy', {}, RoundRobinPolicy) @@ -112,6 +121,7 @@ def test_get_lb_policy(self): TokenAwarePolicy, expected_child_policy_type=RoundRobinPolicy) + def test_get_lb_policy_no_host_for_white_list(self): # test host not specified for WhiteListRoundRobinPolicy should throw exception self._assert_get_lb_policy('WhiteListRoundRobinPolicy', {}, diff --git a/tests/providers/apache/cassandra/sensors/test_cassandra_sensor.py b/tests/providers/apache/cassandra/sensors/test_cassandra_sensor.py index 9431eacae510a2..d5a6656ef5f499 100644 --- a/tests/providers/apache/cassandra/sensors/test_cassandra_sensor.py +++ b/tests/providers/apache/cassandra/sensors/test_cassandra_sensor.py @@ -17,63 +17,36 @@ # specific language governing permissions and limitations # under the License. - import unittest from unittest.mock import patch -from flaky import flaky - -from airflow import DAG from airflow.providers.apache.cassandra.sensors.record import CassandraRecordSensor from airflow.providers.apache.cassandra.sensors.table import CassandraTableSensor -from airflow.utils import timezone -DEFAULT_DATE = timezone.datetime(2017, 1, 1) - -@flaky(max_runs=4, min_passes=1) class TestCassandraRecordSensor(unittest.TestCase): - - def setUp(self): - args = { - 'owner': 'airflow', - 'start_date': DEFAULT_DATE - } - self.dag = DAG('test_dag_id', default_args=args) - self.sensor = CassandraRecordSensor( + @patch("airflow.providers.apache.cassandra.sensors.record.CassandraHook") + def test_poke(self, mock_hook): + sensor = CassandraRecordSensor( task_id='test_task', cassandra_conn_id='cassandra_default', - dag=self.dag, table='t', keys={'foo': 'bar'} ) + sensor.poke(None) + mock_hook.return_value.record_exists.assert_called_once_with('t', {'foo': 'bar'}) - @patch("airflow.contrib.hooks.cassandra_hook.CassandraHook.record_exists") - def test_poke(self, mock_record_exists): - self.sensor.poke(None) - mock_record_exists.assert_called_once_with('t', {'foo': 'bar'}) - -@flaky(max_runs=4, min_passes=1) class TestCassandraTableSensor(unittest.TestCase): - - def setUp(self): - args = { - 'owner': 'airflow', - 'start_date': DEFAULT_DATE - } - self.dag = DAG('test_dag_id', default_args=args) - self.sensor = CassandraTableSensor( + @patch("airflow.providers.apache.cassandra.sensors.table.CassandraHook") + def test_poke(self, mock_hook): + sensor = CassandraTableSensor( task_id='test_task', cassandra_conn_id='cassandra_default', - dag=self.dag, table='t', ) - - @patch("airflow.contrib.hooks.cassandra_hook.CassandraHook.table_exists") - def test_poke(self, mock_table_exists): - self.sensor.poke(None) - mock_table_exists.assert_called_once_with('t') + sensor.poke(None) + mock_hook.return_value.table_exists.assert_called_once_with('t') if __name__ == '__main__':