Skip to content

Commit

Permalink
Fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
betodealmeida committed Aug 25, 2022
1 parent b2e8c66 commit f77246c
Show file tree
Hide file tree
Showing 7 changed files with 111 additions and 50 deletions.
4 changes: 2 additions & 2 deletions superset/databases/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -1083,8 +1083,8 @@ def available(self) -> Response:
"preferred": engine_spec.engine_name in preferred_databases,
}

if hasattr(engine_spec, "default_driver"):
payload["default_driver"] = engine_spec.default_driver # type: ignore
if engine_spec.default_driver:
payload["default_driver"] = engine_spec.default_driver

# show configuration parameters for DBs that support it
if (
Expand Down
6 changes: 6 additions & 0 deletions superset/db_engine_specs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,12 @@ def get_engine_spec(backend: str, driver: Optional[str] = None) -> Type[BaseEngi
if engine_spec.supports_backend(backend, driver):
return engine_spec

# check ignoring the driver, in order to support new drivers; this will return a
# random DB engine spec that supports the engine
for engine_spec in engine_specs:
if engine_spec.supports_backend(backend):
return engine_spec

# default to the generic DB engine spec
return BaseEngineSpec

Expand Down
17 changes: 10 additions & 7 deletions superset/db_engine_specs/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,15 +185,18 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods

engine_name: Optional[str] = None # for user messages, overridden in child classes

# Associate the DB engine spec to one or more SQLAlchemy dialects/drivers. For
# example, if a given DB engine spec has:
# These attributes map the DB engine spec to one or more SQLAlchemy dialects/drivers.
# For example, if a given DB engine spec has:
#
# class PostgresDBEngineSpec:
# engine = 'postgresql'
# engine_aliases = 'postgres'
# drivers = {'psycopg2', 'asyncpg'}
# engine = "postgresql"
# engine_aliases = "postgres"
# drivers = {
# "psycopg2": "The default Postgres driver",
# "asyncpg": "An asynchronous Postgres driver",
# }
#
# It would be used for all the following SQLALchemy URIs:
# It would be used for all the following SQLAlchemy URIs:
#
# - postgres://user:password@host/db
# - postgresql://user:password@host/db
Expand Down Expand Up @@ -450,7 +453,7 @@ def get_allow_cost_estimate( # pylint: disable=unused-argument
@classmethod
def get_text_clause(cls, clause: str) -> TextClause:
"""
SQLALchemy wrapper to ensure text clauses are escaped properly
SQLAlchemy wrapper to ensure text clauses are escaped properly
:param clause: string clause with potentially unescaped characters
:return: text clause with escaped characters
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import pytest

from superset.connectors.sqla.models import TableColumn
from superset.db_engine_specs import get_engine_specs
from superset.db_engine_specs import load_engine_specs
from superset.db_engine_specs.base import (
BaseEngineSpec,
BasicParametersMixin,
Expand Down Expand Up @@ -195,7 +195,7 @@ class DummyEngineSpec(BaseEngineSpec):
def test_engine_time_grain_validity(self):
time_grains = set(builtin_time_grains.keys())
# loop over all subclasses of BaseEngineSpec
for engine in get_engine_specs().values():
for engine in load_engine_specs():
if engine is not BaseEngineSpec:
# make sure time grain functions have been defined
self.assertGreater(len(engine.get_time_grain_expressions()), 0)
Expand Down
4 changes: 2 additions & 2 deletions tests/integration_tests/db_engine_specs/postgres_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from sqlalchemy import column, literal_column
from sqlalchemy.dialects import postgresql

from superset.db_engine_specs import get_engine_specs
from superset.db_engine_specs import load_engine_specs
from superset.db_engine_specs.postgres import PostgresEngineSpec
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
from superset.models.sql_lab import Query
Expand Down Expand Up @@ -137,7 +137,7 @@ def test_engine_alias_name(self):
"""
DB Eng Specs (postgres): Test "postgres" in engine spec
"""
self.assertIn("postgres", get_engine_specs())
self.assertIn("postgres", [engine.engine for engine in load_engine_specs()])

def test_extras_without_ssl(self):
db = mock.Mock()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,31 +15,59 @@
# specific language governing permissions and limitations
# under the License.

from unittest import mock
# pylint: disable=import-outside-toplevel, invalid-name, unused-argument, redefined-outer-name

from typing import TYPE_CHECKING

import pytest
from marshmallow import fields, Schema, ValidationError
from pytest_mock import MockFixture

if TYPE_CHECKING:
from superset.databases.schemas import DatabaseParametersSchemaMixin
from superset.db_engine_specs.base import BasicParametersMixin

from superset.databases.schemas import DatabaseParametersSchemaMixin
from superset.db_engine_specs.base import BasicParametersMixin
from superset.models.core import ConfigurationMethod

# pylint: disable=too-few-public-methods
class InvalidEngine:
"""
An invalid DB engine spec.
"""

class DummySchema(Schema, DatabaseParametersSchemaMixin):
sqlalchemy_uri = fields.String()

@pytest.fixture
def dummy_schema() -> "DatabaseParametersSchemaMixin":
"""
Fixture providing a dummy schema.
"""
from superset.databases.schemas import DatabaseParametersSchemaMixin

class DummyEngine(BasicParametersMixin):
engine = "dummy"
default_driver = "dummy"
class DummySchema(Schema, DatabaseParametersSchemaMixin):
sqlalchemy_uri = fields.String()

return DummySchema()


@pytest.fixture
def dummy_engine(mocker: MockFixture) -> None:
"""
Fixture proving a dummy DB engine spec.
"""
from superset.db_engine_specs.base import BasicParametersMixin

class DummyEngine(BasicParametersMixin):
engine = "dummy"
default_driver = "dummy"

mocker.patch("superset.databases.schemas.get_engine_spec", return_value=DummyEngine)

class InvalidEngine:
pass

def test_database_parameters_schema_mixin(
dummy_engine: None,
dummy_schema: "Schema",
) -> None:
from superset.models.core import ConfigurationMethod

@mock.patch("superset.databases.schemas.get_engine_specs")
def test_database_parameters_schema_mixin(get_engine_specs):
get_engine_specs.return_value = {"dummy_engine": DummyEngine}
payload = {
"engine": "dummy_engine",
"configuration_method": ConfigurationMethod.DYNAMIC_FORM,
Expand All @@ -51,39 +79,47 @@ def test_database_parameters_schema_mixin(get_engine_specs):
"database": "dbname",
},
}
schema = DummySchema()
result = schema.load(payload)
result = dummy_schema.load(payload)
assert result == {
"configuration_method": ConfigurationMethod.DYNAMIC_FORM,
"sqlalchemy_uri": "dummy+dummy://username:password@localhost:12345/dbname",
}


def test_database_parameters_schema_mixin_no_engine():
def test_database_parameters_schema_mixin_no_engine(
dummy_schema: "Schema",
) -> None:
from superset.models.core import ConfigurationMethod

payload = {
"configuration_method": ConfigurationMethod.DYNAMIC_FORM,
"parameters": {
"username": "username",
"password": "password",
"host": "localhost",
"port": 12345,
"dbname": "dbname",
"database": "dbname",
},
}
schema = DummySchema()
try:
schema.load(payload)
dummy_schema.load(payload)
except ValidationError as err:
assert err.messages == {
"_schema": [
"An engine must be specified when passing individual parameters to a database."
(
"An engine must be specified when passing individual parameters to "
"a database."
),
]
}


@mock.patch("superset.databases.schemas.get_engine_specs")
def test_database_parameters_schema_mixin_invalid_engine(get_engine_specs):
get_engine_specs.return_value = {}
def test_database_parameters_schema_mixin_invalid_engine(
dummy_engine: None,
dummy_schema: "Schema",
) -> None:
from superset.models.core import ConfigurationMethod

payload = {
"engine": "dummy_engine",
"configuration_method": ConfigurationMethod.DYNAMIC_FORM,
Expand All @@ -92,21 +128,24 @@ def test_database_parameters_schema_mixin_invalid_engine(get_engine_specs):
"password": "password",
"host": "localhost",
"port": 12345,
"dbname": "dbname",
"database": "dbname",
},
}
schema = DummySchema()
try:
schema.load(payload)
dummy_schema.load(payload)
except ValidationError as err:
print(err.messages)
assert err.messages == {
"_schema": ['Engine "dummy_engine" is not a valid engine.']
}


@mock.patch("superset.databases.schemas.get_engine_specs")
def test_database_parameters_schema_no_mixin(get_engine_specs):
get_engine_specs.return_value = {"invalid_engine": InvalidEngine}
def test_database_parameters_schema_no_mixin(
dummy_engine: None,
dummy_schema: "Schema",
) -> None:
from superset.models.core import ConfigurationMethod

payload = {
"engine": "invalid_engine",
"configuration_method": ConfigurationMethod.DYNAMIC_FORM,
Expand All @@ -118,9 +157,8 @@ def test_database_parameters_schema_no_mixin(get_engine_specs):
"database": "dbname",
},
}
schema = DummySchema()
try:
schema.load(payload)
dummy_schema.load(payload)
except ValidationError as err:
assert err.messages == {
"_schema": [
Expand All @@ -132,9 +170,12 @@ def test_database_parameters_schema_no_mixin(get_engine_specs):
}


@mock.patch("superset.databases.schemas.get_engine_specs")
def test_database_parameters_schema_mixin_invalid_type(get_engine_specs):
get_engine_specs.return_value = {"dummy_engine": DummyEngine}
def test_database_parameters_schema_mixin_invalid_type(
dummy_engine: None,
dummy_schema: "Schema",
) -> None:
from superset.models.core import ConfigurationMethod

payload = {
"engine": "dummy_engine",
"configuration_method": ConfigurationMethod.DYNAMIC_FORM,
Expand All @@ -146,8 +187,7 @@ def test_database_parameters_schema_mixin_invalid_type(get_engine_specs):
"database": "dbname",
},
}
schema = DummySchema()
try:
schema.load(payload)
dummy_schema.load(payload)
except ValidationError as err:
assert err.messages == {"port": ["Not a valid integer."]}
12 changes: 12 additions & 0 deletions tests/unit_tests/models/core_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,12 @@ class OldDBEngineSpec(BaseEngineSpec):
).db_engine_spec
== PostgresDBEngineSpec
)
assert (
Database(
database_name="db", sqlalchemy_uri="postgresql+fancynewdriver://"
).db_engine_spec
== PostgresDBEngineSpec
)
assert (
Database(database_name="db", sqlalchemy_uri="mysql://").db_engine_spec
== OldDBEngineSpec
Expand All @@ -133,3 +139,9 @@ class OldDBEngineSpec(BaseEngineSpec):
).db_engine_spec
== OldDBEngineSpec
)
assert (
Database(
database_name="db", sqlalchemy_uri="mysql+fancynewdriver://"
).db_engine_spec
== OldDBEngineSpec
)

0 comments on commit f77246c

Please sign in to comment.