Skip to content

Commit

Permalink
DbAPiHook: Don't log a warning message if placeholder is None and mak…
Browse files Browse the repository at this point in the history
…e sure warning message is formatted correctly (apache#39690)

* fix: Don't log a warning message if placeholder is None and make sure if the placeholder is invalid that the warning message is logged correctly

* refactor: Also make sure to verify that log.warning isn't invoked when placeholder is valid

* refactor: All assertions regarding the logging are now done through caplog

* refactor: Reformatted logging assertions

* refactor: Reformatted logging assertion

---------

Co-authored-by: David Blain <[email protected]>
  • Loading branch information
2 people authored and RNHTTR committed Jun 1, 2024
1 parent da9fda7 commit fafa7e0
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 22 deletions.
17 changes: 9 additions & 8 deletions airflow/providers/common/sql/hooks/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,14 +185,15 @@ def __init__(self, *args, schema: str | None = None, log_sql: bool = True, **kwa
def placeholder(self):
conn = self.get_connection(getattr(self, self.conn_name_attr))
placeholder = conn.extra_dejson.get("placeholder")
if placeholder in SQL_PLACEHOLDERS:
return placeholder
self.log.warning(
"Placeholder defined in Connection '%s' is not listed in 'DEFAULT_SQL_PLACEHOLDERS' "
"and got ignored. Falling back to the default placeholder '%s'.",
placeholder,
self._placeholder,
)
if placeholder:
if placeholder in SQL_PLACEHOLDERS:
return placeholder
self.log.warning(
"Placeholder defined in Connection '%s' is not listed in 'DEFAULT_SQL_PLACEHOLDERS' "
"and got ignored. Falling back to the default placeholder '%s'.",
self.conn_name_attr,
self._placeholder,
)
return self._placeholder

def get_conn(self):
Expand Down
83 changes: 69 additions & 14 deletions tests/providers/common/sql/hooks/test_dbapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,11 +46,11 @@ def setup_method(self, **kwargs):
self.conn = mock.MagicMock()
self.conn.cursor.return_value = self.cur
self.conn.schema.return_value = "test_schema"
self.conn.extra_dejson = {}
conn = self.conn

class DbApiHookMock(DbApiHook):
conn_name_attr = "test_conn_id"
log = mock.MagicMock(spec=logging.Logger)

@classmethod
def get_connection(cls, conn_id: str) -> Connection:
Expand All @@ -63,6 +63,7 @@ def get_conn(self):
self.db_hook_no_log_sql = DbApiHookMock(log_sql=False)
self.db_hook_schema_override = DbApiHookMock(schema="schema-override")
self.db_hook.supports_executemany = False
self.db_hook.log.setLevel(logging.DEBUG)

def test_get_records(self):
statement = "SQL"
Expand Down Expand Up @@ -193,39 +194,47 @@ def test_insert_rows_replace_executemany_hana_dialect(self):
sql = f"UPSERT {table} VALUES (%s) WITH PRIMARY KEY"
self.cur.executemany.assert_any_call(sql, rows)

def test_insert_rows_as_generator(self):
def test_insert_rows_as_generator(self, caplog):
table = "table"
rows = [("What's",), ("up",), ("world",)]

self.db_hook.insert_rows(table, iter(rows))
with caplog.at_level(logging.DEBUG):
self.db_hook.insert_rows(table, iter(rows))

assert self.conn.close.call_count == 1
assert self.cur.close.call_count == 1
assert self.conn.commit.call_count == 2

sql = f"INSERT INTO {table} VALUES (%s)"

self.db_hook.log.debug.assert_called_with("Generated sql: %s", sql)
self.db_hook.log.info.assert_called_with("Done loading. Loaded a total of %s rows into %s", 3, table)
assert any(f"Generated sql: {sql}" in message for message in caplog.messages)
assert any(
f"Done loading. Loaded a total of 3 rows into {table}" in message for message in caplog.messages
)

for row in rows:
self.cur.execute.assert_any_call(sql, row)

def test_insert_rows_as_generator_supports_executemany(self):
def test_insert_rows_as_generator_supports_executemany(self, caplog):
table = "table"
rows = [("What's",), ("up",), ("world",)]

self.db_hook.supports_executemany = True
self.db_hook.insert_rows(table, iter(rows))
with caplog.at_level(logging.DEBUG):
self.db_hook.supports_executemany = True
self.db_hook.insert_rows(table, iter(rows))

assert self.conn.close.call_count == 1
assert self.cur.close.call_count == 1
assert self.conn.commit.call_count == 2

sql = f"INSERT INTO {table} VALUES (%s)"

self.db_hook.log.debug.assert_called_with("Generated sql: %s", sql)
self.db_hook.log.info.assert_called_with("Done loading. Loaded a total of %s rows into %s", 3, table)
assert any(f"Generated sql: {sql}" in message for message in caplog.messages)
assert any(f"Loaded 3 rows into {table} so far" in message for message in caplog.messages)
assert any(
f"Done loading. Loaded a total of 3 rows into {table}" in message for message in caplog.messages
)

self.cur.executemany.assert_any_call(sql, rows)

def test_get_uri_schema_not_none(self):
Expand Down Expand Up @@ -421,15 +430,61 @@ def test_get_uri_without_auth_and_empty_host(self):
)
assert self.db_hook.get_uri() == "conn-type://@:3306/schema?charset=utf-8"

def test_run_log(self):
def test_placeholder(self, caplog):
self.db_hook.get_connection = mock.MagicMock(
return_value=Connection(
conn_type="conn-type",
login=None,
password=None,
schema="schema",
port=3306,
)
)
assert self.db_hook.placeholder == "%s"
assert not caplog.messages

def test_placeholder_with_valid_placeholder_in_extra(self, caplog):
self.db_hook.get_connection = mock.MagicMock(
return_value=Connection(
conn_type="conn-type",
login=None,
password=None,
schema="schema",
port=3306,
extra=json.dumps({"placeholder": "?"}),
)
)
assert self.db_hook.placeholder == "?"
assert not caplog.messages

def test_placeholder_with_invalid_placeholder_in_extra(self, caplog):
self.db_hook.get_connection = mock.MagicMock(
return_value=Connection(
conn_type="conn-type",
login=None,
password=None,
schema="schema",
port=3306,
extra=json.dumps({"placeholder": "!"}),
)
)

assert self.db_hook.placeholder == "%s"
assert any(
"Placeholder defined in Connection 'test_conn_id' is not listed in 'DEFAULT_SQL_PLACEHOLDERS' "
"and got ignored. Falling back to the default placeholder '%s'." in message
for message in caplog.messages
)

def test_run_log(self, caplog):
statement = "SQL"
self.db_hook.run(statement)
assert self.db_hook.log.info.call_count == 2
assert len(caplog.messages) == 2

def test_run_no_log(self):
def test_run_no_log(self, caplog):
statement = "SQL"
self.db_hook_no_log_sql.run(statement)
assert self.db_hook_no_log_sql.log.info.call_count == 1
assert len(caplog.messages) == 1

def test_run_with_handler(self):
sql = "SQL"
Expand Down

0 comments on commit fafa7e0

Please sign in to comment.