Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make "placeholder" of ODBC configurable in UI #36000

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 7 additions & 5 deletions airflow/providers/common/sql/hooks/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,8 +143,6 @@ class DbApiHook(BaseHook):
connector: ConnectorProtocol | None = None
# Override with db-specific query to check connection
_test_connection_sql = "select 1"
# Override with the db-specific value used for placeholders
placeholder: str = "%s"

def __init__(self, *args, schema: str | None = None, log_sql: bool = True, **kwargs):
super().__init__()
Expand All @@ -163,6 +161,11 @@ def __init__(self, *args, schema: str | None = None, log_sql: bool = True, **kwa
self.__schema = schema
self.log_sql = log_sql
self.descriptions: list[Sequence[Sequence] | None] = []
self._placeholder: str = "%s"

@property
def placeholder(self) -> str:
return self._placeholder

def get_conn(self):
"""Return a connection object."""
Expand Down Expand Up @@ -463,8 +466,7 @@ def get_cursor(self):
"""Return a cursor."""
return self.get_conn().cursor()

@classmethod
def _generate_insert_sql(cls, table, values, target_fields, replace, **kwargs) -> str:
def _generate_insert_sql(self, table, values, target_fields, replace, **kwargs) -> str:
"""
Generate the INSERT SQL statement.

Expand All @@ -477,7 +479,7 @@ def _generate_insert_sql(cls, table, values, target_fields, replace, **kwargs) -
:return: The generated INSERT or REPLACE SQL statement
"""
placeholders = [
cls.placeholder,
self.placeholder,
] * len(values)

if target_fields:
Expand Down
1 change: 1 addition & 0 deletions airflow/providers/common/sql/hooks/sql.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ class DbApiHook(BaseForDbApiHook):
placeholder: str
log_sql: Incomplete
descriptions: Incomplete
_placeholder: str
def __init__(self, *args, schema: Union[str, None] = ..., log_sql: bool = ..., **kwargs) -> None: ...
def get_conn(self): ...
def get_uri(self) -> str: ...
Expand Down
18 changes: 17 additions & 1 deletion airflow/providers/odbc/hooks/odbc.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
from airflow.providers.common.sql.hooks.sql import DbApiHook
from airflow.utils.helpers import merge_dicts

DEFAULT_ODBC_PLACEHOLDERS = frozenset({"%s", "?"})


class OdbcHook(DbApiHook):
"""
Expand Down Expand Up @@ -161,7 +163,7 @@ def odbc_connection_string(self):
if self.connection.port:
conn_str += f"PORT={self.connection.port};"

extra_exclude = {"driver", "dsn", "connect_kwargs", "sqlalchemy_scheme"}
extra_exclude = {"driver", "dsn", "connect_kwargs", "sqlalchemy_scheme", "placeholder"}
extra_params = {
k: v for k, v in self.connection.extra_dejson.items() if k.lower() not in extra_exclude
}
Expand Down Expand Up @@ -198,6 +200,20 @@ def get_conn(self) -> pyodbc.Connection:
conn = pyodbc.connect(self.odbc_connection_string, **self.connect_kwargs)
return conn

@property
def placeholder(self):
placeholder = self.connection.extra_dejson.get("placeholder")
if placeholder in DEFAULT_ODBC_PLACEHOLDERS:
return placeholder
else:
self.log.warning(
"Placeholder defined in Connection '%s' is not listed in 'DEFAULT_ODBC_PLACEHOLDERS' "
"and got ignored. Falling back to the default placeholder '%s'.",
placeholder,
self._placeholder,
)
return self._placeholder

def get_uri(self) -> str:
"""URI invoked in :meth:`~airflow.providers.common.sql.hooks.sql.DbApiHook.get_sqlalchemy_engine`."""
quoted_conn_str = quote_plus(self.odbc_connection_string)
Expand Down
5 changes: 2 additions & 3 deletions airflow/providers/postgres/hooks/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,9 +275,8 @@ def get_table_primary_key(self, table: str, schema: str | None = "public") -> li
pk_columns = [row[0] for row in self.get_records(sql, (schema, table))]
return pk_columns or None

@classmethod
def _generate_insert_sql(
cls, table: str, values: tuple[str, ...], target_fields: Iterable[str], replace: bool, **kwargs
self, table: str, values: tuple[str, ...], target_fields: Iterable[str], replace: bool, **kwargs
) -> str:
"""Generate the INSERT SQL statement.

Expand All @@ -292,7 +291,7 @@ def _generate_insert_sql(
:return: The generated INSERT or REPLACE SQL statement
"""
placeholders = [
cls.placeholder,
self.placeholder,
] * len(values)
replace_index = kwargs.get("replace_index")

Expand Down
5 changes: 4 additions & 1 deletion airflow/providers/presto/hooks/presto.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,10 @@ class PrestoHook(DbApiHook):
default_conn_name = "presto_default"
conn_type = "presto"
hook_name = "Presto"
placeholder = "?"

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._placeholder: str = "?"

def get_conn(self) -> Connection:
"""Returns a connection object."""
Expand Down
5 changes: 4 additions & 1 deletion airflow/providers/sqlite/hooks/sqlite.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,10 @@ class SqliteHook(DbApiHook):
default_conn_name = "sqlite_default"
conn_type = "sqlite"
hook_name = "Sqlite"
placeholder = "?"

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._placeholder: str = "?"

def get_conn(self) -> sqlite3.dbapi2.Connection:
"""Returns a sqlite connection object."""
Expand Down
5 changes: 4 additions & 1 deletion airflow/providers/trino/hooks/trino.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,9 +87,12 @@ class TrinoHook(DbApiHook):
conn_type = "trino"
hook_name = "Trino"
query_id = ""
placeholder = "?"
_test_connection_sql = "select 1"

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._placeholder: str = "?"

def get_conn(self) -> Connection:
"""Returns a connection object."""
db = self.get_connection(self.trino_conn_id) # type: ignore[attr-defined]
Expand Down
5 changes: 5 additions & 0 deletions tests/providers/odbc/hooks/test_odbc.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,11 @@ def test_driver_extra_raises_warning_and_returns_default_driver_by_default(self,
assert "have supplied 'driver' via connection extra but it will not be used" in caplog.text
assert driver == "Blah driver"

def test_placeholder_config_from_extra(self):
conn_params = dict(extra=json.dumps(dict(placeholder="?")))
hook = self.get_hook(conn_params=conn_params)
assert hook.placeholder == "?"

def test_database(self):
hook = self.get_hook(hook_params=dict(database="abc"))
assert hook.database == "abc"
Expand Down