Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: % replace in values_for_column #28271

Merged
merged 2 commits into from
Apr 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion superset/models/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1377,10 +1377,14 @@ def values_for_column(
qry = qry.where(self.get_fetch_values_predicate(template_processor=tp))

with self.database.get_sqla_engine() as engine:
sql = qry.compile(engine, compile_kwargs={"literal_binds": True})
sql = str(qry.compile(engine, compile_kwargs={"literal_binds": True}))
sql = self._apply_cte(sql, cte)
sql = self.database.mutate_sql_based_on_config(sql)

# pylint: disable=protected-access
if engine.dialect.identifier_preparer._double_percents:
sql = sql.replace("%%", "%")

df = pd.read_sql_query(sql=sql, con=engine)
# replace NaN with None to ensure it can be serialized to JSON
df = df.replace({np.nan: None})
Expand Down
111 changes: 98 additions & 13 deletions tests/unit_tests/models/helpers_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,22 +17,24 @@

# pylint: disable=import-outside-toplevel

from __future__ import annotations

from contextlib import contextmanager
from typing import TYPE_CHECKING

import pytest
from pytest_mock import MockerFixture
from sqlalchemy import create_engine
from sqlalchemy.orm.session import Session
from sqlalchemy.pool import StaticPool

if TYPE_CHECKING:
from superset.models.core import Database

def test_values_for_column(mocker: MockerFixture, session: Session) -> None:
"""
Test the `values_for_column` method.

NULL values should be returned as `None`, not `np.nan`, since NaN cannot be
serialized to JSON.
"""
from superset.connectors.sqla.models import SqlaTable, TableColumn
@pytest.fixture()
def database(mocker: MockerFixture, session: Session) -> Database:
from superset.connectors.sqla.models import SqlaTable
from superset.models.core import Database

SqlaTable.metadata.create_all(session.get_bind())
Expand All @@ -42,13 +44,12 @@ def test_values_for_column(mocker: MockerFixture, session: Session) -> None:
connect_args={"check_same_thread": False},
poolclass=StaticPool,
)

database = Database(database_name="db", sqlalchemy_uri="sqlite://")

connection = engine.raw_connection()
connection.execute("CREATE TABLE t (c INTEGER)")
connection.execute("INSERT INTO t VALUES (1)")
connection.execute("INSERT INTO t VALUES (NULL)")
connection.execute("CREATE TABLE t (a INTEGER, b TEXT)")
connection.execute("INSERT INTO t VALUES (1, 'Alice')")
connection.execute("INSERT INTO t VALUES (NULL, 'Bob')")
connection.commit()

# since we're using an in-memory SQLite database, make sure we always
Expand All @@ -63,10 +64,94 @@ def mock_get_sqla_engine():
new=mock_get_sqla_engine,
)

return database


def test_values_for_column(database: Database) -> None:
"""
Test the `values_for_column` method.

NULL values should be returned as `None`, not `np.nan`, since NaN cannot be
serialized to JSON.
"""
from superset.connectors.sqla.models import SqlaTable, TableColumn

table = SqlaTable(
database=database,
schema=None,
table_name="t",
columns=[TableColumn(column_name="a")],
)
assert table.values_for_column("a") == [1, None]


def test_values_for_column_calculated(
mocker: MockerFixture,
database: Database,
) -> None:
"""
Test that calculated columns work.
"""
from superset.connectors.sqla.models import SqlaTable, TableColumn

table = SqlaTable(
database=database,
schema=None,
table_name="t",
columns=[
TableColumn(
column_name="starts_with_A",
expression="CASE WHEN b LIKE 'A%' THEN 'yes' ELSE 'nope' END",
)
],
)
assert table.values_for_column("starts_with_A") == ["yes", "nope"]


def test_values_for_column_double_percents(
mocker: MockerFixture,
database: Database,
) -> None:
"""
Test the behavior of `double_percents`.
"""
from superset.connectors.sqla.models import SqlaTable, TableColumn

with database.get_sqla_engine() as engine:
engine.dialect.identifier_preparer._double_percents = "pyformat"

table = SqlaTable(
database=database,
schema=None,
table_name="t",
columns=[TableColumn(column_name="c")],
columns=[
TableColumn(
column_name="starts_with_A",
expression="CASE WHEN b LIKE 'A%' THEN 'yes' ELSE 'nope' END",
)
],
)

mutate_sql_based_on_config = mocker.patch.object(
database,
"mutate_sql_based_on_config",
side_effect=lambda sql: sql,
)
pd = mocker.patch("superset.models.helpers.pd")

table.values_for_column("starts_with_A")

# make sure the SQL originally had double percents
mutate_sql_based_on_config.assert_called_with(
"SELECT DISTINCT CASE WHEN b LIKE 'A%%' THEN 'yes' ELSE 'nope' END "
"AS column_values \nFROM t\n LIMIT 10000 OFFSET 0"
)
assert table.values_for_column("c") == [1, None]
# make sure final query has single percents
with database.get_sqla_engine() as engine:
pd.read_sql_query.assert_called_with(
sql=(
"SELECT DISTINCT CASE WHEN b LIKE 'A%' THEN 'yes' ELSE 'nope' END "
"AS column_values \nFROM t\n LIMIT 10000 OFFSET 0"
),
con=engine,
)
Loading