Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DbAPiHook: Don't log a warning message if placeholder is None and make sure warning message is formatted correctly #39690

Merged
merged 8 commits into from
May 18, 2024
17 changes: 9 additions & 8 deletions airflow/providers/common/sql/hooks/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,14 +185,15 @@ def __init__(self, *args, schema: str | None = None, log_sql: bool = True, **kwa
def placeholder(self):
conn = self.get_connection(getattr(self, self.conn_name_attr))
placeholder = conn.extra_dejson.get("placeholder")
if placeholder in SQL_PLACEHOLDERS:
return placeholder
self.log.warning(
"Placeholder defined in Connection '%s' is not listed in 'DEFAULT_SQL_PLACEHOLDERS' "
"and got ignored. Falling back to the default placeholder '%s'.",
placeholder,
self._placeholder,
)
if placeholder:
if placeholder in SQL_PLACEHOLDERS:
return placeholder
self.log.warning(
"Placeholder defined in Connection '%s' is not listed in 'DEFAULT_SQL_PLACEHOLDERS' "
"and got ignored. Falling back to the default placeholder '%s'.",
self.conn_name_attr,
self._placeholder,
)
return self._placeholder

def get_conn(self):
Expand Down
83 changes: 69 additions & 14 deletions tests/providers/common/sql/hooks/test_dbapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,11 +46,11 @@ def setup_method(self, **kwargs):
self.conn = mock.MagicMock()
self.conn.cursor.return_value = self.cur
self.conn.schema.return_value = "test_schema"
self.conn.extra_dejson = {}
conn = self.conn

class DbApiHookMock(DbApiHook):
conn_name_attr = "test_conn_id"
log = mock.MagicMock(spec=logging.Logger)

@classmethod
def get_connection(cls, conn_id: str) -> Connection:
Expand All @@ -63,6 +63,7 @@ def get_conn(self):
self.db_hook_no_log_sql = DbApiHookMock(log_sql=False)
self.db_hook_schema_override = DbApiHookMock(schema="schema-override")
self.db_hook.supports_executemany = False
self.db_hook.log.setLevel(logging.DEBUG)

def test_get_records(self):
statement = "SQL"
Expand Down Expand Up @@ -193,39 +194,47 @@ def test_insert_rows_replace_executemany_hana_dialect(self):
sql = f"UPSERT {table} VALUES (%s) WITH PRIMARY KEY"
self.cur.executemany.assert_any_call(sql, rows)

def test_insert_rows_as_generator(self):
def test_insert_rows_as_generator(self, caplog):
table = "table"
rows = [("What's",), ("up",), ("world",)]

self.db_hook.insert_rows(table, iter(rows))
with caplog.at_level(logging.DEBUG):
self.db_hook.insert_rows(table, iter(rows))

assert self.conn.close.call_count == 1
assert self.cur.close.call_count == 1
assert self.conn.commit.call_count == 2

sql = f"INSERT INTO {table} VALUES (%s)"

self.db_hook.log.debug.assert_called_with("Generated sql: %s", sql)
self.db_hook.log.info.assert_called_with("Done loading. Loaded a total of %s rows into %s", 3, table)
assert any(f"Generated sql: {sql}" in message for message in caplog.messages)
assert any(
f"Done loading. Loaded a total of 3 rows into {table}" in message for message in caplog.messages
)

for row in rows:
self.cur.execute.assert_any_call(sql, row)

def test_insert_rows_as_generator_supports_executemany(self):
def test_insert_rows_as_generator_supports_executemany(self, caplog):
table = "table"
rows = [("What's",), ("up",), ("world",)]

self.db_hook.supports_executemany = True
self.db_hook.insert_rows(table, iter(rows))
with caplog.at_level(logging.DEBUG):
self.db_hook.supports_executemany = True
self.db_hook.insert_rows(table, iter(rows))

assert self.conn.close.call_count == 1
assert self.cur.close.call_count == 1
assert self.conn.commit.call_count == 2

sql = f"INSERT INTO {table} VALUES (%s)"

self.db_hook.log.debug.assert_called_with("Generated sql: %s", sql)
self.db_hook.log.info.assert_called_with("Done loading. Loaded a total of %s rows into %s", 3, table)
assert any(f"Generated sql: {sql}" in message for message in caplog.messages)
assert any(f"Loaded 3 rows into {table} so far" in message for message in caplog.messages)
assert any(
f"Done loading. Loaded a total of 3 rows into {table}" in message for message in caplog.messages
)

self.cur.executemany.assert_any_call(sql, rows)

def test_get_uri_schema_not_none(self):
Expand Down Expand Up @@ -421,15 +430,61 @@ def test_get_uri_without_auth_and_empty_host(self):
)
assert self.db_hook.get_uri() == "conn-type://@:3306/schema?charset=utf-8"

def test_run_log(self):
def test_placeholder(self, caplog):
self.db_hook.get_connection = mock.MagicMock(
return_value=Connection(
conn_type="conn-type",
login=None,
password=None,
schema="schema",
port=3306,
)
)
assert self.db_hook.placeholder == "%s"
assert not caplog.messages

def test_placeholder_with_valid_placeholder_in_extra(self, caplog):
self.db_hook.get_connection = mock.MagicMock(
return_value=Connection(
conn_type="conn-type",
login=None,
password=None,
schema="schema",
port=3306,
extra=json.dumps({"placeholder": "?"}),
)
)
assert self.db_hook.placeholder == "?"
assert not caplog.messages

def test_placeholder_with_invalid_placeholder_in_extra(self, caplog):
self.db_hook.get_connection = mock.MagicMock(
return_value=Connection(
conn_type="conn-type",
login=None,
password=None,
schema="schema",
port=3306,
extra=json.dumps({"placeholder": "!"}),
)
)

assert self.db_hook.placeholder == "%s"
assert any(
"Placeholder defined in Connection 'test_conn_id' is not listed in 'DEFAULT_SQL_PLACEHOLDERS' "
"and got ignored. Falling back to the default placeholder '%s'." in message
for message in caplog.messages
)

def test_run_log(self, caplog):
statement = "SQL"
self.db_hook.run(statement)
assert self.db_hook.log.info.call_count == 2
assert len(caplog.messages) == 2

def test_run_no_log(self):
def test_run_no_log(self, caplog):
statement = "SQL"
self.db_hook_no_log_sql.run(statement)
assert self.db_hook_no_log_sql.log.info.call_count == 1
assert len(caplog.messages) == 1

def test_run_with_handler(self):
sql = "SQL"
Expand Down