diff --git a/singer_sdk/connectors/sql.py b/singer_sdk/connectors/sql.py index d3ec81261..437e97134 100644 --- a/singer_sdk/connectors/sql.py +++ b/singer_sdk/connectors/sql.py @@ -3,9 +3,11 @@ from __future__ import annotations import logging +import warnings +from contextlib import contextmanager from datetime import datetime from functools import lru_cache -from typing import Any, Iterable, cast +from typing import Any, Iterable, Iterator, cast import sqlalchemy from sqlalchemy.engine import Engine @@ -34,6 +36,7 @@ class SQLConnector: allow_column_alter: bool = False # Whether altering column types is supported. allow_merge_upsert: bool = False # Whether MERGE UPSERT is supported. allow_temp_tables: bool = True # Whether temp tables are supported. + _cached_engine: Engine | None = None def __init__( self, config: dict | None = None, sqlalchemy_url: str | None = None @@ -46,7 +49,6 @@ def __init__( """ self._config: dict[str, Any] = config or {} self._sqlalchemy_url: str | None = sqlalchemy_url or None - self._connection: sqlalchemy.engine.Connection | None = None @property def config(self) -> dict: @@ -66,8 +68,17 @@ def logger(self) -> logging.Logger: """ return logging.getLogger("sqlconnector") + @contextmanager + def _connect(self) -> Iterator[sqlalchemy.engine.Connection]: + with self._engine.connect().execution_options(stream_results=True) as conn: + yield conn + def create_sqlalchemy_connection(self) -> sqlalchemy.engine.Connection: - """Return a new SQLAlchemy connection using the provided config. + """(DEPRECATED) Return a new SQLAlchemy connection using the provided config. + + Do not use the SQLConnector's connection directly. Instead, if you need + to execute something that isn't available on the connector currently, + make a child class and add a method on that connector. By default this will create using the sqlalchemy `stream_results=True` option described here: @@ -81,14 +92,17 @@ def create_sqlalchemy_connection(self) -> sqlalchemy.engine.Connection: Returns: A newly created SQLAlchemy engine object. """ - return ( - self.create_sqlalchemy_engine() - .connect() - .execution_options(stream_results=True) + warnings.warn( + "`SQLConnector.create_sqlalchemy_connection` is deprecated. " + "If you need to execute something that isn't available " + "on the connector currently, make a child class and " + "add your required method on that connector.", + DeprecationWarning, ) + return self._engine.connect().execution_options(stream_results=True) - def create_sqlalchemy_engine(self) -> sqlalchemy.engine.Engine: - """Return a new SQLAlchemy engine using the provided config. + def create_sqlalchemy_engine(self) -> Engine: + """(DEPRECATED) Return a new SQLAlchemy engine using the provided config. Developers can generally override just one of the following: `sqlalchemy_engine`, sqlalchemy_url`. @@ -96,19 +110,31 @@ def create_sqlalchemy_engine(self) -> sqlalchemy.engine.Engine: Returns: A newly created SQLAlchemy engine object. """ - return sqlalchemy.create_engine(self.sqlalchemy_url, echo=False) + warnings.warn( + "`SQLConnector.create_sqlalchemy_engine` is deprecated. Override" + "`_engine` or sqlalchemy_url` instead.", + DeprecationWarning, + ) + return self._engine @property def connection(self) -> sqlalchemy.engine.Connection: - """Return or set the SQLAlchemy connection object. + """(DEPRECATED) Return or set the SQLAlchemy connection object. + + Do not use the SQLConnector's connection directly. Instead, if you need + to execute something that isn't available on the connector currently, + make a child class and add a method on that connector. Returns: The active SQLAlchemy connection object. """ - if not self._connection: - self._connection = self.create_sqlalchemy_connection() - - return self._connection + warnings.warn( + "`SQLConnector.connection` is deprecated. If you need to execute something " + "that isn't available on the connector currently, make a child " + "class and add your required method on that connector.", + DeprecationWarning, + ) + return self.create_sqlalchemy_connection() @property def sqlalchemy_url(self) -> str: @@ -249,16 +275,37 @@ def _dialect(self) -> sqlalchemy.engine.Dialect: Returns: The dialect object. """ - return cast(sqlalchemy.engine.Dialect, self.connection.engine.dialect) + return cast(sqlalchemy.engine.Dialect, self._engine.dialect) @property - def _engine(self) -> sqlalchemy.engine.Engine: - """Return the dialect object. + def _engine(self) -> Engine: + """Return the engine object. + + This is the correct way to access the Connector's engine, if needed + (e.g. to inspect tables). Returns: - The dialect object. + The SQLAlchemy Engine that's attached to this SQLConnector instance. + """ + if not self._cached_engine: + self._cached_engine = self.create_engine() + return cast(Engine, self._cached_engine) + + def create_engine(self) -> Engine: + """Creates and returns a new engine. Do not call outside of _engine. + + NOTE: Do not call this method. The only place that this method should + be called is inside the self._engine method. If you'd like to access + the engine on a connector, use self._engine. + + This method exists solely so that tap/target developers can override it + on their subclass of SQLConnector to perform custom engine creation + logic. + + Returns: + A new SQLAlchemy Engine. """ - return cast(sqlalchemy.engine.Engine, self.connection.engine) + return sqlalchemy.create_engine(self.sqlalchemy_url, echo=False) def quote(self, name: str) -> str: """Quote a name if it needs quoting, using '.' as a name-part delimiter. @@ -421,7 +468,7 @@ def discover_catalog_entries(self) -> list[dict]: The discovered catalog entries as a list. """ result: list[dict] = [] - engine = self.create_sqlalchemy_engine() + engine = self._engine inspected = sqlalchemy.inspect(engine) for schema_name in self.get_schema_names(engine, inspected): # Iterate through each table and view @@ -562,7 +609,8 @@ def create_schema(self, schema_name: str) -> None: Args: schema_name: The target schema to create. """ - self._engine.execute(sqlalchemy.schema.CreateSchema(schema_name)) + with self._connect() as conn: + conn.execute(sqlalchemy.schema.CreateSchema(schema_name)) def create_empty_table( self, @@ -635,7 +683,8 @@ def _create_empty_column( column_add_ddl = self.get_column_add_ddl( table_name=full_table_name, column_name=column_name, column_type=sql_type ) - self.connection.execute(column_add_ddl) + with self._connect() as conn: + conn.execute(column_add_ddl) def prepare_schema(self, schema_name: str) -> None: """Create the target database schema. @@ -723,7 +772,8 @@ def rename_column(self, full_table_name: str, old_name: str, new_name: str) -> N column_rename_ddl = self.get_column_rename_ddl( table_name=full_table_name, column_name=old_name, new_column_name=new_name ) - self.connection.execute(column_rename_ddl) + with self._connect() as conn: + conn.execute(column_rename_ddl) def merge_sql_types( self, sql_types: list[sqlalchemy.types.TypeEngine] @@ -1027,4 +1077,5 @@ def _adapt_column_type( column_name=column_name, column_type=compatible_sql_type, ) - self.connection.execute(alter_column_ddl) + with self._connect() as conn: + conn.execute(alter_column_ddl) diff --git a/tests/core/test_connector_sql.py b/tests/core/test_connector_sql.py index 8ad7bf9a1..89a7d46a2 100644 --- a/tests/core/test_connector_sql.py +++ b/tests/core/test_connector_sql.py @@ -1,8 +1,11 @@ +from unittest import mock + import pytest import sqlalchemy from sqlalchemy.dialects import sqlite from singer_sdk.connectors import SQLConnector +from singer_sdk.exceptions import ConfigValidationError def stringify(in_dict): @@ -14,7 +17,7 @@ class TestConnectorSQL: @pytest.fixture() def connector(self): - return SQLConnector() + return SQLConnector(config={"sqlalchemy_url": "sqlite:///"}) @pytest.mark.parametrize( "method_name,kwargs,context,unrendered_statement,rendered_statement", @@ -130,3 +133,59 @@ def test_update_collation_non_text_type(self): assert not hasattr(compatible_type, "collation") # Check that we get the same type we put in assert str(compatible_type) == "INTEGER" + + def test_create_engine_returns_new_engine(self, connector): + engine1 = connector.create_engine() + engine2 = connector.create_engine() + assert engine1 is not engine2 + + def test_engine_creates_and_returns_cached_engine(self, connector): + assert not connector._cached_engine + engine1 = connector._engine + engine2 = connector._cached_engine + assert engine1 is engine2 + + def test_deprecated_functions_warn(self, connector): + with pytest.deprecated_call(): + connector.create_sqlalchemy_engine() + with pytest.deprecated_call(): + connector.create_sqlalchemy_connection() + with pytest.deprecated_call(): + connector.connection + + def test_connect_calls_engine(self, connector): + with mock.patch.object(SQLConnector, "_engine") as mock_engine: + with connector._connect() as conn: + mock_engine.connect.assert_called_once() + + def test_connect_calls_engine(self, connector): + attached_engine = connector._engine + with mock.patch.object(attached_engine, "connect") as mock_conn: + with connector._connect() as conn: + mock_conn.assert_called_once() + + def test_connect_raises_on_operational_failure(self, connector): + with pytest.raises(sqlalchemy.exc.OperationalError) as e: + with connector._connect() as conn: + conn.execute("SELECT * FROM fake_table") + + def test_rename_column_uses_connect_correctly(self, connector): + attached_engine = connector._engine + # Ends up using the attached engine + with mock.patch.object(attached_engine, "connect") as mock_conn: + connector.rename_column("fake_table", "old_name", "new_name") + mock_conn.assert_called_once() + # Uses the _connect method + with mock.patch.object(connector, "_connect") as mock_connect_method: + connector.rename_column("fake_table", "old_name", "new_name") + mock_connect_method.assert_called_once() + + def test_get_slalchemy_url_raises_if_not_in_config(self, connector): + with pytest.raises(ConfigValidationError): + connector.get_sqlalchemy_url({}) + + def test_dialect_uses_engine(self, connector): + attached_engine = connector._engine + with mock.patch.object(attached_engine, "dialect") as mock_dialect: + res = connector._dialect + assert res == attached_engine.dialect