From 8f31a8cff6530b9f5291aef5cd08d4ff3afd32e5 Mon Sep 17 00:00:00 2001 From: Daniel Standish Date: Sat, 26 Oct 2019 21:16:00 -0700 Subject: [PATCH] [AIRFLOW-5768] GCP cloud sql don't store ephemeral connection in db * add optional param `connection` to postgres and mysql hooks * instead of storing ephemeral connection to db, just pass directly to hook * remove obsoleted tests --- airflow/gcp/hooks/cloud_sql.py | 57 +--- airflow/gcp/operators/cloud_sql.py | 13 +- airflow/hooks/mysql_hook.py | 3 +- airflow/hooks/postgres_hook.py | 4 +- tests/gcp/hooks/test_cloud_sql.py | 359 ++++++++------------------ tests/gcp/operators/test_cloud_sql.py | 34 --- tests/hooks/test_mysql_hook.py | 18 ++ tests/hooks/test_postgres_hook.py | 18 ++ 8 files changed, 167 insertions(+), 339 deletions(-) diff --git a/airflow/gcp/hooks/cloud_sql.py b/airflow/gcp/hooks/cloud_sql.py index 74b689a4b0cfd..5c9a7b478e02a 100644 --- a/airflow/gcp/hooks/cloud_sql.py +++ b/airflow/gcp/hooks/cloud_sql.py @@ -934,57 +934,16 @@ def _get_sqlproxy_instance_specification(self) -> str: instance_specification += "=tcp:" + str(self.sql_proxy_tcp_port) return instance_specification - @provide_session - def create_connection(self, session: Optional[Session] = None) -> None: + def create_connection(self) -> Connection: """ - Create connection in the Connection table, according to whether it uses - proxy, TCP, UNIX sockets, SSL. Connection ID will be randomly generated. - - :param session: Session of the SQL Alchemy ORM (automatically generated with - decorator). + Create Connection object, according to whether it uses proxy, TCP, UNIX sockets, SSL. + Connection ID will be randomly generated. """ - assert session is not None connection = Connection(conn_id=self.db_conn_id) uri = self._generate_connection_uri() self.log.info("Creating connection %s", self.db_conn_id) connection.parse_from_uri(uri) - session.add(connection) - session.commit() - - @provide_session - def retrieve_connection(self, session: Optional[Session] = None) -> Optional[Connection]: - """ - Retrieves the dynamically created connection from the Connection table. - - :param session: Session of the SQL Alchemy ORM (automatically generated with - decorator). - """ - assert session is not None - self.log.info("Retrieving connection %s", self.db_conn_id) - connections = session.query(Connection).filter( - Connection.conn_id == self.db_conn_id) - if connections.count(): - return connections[0] - return None - - @provide_session - def delete_connection(self, session: Optional[Session] = None) -> None: - """ - Delete the dynamically created connection from the Connection table. - - :param session: Session of the SQL Alchemy ORM (automatically generated with - decorator). - """ - assert session is not None - self.log.info("Deleting connection %s", self.db_conn_id) - connections = session.query(Connection).filter( - Connection.conn_id == self.db_conn_id) - if connections.count(): - connection = connections[0] - session.delete(connection) - session.commit() - else: - self.log.info("Connection was already deleted!") + return connection def get_sqlproxy_runner(self) -> CloudSqlProxyRunner: """ @@ -1006,17 +965,15 @@ def get_sqlproxy_runner(self) -> CloudSqlProxyRunner: gcp_conn_id=self.gcp_conn_id ) - def get_database_hook(self) -> Union[PostgresHook, MySqlHook]: + def get_database_hook(self, connection: Connection) -> Union[PostgresHook, MySqlHook]: """ Retrieve database hook. This is the actual Postgres or MySQL database hook that uses proxy or connects directly to the Google Cloud SQL database. """ if self.database_type == 'postgres': - self.db_hook = PostgresHook(postgres_conn_id=self.db_conn_id, - schema=self.database) + self.db_hook = PostgresHook(connection=connection, schema=self.database) else: - self.db_hook = MySqlHook(mysql_conn_id=self.db_conn_id, - schema=self.database) + self.db_hook = MySqlHook(connection=connection, schema=self.database) return self.db_hook def cleanup_database_hook(self) -> None: diff --git a/airflow/gcp/operators/cloud_sql.py b/airflow/gcp/operators/cloud_sql.py index 6b9e6ab905a58..ac4130fc49462 100644 --- a/airflow/gcp/operators/cloud_sql.py +++ b/airflow/gcp/operators/cloud_sql.py @@ -835,13 +835,10 @@ def execute(self, context): 'extra__google_cloud_platform__project') ) hook.validate_ssl_certs() - hook.create_connection() + connection = hook.create_connection() + hook.validate_socket_path_length() + database_hook = hook.get_database_hook(connection=connection) try: - hook.validate_socket_path_length() - database_hook = hook.get_database_hook() - try: - self._execute_query(hook, database_hook) - finally: - hook.cleanup_database_hook() + self._execute_query(hook, database_hook) finally: - hook.delete_connection() + hook.cleanup_database_hook() diff --git a/airflow/hooks/mysql_hook.py b/airflow/hooks/mysql_hook.py index fe79ba1725053..b5ea04c4d2195 100644 --- a/airflow/hooks/mysql_hook.py +++ b/airflow/hooks/mysql_hook.py @@ -47,6 +47,7 @@ class MySqlHook(DbApiHook): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.schema = kwargs.pop("schema", None) + self.connection = kwargs.pop("connection", None) def set_autocommit(self, conn, autocommit): """ @@ -69,7 +70,7 @@ def get_conn(self): """ Returns a mysql connection object """ - conn = self.get_connection(self.mysql_conn_id) + conn = self.connection or self.get_connection(self.mysql_conn_id) conn_config = { "user": conn.login, diff --git a/airflow/hooks/postgres_hook.py b/airflow/hooks/postgres_hook.py index 624bba2dbedf7..1a9957578e6c7 100644 --- a/airflow/hooks/postgres_hook.py +++ b/airflow/hooks/postgres_hook.py @@ -56,6 +56,7 @@ class PostgresHook(DbApiHook): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.schema = kwargs.pop("schema", None) + self.connection = kwargs.pop("connection", None) def _get_cursor(self, raw_cursor): _cursor = raw_cursor.lower() @@ -68,8 +69,9 @@ def _get_cursor(self, raw_cursor): raise ValueError('Invalid cursor passed {}'.format(_cursor)) def get_conn(self): + conn_id = getattr(self, self.conn_name_attr) - conn = self.get_connection(conn_id) + conn = self.connection or self.get_connection(conn_id) # check for authentication via AWS IAM if conn.extra_dejson.get('iam', False): diff --git a/tests/gcp/hooks/test_cloud_sql.py b/tests/gcp/hooks/test_cloud_sql.py index 236dc37a31c13..49963faba6789 100644 --- a/tests/gcp/hooks/test_cloud_sql.py +++ b/tests/gcp/hooks/test_cloud_sql.py @@ -27,7 +27,6 @@ from airflow.exceptions import AirflowException from airflow.gcp.hooks.cloud_sql import CloudSqlDatabaseHook, CloudSqlHook -from airflow.hooks.base_hook import BaseHook from airflow.models import Connection from tests.compat import PropertyMock, mock from tests.gcp.utils.base_gcp_mock import ( @@ -1066,23 +1065,6 @@ def test_cloudsql_database_hook_create_connection_missing_fields(self, uri, get_ err = cm.exception self.assertIn("needs to be set in connection", str(err)) - @mock.patch('airflow.gcp.hooks.cloud_sql.CloudSqlDatabaseHook.get_connection') - def test_cloudsql_database_hook_create_delete_connection(self, get_connection): - connection = Connection() - connection.parse_from_uri("http://user:password@host:80/database") - connection.set_extra(json.dumps({ - "location": "test", - "instance": "instance", - "database_type": "postgres" - })) - get_connection.return_value = connection - hook = CloudSqlDatabaseHook(gcp_cloudsql_conn_id='cloudsql_connection', - default_gcp_project_id='google_connection') - hook.create_connection() - self.assertIsNotNone(hook.retrieve_connection()) - hook.delete_connection() - self.assertIsNone(hook.retrieve_connection()) - @mock.patch('airflow.gcp.hooks.cloud_sql.CloudSqlDatabaseHook.get_connection') def test_cloudsql_database_hook_get_sqlproxy_runner_no_proxy(self, get_connection): connection = Connection() @@ -1095,14 +1077,10 @@ def test_cloudsql_database_hook_get_sqlproxy_runner_no_proxy(self, get_connectio get_connection.return_value = connection hook = CloudSqlDatabaseHook(gcp_cloudsql_conn_id='cloudsql_connection', default_gcp_project_id='google_connection') - hook.create_connection() - try: - with self.assertRaises(AirflowException) as cm: - hook.get_sqlproxy_runner() - err = cm.exception - self.assertIn('Proxy runner can only be retrieved in case of use_proxy = True', str(err)) - finally: - hook.delete_connection() + with self.assertRaises(AirflowException) as cm: + hook.get_sqlproxy_runner() + err = cm.exception + self.assertIn('Proxy runner can only be retrieved in case of use_proxy = True', str(err)) @mock.patch('airflow.gcp.hooks.cloud_sql.CloudSqlDatabaseHook.get_connection') def test_cloudsql_database_hook_get_sqlproxy_runner(self, get_connection): @@ -1119,11 +1097,8 @@ def test_cloudsql_database_hook_get_sqlproxy_runner(self, get_connection): hook = CloudSqlDatabaseHook(gcp_cloudsql_conn_id='cloudsql_connection', default_gcp_project_id='google_connection') hook.create_connection() - try: - proxy_runner = hook.get_sqlproxy_runner() - self.assertIsNotNone(proxy_runner) - finally: - hook.delete_connection() + proxy_runner = hook.get_sqlproxy_runner() + self.assertIsNotNone(proxy_runner) @mock.patch('airflow.gcp.hooks.cloud_sql.CloudSqlDatabaseHook.get_connection') def test_cloudsql_database_hook_get_database_hook(self, get_connection): @@ -1137,28 +1112,13 @@ def test_cloudsql_database_hook_get_database_hook(self, get_connection): get_connection.return_value = connection hook = CloudSqlDatabaseHook(gcp_cloudsql_conn_id='cloudsql_connection', default_gcp_project_id='google_connection') - hook.create_connection() - try: - db_hook = hook.get_database_hook() - self.assertIsNotNone(db_hook) - finally: - hook.delete_connection() + connection = hook.create_connection() + db_hook = hook.get_database_hook(connection=connection) + self.assertIsNotNone(db_hook) class TestCloudSqlDatabaseHook(unittest.TestCase): - @staticmethod - def _setup_connections(get_connections, uri): - gcp_connection = mock.MagicMock() - gcp_connection.extra_dejson = mock.MagicMock() - gcp_connection.extra_dejson.get.return_value = 'empty_project' - cloudsql_connection = Connection() - cloudsql_connection.parse_from_uri(uri) - cloudsql_connection2 = Connection() - cloudsql_connection2.parse_from_uri(uri) - get_connections.side_effect = [[gcp_connection], [cloudsql_connection], - [cloudsql_connection2]] - @mock.patch('airflow.contrib.hooks.gcp_sql_hook.CloudSqlDatabaseHook.get_connection') def setUp(self, m): super().setUp() @@ -1213,235 +1173,144 @@ def test_get_sqlproxy_runner(self): ) self.assertEqual(sqlproxy_runner.instance_specification, instance_spec) - @mock.patch("airflow.hooks.base_hook.BaseHook.get_connections") - def test_hook_with_not_too_long_unix_socket_path(self, get_connections): + @mock.patch("airflow.gcp.hooks.cloud_sql.CloudSqlDatabaseHook.get_connection") + def test_hook_with_not_too_long_unix_socket_path(self, get_connection): uri = "gcpcloudsql://user:password@127.0.0.1:3200/testdb?database_type=postgres&" \ "project_id=example-project&location=europe-west1&" \ "instance=" \ "test_db_with_longname_but_with_limit_of_UNIX_socket&" \ "use_proxy=True&sql_proxy_use_tcp=False" - self._setup_connections(get_connections, uri) - gcp_conn_id = 'google_cloud_default' - hook = CloudSqlDatabaseHook( - default_gcp_project_id=BaseHook.get_connection(gcp_conn_id).extra_dejson.get( - 'extra__google_cloud_platform__project') - ) - hook.create_connection() - try: - db_hook = hook.get_database_hook() - conn = db_hook._get_connections_from_db(db_hook.postgres_conn_id)[0] # pylint: disable=no-member - finally: - hook.delete_connection() - self.assertEqual('postgres', conn.conn_type) - self.assertEqual('testdb', conn.schema) - - @mock.patch("airflow.hooks.base_hook.BaseHook.get_connections") - def test_hook_with_correct_parameters_postgres(self, get_connections): + get_connection.side_effect = [Connection(uri=uri)] + hook = CloudSqlDatabaseHook() + connection = hook.create_connection() + self.assertEqual('postgres', connection.conn_type) + self.assertEqual('testdb', connection.schema) + + @mock.patch("airflow.gcp.hooks.cloud_sql.CloudSqlDatabaseHook.get_connection") + def test_hook_with_correct_parameters_postgres(self, get_connection): uri = "gcpcloudsql://user:password@127.0.0.1:3200/testdb?database_type=postgres&" \ "project_id=example-project&location=europe-west1&instance=testdb&" \ "use_proxy=False&use_ssl=False" - self._setup_connections(get_connections, uri) - gcp_conn_id = 'google_cloud_default' - hook = CloudSqlDatabaseHook( - default_gcp_project_id=BaseHook.get_connection(gcp_conn_id).extra_dejson.get( - 'extra__google_cloud_platform__project') - ) - hook.create_connection() - try: - db_hook = hook.get_database_hook() - conn = db_hook._get_connections_from_db(db_hook.postgres_conn_id)[0] # pylint: disable=no-member - finally: - hook.delete_connection() - self.assertEqual('postgres', conn.conn_type) - self.assertEqual('127.0.0.1', conn.host) - self.assertEqual(3200, conn.port) - self.assertEqual('testdb', conn.schema) - - @mock.patch("airflow.hooks.base_hook.BaseHook.get_connections") - def test_hook_with_correct_parameters_postgres_ssl(self, get_connections): + get_connection.side_effect = [Connection(uri=uri)] + hook = CloudSqlDatabaseHook() + connection = hook.create_connection() + self.assertEqual('postgres', connection.conn_type) + self.assertEqual('127.0.0.1', connection.host) + self.assertEqual(3200, connection.port) + self.assertEqual('testdb', connection.schema) + + @mock.patch("airflow.gcp.hooks.cloud_sql.CloudSqlDatabaseHook.get_connection") + def test_hook_with_correct_parameters_postgres_ssl(self, get_connection): uri = "gcpcloudsql://user:password@127.0.0.1:3200/testdb?database_type=postgres&" \ "project_id=example-project&location=europe-west1&instance=testdb&" \ "use_proxy=False&use_ssl=True&sslcert=/bin/bash&" \ "sslkey=/bin/bash&sslrootcert=/bin/bash" - self._setup_connections(get_connections, uri) - gcp_conn_id = 'google_cloud_default' - hook = CloudSqlDatabaseHook( - default_gcp_project_id=BaseHook.get_connection(gcp_conn_id).extra_dejson.get( - 'extra__google_cloud_platform__project') - ) - hook.create_connection() - try: - db_hook = hook.get_database_hook() - conn = db_hook._get_connections_from_db(db_hook.postgres_conn_id)[0] # pylint: disable=no-member - finally: - hook.delete_connection() - self.assertEqual('postgres', conn.conn_type) - self.assertEqual('127.0.0.1', conn.host) - self.assertEqual(3200, conn.port) - self.assertEqual('testdb', conn.schema) - self.assertEqual('/bin/bash', conn.extra_dejson['sslkey']) - self.assertEqual('/bin/bash', conn.extra_dejson['sslcert']) - self.assertEqual('/bin/bash', conn.extra_dejson['sslrootcert']) - - @mock.patch("airflow.hooks.base_hook.BaseHook.get_connections") - def test_hook_with_correct_parameters_postgres_proxy_socket(self, get_connections): + get_connection.side_effect = [Connection(uri=uri)] + hook = CloudSqlDatabaseHook() + connection = hook.create_connection() + self.assertEqual('postgres', connection.conn_type) + self.assertEqual('127.0.0.1', connection.host) + self.assertEqual(3200, connection.port) + self.assertEqual('testdb', connection.schema) + self.assertEqual('/bin/bash', connection.extra_dejson['sslkey']) + self.assertEqual('/bin/bash', connection.extra_dejson['sslcert']) + self.assertEqual('/bin/bash', connection.extra_dejson['sslrootcert']) + + @mock.patch("airflow.gcp.hooks.cloud_sql.CloudSqlDatabaseHook.get_connection") + def test_hook_with_correct_parameters_postgres_proxy_socket(self, get_connection): uri = "gcpcloudsql://user:password@127.0.0.1:3200/testdb?database_type=postgres&" \ "project_id=example-project&location=europe-west1&instance=testdb&" \ "use_proxy=True&sql_proxy_use_tcp=False" - self._setup_connections(get_connections, uri) - gcp_conn_id = 'google_cloud_default' - hook = CloudSqlDatabaseHook( - default_gcp_project_id=BaseHook.get_connection(gcp_conn_id).extra_dejson.get( - 'extra__google_cloud_platform__project') - ) - hook.create_connection() - try: - db_hook = hook.get_database_hook() - conn = db_hook._get_connections_from_db(db_hook.postgres_conn_id)[0] # pylint: disable=no-member - finally: - hook.delete_connection() - self.assertEqual('postgres', conn.conn_type) - self.assertIn('/tmp', conn.host) - self.assertIn('example-project:europe-west1:testdb', conn.host) - self.assertIsNone(conn.port) - self.assertEqual('testdb', conn.schema) - - @mock.patch("airflow.hooks.base_hook.BaseHook.get_connections") - def test_hook_with_correct_parameters_project_id_missing(self, get_connections): + get_connection.side_effect = [Connection(uri=uri)] + hook = CloudSqlDatabaseHook() + connection = hook.create_connection() + self.assertEqual('postgres', connection.conn_type) + self.assertIn('/tmp', connection.host) + self.assertIn('example-project:europe-west1:testdb', connection.host) + self.assertIsNone(connection.port) + self.assertEqual('testdb', connection.schema) + + @mock.patch("airflow.gcp.hooks.cloud_sql.CloudSqlDatabaseHook.get_connection") + def test_hook_with_correct_parameters_project_id_missing(self, get_connection): uri = "gcpcloudsql://user:password@127.0.0.1:3200/testdb?database_type=mysql&" \ "location=europe-west1&instance=testdb&" \ "use_proxy=False&use_ssl=False" - self._setup_connections(get_connections, uri) - gcp_conn_id = 'google_cloud_default' - hook = CloudSqlDatabaseHook( - default_gcp_project_id=BaseHook.get_connection(gcp_conn_id).extra_dejson.get( - 'extra__google_cloud_platform__project') - ) - hook.create_connection() - try: - db_hook = hook.get_database_hook() - conn = db_hook._get_connections_from_db(db_hook.mysql_conn_id)[0] # pylint: disable=no-member - finally: - hook.delete_connection() - self.assertEqual('mysql', conn.conn_type) - self.assertEqual('127.0.0.1', conn.host) - self.assertEqual(3200, conn.port) - self.assertEqual('testdb', conn.schema) - - @mock.patch("airflow.hooks.base_hook.BaseHook.get_connections") - def test_hook_with_correct_parameters_postgres_proxy_tcp(self, get_connections): + get_connection.side_effect = [Connection(uri=uri)] + hook = CloudSqlDatabaseHook() + connection = hook.create_connection() + self.assertEqual('mysql', connection.conn_type) + self.assertEqual('127.0.0.1', connection.host) + self.assertEqual(3200, connection.port) + self.assertEqual('testdb', connection.schema) + + @mock.patch("airflow.gcp.hooks.cloud_sql.CloudSqlDatabaseHook.get_connection") + def test_hook_with_correct_parameters_postgres_proxy_tcp(self, get_connection): uri = "gcpcloudsql://user:password@127.0.0.1:3200/testdb?database_type=postgres&" \ "project_id=example-project&location=europe-west1&instance=testdb&" \ "use_proxy=True&sql_proxy_use_tcp=True" - self._setup_connections(get_connections, uri) - gcp_conn_id = 'google_cloud_default' - hook = CloudSqlDatabaseHook( - default_gcp_project_id=BaseHook.get_connection(gcp_conn_id).extra_dejson.get( - 'extra__google_cloud_platform__project') - ) - hook.create_connection() - try: - db_hook = hook.get_database_hook() - conn = db_hook._get_connections_from_db(db_hook.postgres_conn_id)[0] # pylint: disable=no-member - finally: - hook.delete_connection() - self.assertEqual('postgres', conn.conn_type) - self.assertEqual('127.0.0.1', conn.host) - self.assertNotEqual(3200, conn.port) - self.assertEqual('testdb', conn.schema) - - @mock.patch("airflow.hooks.base_hook.BaseHook.get_connections") - def test_hook_with_correct_parameters_mysql(self, get_connections): + get_connection.side_effect = [Connection(uri=uri)] + hook = CloudSqlDatabaseHook() + connection = hook.create_connection() + self.assertEqual('postgres', connection.conn_type) + self.assertEqual('127.0.0.1', connection.host) + self.assertNotEqual(3200, connection.port) + self.assertEqual('testdb', connection.schema) + + @mock.patch("airflow.gcp.hooks.cloud_sql.CloudSqlDatabaseHook.get_connection") + def test_hook_with_correct_parameters_mysql(self, get_connection): uri = "gcpcloudsql://user:password@127.0.0.1:3200/testdb?database_type=mysql&" \ "project_id=example-project&location=europe-west1&instance=testdb&" \ "use_proxy=False&use_ssl=False" - self._setup_connections(get_connections, uri) - gcp_conn_id = 'google_cloud_default' - hook = CloudSqlDatabaseHook( - default_gcp_project_id=BaseHook.get_connection(gcp_conn_id).extra_dejson.get( - 'extra__google_cloud_platform__project') - ) - hook.create_connection() - try: - db_hook = hook.get_database_hook() - conn = db_hook._get_connections_from_db(db_hook.mysql_conn_id)[0] # pylint: disable=no-member - finally: - hook.delete_connection() - self.assertEqual('mysql', conn.conn_type) - self.assertEqual('127.0.0.1', conn.host) - self.assertEqual(3200, conn.port) - self.assertEqual('testdb', conn.schema) - - @mock.patch("airflow.hooks.base_hook.BaseHook.get_connections") - def test_hook_with_correct_parameters_mysql_ssl(self, get_connections): + get_connection.side_effect = [Connection(uri=uri)] + hook = CloudSqlDatabaseHook() + connection = hook.create_connection() + self.assertEqual('mysql', connection.conn_type) + self.assertEqual('127.0.0.1', connection.host) + self.assertEqual(3200, connection.port) + self.assertEqual('testdb', connection.schema) + + @mock.patch("airflow.gcp.hooks.cloud_sql.CloudSqlDatabaseHook.get_connection") + def test_hook_with_correct_parameters_mysql_ssl(self, get_connection): uri = "gcpcloudsql://user:password@127.0.0.1:3200/testdb?database_type=mysql&" \ "project_id=example-project&location=europe-west1&instance=testdb&" \ "use_proxy=False&use_ssl=True&sslcert=/bin/bash&" \ "sslkey=/bin/bash&sslrootcert=/bin/bash" - self._setup_connections(get_connections, uri) - - gcp_conn_id = 'google_cloud_default' - hook = CloudSqlDatabaseHook( - default_gcp_project_id=BaseHook.get_connection(gcp_conn_id).extra_dejson.get( - 'extra__google_cloud_platform__project') - ) - hook.create_connection() - try: - db_hook = hook.get_database_hook() - conn = db_hook._get_connections_from_db(db_hook.mysql_conn_id)[0] # pylint: disable=no-member - finally: - hook.delete_connection() - self.assertEqual('mysql', conn.conn_type) - self.assertEqual('127.0.0.1', conn.host) - self.assertEqual(3200, conn.port) - self.assertEqual('testdb', conn.schema) - self.assertEqual('/bin/bash', json.loads(conn.extra_dejson['ssl'])['cert']) - self.assertEqual('/bin/bash', json.loads(conn.extra_dejson['ssl'])['key']) - self.assertEqual('/bin/bash', json.loads(conn.extra_dejson['ssl'])['ca']) - - @mock.patch("airflow.hooks.base_hook.BaseHook.get_connections") - def test_hook_with_correct_parameters_mysql_proxy_socket(self, get_connections): + get_connection.side_effect = [Connection(uri=uri)] + hook = CloudSqlDatabaseHook() + connection = hook.create_connection() + self.assertEqual('mysql', connection.conn_type) + self.assertEqual('127.0.0.1', connection.host) + self.assertEqual(3200, connection.port) + self.assertEqual('testdb', connection.schema) + self.assertEqual('/bin/bash', json.loads(connection.extra_dejson['ssl'])['cert']) + self.assertEqual('/bin/bash', json.loads(connection.extra_dejson['ssl'])['key']) + self.assertEqual('/bin/bash', json.loads(connection.extra_dejson['ssl'])['ca']) + + @mock.patch("airflow.gcp.hooks.cloud_sql.CloudSqlDatabaseHook.get_connection") + def test_hook_with_correct_parameters_mysql_proxy_socket(self, get_connection): uri = "gcpcloudsql://user:password@127.0.0.1:3200/testdb?database_type=mysql&" \ "project_id=example-project&location=europe-west1&instance=testdb&" \ "use_proxy=True&sql_proxy_use_tcp=False" - self._setup_connections(get_connections, uri) - gcp_conn_id = 'google_cloud_default' - hook = CloudSqlDatabaseHook( - default_gcp_project_id=BaseHook.get_connection(gcp_conn_id).extra_dejson.get( - 'extra__google_cloud_platform__project') - ) - hook.create_connection() - try: - db_hook = hook.get_database_hook() - conn = db_hook._get_connections_from_db(db_hook.mysql_conn_id)[0] # pylint: disable=no-member - finally: - hook.delete_connection() - self.assertEqual('mysql', conn.conn_type) - self.assertEqual('localhost', conn.host) - self.assertIn('/tmp', conn.extra_dejson['unix_socket']) + get_connection.side_effect = [Connection(uri=uri)] + hook = CloudSqlDatabaseHook() + connection = hook.create_connection() + self.assertEqual('mysql', connection.conn_type) + self.assertEqual('localhost', connection.host) + self.assertIn('/tmp', connection.extra_dejson['unix_socket']) self.assertIn('example-project:europe-west1:testdb', - conn.extra_dejson['unix_socket']) - self.assertIsNone(conn.port) - self.assertEqual('testdb', conn.schema) + connection.extra_dejson['unix_socket']) + self.assertIsNone(connection.port) + self.assertEqual('testdb', connection.schema) - @mock.patch("airflow.hooks.base_hook.BaseHook.get_connections") - def test_hook_with_correct_parameters_mysql_tcp(self, get_connections): + @mock.patch("airflow.gcp.hooks.cloud_sql.CloudSqlDatabaseHook.get_connection") + def test_hook_with_correct_parameters_mysql_tcp(self, get_connection): uri = "gcpcloudsql://user:password@127.0.0.1:3200/testdb?database_type=mysql&" \ "project_id=example-project&location=europe-west1&instance=testdb&" \ "use_proxy=True&sql_proxy_use_tcp=True" - self._setup_connections(get_connections, uri) - gcp_conn_id = 'google_cloud_default' - hook = CloudSqlDatabaseHook( - default_gcp_project_id=BaseHook.get_connection(gcp_conn_id).extra_dejson.get( - 'extra__google_cloud_platform__project') - ) - hook.create_connection() - try: - db_hook = hook.get_database_hook() - conn = db_hook._get_connections_from_db(db_hook.mysql_conn_id)[0] # pylint: disable=no-member - finally: - hook.delete_connection() - self.assertEqual('mysql', conn.conn_type) - self.assertEqual('127.0.0.1', conn.host) - self.assertNotEqual(3200, conn.port) - self.assertEqual('testdb', conn.schema) + get_connection.side_effect = [Connection(uri=uri)] + hook = CloudSqlDatabaseHook() + connection = hook.create_connection() + self.assertEqual('mysql', connection.conn_type) + self.assertEqual('127.0.0.1', connection.host) + self.assertNotEqual(3200, connection.port) + self.assertEqual('testdb', connection.schema) diff --git a/tests/gcp/operators/test_cloud_sql.py b/tests/gcp/operators/test_cloud_sql.py index 52ab9bf4fac59..a81822f801f75 100644 --- a/tests/gcp/operators/test_cloud_sql.py +++ b/tests/gcp/operators/test_cloud_sql.py @@ -19,7 +19,6 @@ # pylint: disable=too-many-lines -import json import os import unittest @@ -787,36 +786,3 @@ def test_create_operator_with_too_long_unix_socket_path(self, get_connections): operator.execute(None) err = cm.exception self.assertIn("The UNIX socket path length cannot exceed", str(err)) - - @mock.patch("airflow.contrib.hooks.gcp_sql_hook.CloudSqlDatabaseHook." - "delete_connection") - @mock.patch("airflow.contrib.hooks.gcp_sql_hook.CloudSqlDatabaseHook." - "get_connection") - @mock.patch("airflow.hooks.mysql_hook.MySqlHook.run") - @mock.patch("airflow.hooks.base_hook.BaseHook.get_connections") - def test_cloudsql_hook_delete_connection_on_exception( - self, get_connections, run, get_connection, delete_connection): - connection = Connection() - connection.parse_from_uri( - "gcpcloudsql://user:password@127.0.0.1:3200/testdb?database_type=mysql&" - "project_id=example-project&location=europe-west1&instance=testdb&" - "use_proxy=False") - get_connection.return_value = connection - - db_connection = Connection() - db_connection.host = "127.0.0.1" - db_connection.set_extra(json.dumps({"project_id": "example-project", - "location": "europe-west1", - "instance": "testdb", - "database_type": "mysql"})) - get_connections.return_value = [db_connection] - run.side_effect = Exception("Exception when running a query") - operator = CloudSqlQueryOperator( - sql=['SELECT * FROM TABLE'], - task_id='task_id' - ) - with self.assertRaises(Exception) as cm: - operator.execute(None) - err = cm.exception - self.assertEqual("Exception when running a query", str(err)) - delete_connection.assert_called_once_with() diff --git a/tests/hooks/test_mysql_hook.py b/tests/hooks/test_mysql_hook.py index 0af32910f0827..42ef3f7754728 100644 --- a/tests/hooks/test_mysql_hook.py +++ b/tests/hooks/test_mysql_hook.py @@ -61,6 +61,24 @@ def test_get_conn(self, mock_connect): self.assertEqual(kwargs['host'], 'host') self.assertEqual(kwargs['db'], 'schema') + @mock.patch('airflow.hooks.mysql_hook.MySQLdb.connect') + def test_get_conn_from_connection(self, mock_connect): + conn = Connection(login='login-conn', password='password-conn', host='host', schema='schema') + hook = MySqlHook(connection=conn) + hook.get_conn() + mock_connect.assert_called_once_with( + user='login-conn', passwd='password-conn', host='host', db='schema', port=3306 + ) + + @mock.patch('airflow.hooks.mysql_hook.MySQLdb.connect') + def test_get_conn_from_connection_with_schema(self, mock_connect): + conn = Connection(login='login-conn', password='password-conn', host='host', schema='schema') + hook = MySqlHook(connection=conn, schema='schema-override') + hook.get_conn() + mock_connect.assert_called_once_with( + user='login-conn', passwd='password-conn', host='host', db='schema-override', port=3306 + ) + @mock.patch('airflow.hooks.mysql_hook.MySQLdb.connect') def test_get_conn_port(self, mock_connect): self.connection.port = 3307 diff --git a/tests/hooks/test_postgres_hook.py b/tests/hooks/test_postgres_hook.py index 1ad592c490372..ea0120bbb1d15 100644 --- a/tests/hooks/test_postgres_hook.py +++ b/tests/hooks/test_postgres_hook.py @@ -76,6 +76,24 @@ def test_get_conn_with_invalid_cursor(self, mock_connect): with self.assertRaises(ValueError): self.db_hook.get_conn() + @mock.patch('airflow.hooks.postgres_hook.psycopg2.connect') + def test_get_conn_from_connection(self, mock_connect): + conn = Connection(login='login-conn', password='password-conn', host='host', schema='schema') + hook = PostgresHook(connection=conn) + hook.get_conn() + mock_connect.assert_called_once_with( + user='login-conn', password='password-conn', host='host', dbname='schema', port=None + ) + + @mock.patch('airflow.hooks.postgres_hook.psycopg2.connect') + def test_get_conn_from_connection_with_schema(self, mock_connect): + conn = Connection(login='login-conn', password='password-conn', host='host', schema='schema') + hook = PostgresHook(connection=conn, schema='schema-override') + hook.get_conn() + mock_connect.assert_called_once_with( + user='login-conn', password='password-conn', host='host', dbname='schema-override', port=None + ) + @mock.patch('airflow.hooks.postgres_hook.psycopg2.connect') @mock.patch('airflow.contrib.hooks.aws_hook.AwsHook.get_client_type') def test_get_conn_rds_iam_postgres(self, mock_client, mock_connect):