From f77246c1338f0561e2765b2d801d86a21ffa7bf9 Mon Sep 17 00:00:00 2001 From: Beto Dealmeida Date: Thu, 25 Aug 2022 17:58:01 -0300 Subject: [PATCH] Fix tests --- superset/databases/api.py | 4 +- superset/db_engine_specs/__init__.py | 6 + superset/db_engine_specs/base.py | 17 +-- .../db_engine_specs/base_engine_spec_tests.py | 4 +- .../db_engine_specs/postgres_tests.py | 4 +- .../databases/schema_tests.py | 114 ++++++++++++------ tests/unit_tests/models/core_test.py | 12 ++ 7 files changed, 111 insertions(+), 50 deletions(-) rename tests/{integration_tests => unit_tests}/databases/schema_tests.py (57%) diff --git a/superset/databases/api.py b/superset/databases/api.py index a6160b0d2f1fe..4c617eb720519 100644 --- a/superset/databases/api.py +++ b/superset/databases/api.py @@ -1083,8 +1083,8 @@ def available(self) -> Response: "preferred": engine_spec.engine_name in preferred_databases, } - if hasattr(engine_spec, "default_driver"): - payload["default_driver"] = engine_spec.default_driver # type: ignore + if engine_spec.default_driver: + payload["default_driver"] = engine_spec.default_driver # show configuration parameters for DBs that support it if ( diff --git a/superset/db_engine_specs/__init__.py b/superset/db_engine_specs/__init__.py index 257f6481bccba..fa015e8fed703 100644 --- a/superset/db_engine_specs/__init__.py +++ b/superset/db_engine_specs/__init__.py @@ -95,6 +95,12 @@ def get_engine_spec(backend: str, driver: Optional[str] = None) -> Type[BaseEngi if engine_spec.supports_backend(backend, driver): return engine_spec + # check ignoring the driver, in order to support new drivers; this will return a + # random DB engine spec that supports the engine + for engine_spec in engine_specs: + if engine_spec.supports_backend(backend): + return engine_spec + # default to the generic DB engine spec return BaseEngineSpec diff --git a/superset/db_engine_specs/base.py b/superset/db_engine_specs/base.py index 1ce802781999a..bd6573cb06bed 100644 --- a/superset/db_engine_specs/base.py +++ b/superset/db_engine_specs/base.py @@ -185,15 +185,18 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods engine_name: Optional[str] = None # for user messages, overridden in child classes - # Associate the DB engine spec to one or more SQLAlchemy dialects/drivers. For - # example, if a given DB engine spec has: + # These attributes map the DB engine spec to one or more SQLAlchemy dialects/drivers. + # For example, if a given DB engine spec has: # # class PostgresDBEngineSpec: - # engine = 'postgresql' - # engine_aliases = 'postgres' - # drivers = {'psycopg2', 'asyncpg'} + # engine = "postgresql" + # engine_aliases = "postgres" + # drivers = { + # "psycopg2": "The default Postgres driver", + # "asyncpg": "An asynchronous Postgres driver", + # } # - # It would be used for all the following SQLALchemy URIs: + # It would be used for all the following SQLAlchemy URIs: # # - postgres://user:password@host/db # - postgresql://user:password@host/db @@ -450,7 +453,7 @@ def get_allow_cost_estimate( # pylint: disable=unused-argument @classmethod def get_text_clause(cls, clause: str) -> TextClause: """ - SQLALchemy wrapper to ensure text clauses are escaped properly + SQLAlchemy wrapper to ensure text clauses are escaped properly :param clause: string clause with potentially unescaped characters :return: text clause with escaped characters diff --git a/tests/integration_tests/db_engine_specs/base_engine_spec_tests.py b/tests/integration_tests/db_engine_specs/base_engine_spec_tests.py index 07f9bfcf318dc..f998444f31895 100644 --- a/tests/integration_tests/db_engine_specs/base_engine_spec_tests.py +++ b/tests/integration_tests/db_engine_specs/base_engine_spec_tests.py @@ -20,7 +20,7 @@ import pytest from superset.connectors.sqla.models import TableColumn -from superset.db_engine_specs import get_engine_specs +from superset.db_engine_specs import load_engine_specs from superset.db_engine_specs.base import ( BaseEngineSpec, BasicParametersMixin, @@ -195,7 +195,7 @@ class DummyEngineSpec(BaseEngineSpec): def test_engine_time_grain_validity(self): time_grains = set(builtin_time_grains.keys()) # loop over all subclasses of BaseEngineSpec - for engine in get_engine_specs().values(): + for engine in load_engine_specs(): if engine is not BaseEngineSpec: # make sure time grain functions have been defined self.assertGreater(len(engine.get_time_grain_expressions()), 0) diff --git a/tests/integration_tests/db_engine_specs/postgres_tests.py b/tests/integration_tests/db_engine_specs/postgres_tests.py index e6eb4fc1d13ea..17df25000b364 100644 --- a/tests/integration_tests/db_engine_specs/postgres_tests.py +++ b/tests/integration_tests/db_engine_specs/postgres_tests.py @@ -20,7 +20,7 @@ from sqlalchemy import column, literal_column from sqlalchemy.dialects import postgresql -from superset.db_engine_specs import get_engine_specs +from superset.db_engine_specs import load_engine_specs from superset.db_engine_specs.postgres import PostgresEngineSpec from superset.errors import ErrorLevel, SupersetError, SupersetErrorType from superset.models.sql_lab import Query @@ -137,7 +137,7 @@ def test_engine_alias_name(self): """ DB Eng Specs (postgres): Test "postgres" in engine spec """ - self.assertIn("postgres", get_engine_specs()) + self.assertIn("postgres", [engine.engine for engine in load_engine_specs()]) def test_extras_without_ssl(self): db = mock.Mock() diff --git a/tests/integration_tests/databases/schema_tests.py b/tests/unit_tests/databases/schema_tests.py similarity index 57% rename from tests/integration_tests/databases/schema_tests.py rename to tests/unit_tests/databases/schema_tests.py index 1f8ca067f6b0d..58a1f6389d4c1 100644 --- a/tests/integration_tests/databases/schema_tests.py +++ b/tests/unit_tests/databases/schema_tests.py @@ -15,31 +15,59 @@ # specific language governing permissions and limitations # under the License. -from unittest import mock +# pylint: disable=import-outside-toplevel, invalid-name, unused-argument, redefined-outer-name +from typing import TYPE_CHECKING + +import pytest from marshmallow import fields, Schema, ValidationError +from pytest_mock import MockFixture + +if TYPE_CHECKING: + from superset.databases.schemas import DatabaseParametersSchemaMixin + from superset.db_engine_specs.base import BasicParametersMixin -from superset.databases.schemas import DatabaseParametersSchemaMixin -from superset.db_engine_specs.base import BasicParametersMixin -from superset.models.core import ConfigurationMethod +# pylint: disable=too-few-public-methods +class InvalidEngine: + """ + An invalid DB engine spec. + """ -class DummySchema(Schema, DatabaseParametersSchemaMixin): - sqlalchemy_uri = fields.String() +@pytest.fixture +def dummy_schema() -> "DatabaseParametersSchemaMixin": + """ + Fixture providing a dummy schema. + """ + from superset.databases.schemas import DatabaseParametersSchemaMixin -class DummyEngine(BasicParametersMixin): - engine = "dummy" - default_driver = "dummy" + class DummySchema(Schema, DatabaseParametersSchemaMixin): + sqlalchemy_uri = fields.String() + return DummySchema() + + +@pytest.fixture +def dummy_engine(mocker: MockFixture) -> None: + """ + Fixture proving a dummy DB engine spec. + """ + from superset.db_engine_specs.base import BasicParametersMixin + + class DummyEngine(BasicParametersMixin): + engine = "dummy" + default_driver = "dummy" + + mocker.patch("superset.databases.schemas.get_engine_spec", return_value=DummyEngine) -class InvalidEngine: - pass +def test_database_parameters_schema_mixin( + dummy_engine: None, + dummy_schema: "Schema", +) -> None: + from superset.models.core import ConfigurationMethod -@mock.patch("superset.databases.schemas.get_engine_specs") -def test_database_parameters_schema_mixin(get_engine_specs): - get_engine_specs.return_value = {"dummy_engine": DummyEngine} payload = { "engine": "dummy_engine", "configuration_method": ConfigurationMethod.DYNAMIC_FORM, @@ -51,15 +79,18 @@ def test_database_parameters_schema_mixin(get_engine_specs): "database": "dbname", }, } - schema = DummySchema() - result = schema.load(payload) + result = dummy_schema.load(payload) assert result == { "configuration_method": ConfigurationMethod.DYNAMIC_FORM, "sqlalchemy_uri": "dummy+dummy://username:password@localhost:12345/dbname", } -def test_database_parameters_schema_mixin_no_engine(): +def test_database_parameters_schema_mixin_no_engine( + dummy_schema: "Schema", +) -> None: + from superset.models.core import ConfigurationMethod + payload = { "configuration_method": ConfigurationMethod.DYNAMIC_FORM, "parameters": { @@ -67,23 +98,28 @@ def test_database_parameters_schema_mixin_no_engine(): "password": "password", "host": "localhost", "port": 12345, - "dbname": "dbname", + "database": "dbname", }, } - schema = DummySchema() try: - schema.load(payload) + dummy_schema.load(payload) except ValidationError as err: assert err.messages == { "_schema": [ - "An engine must be specified when passing individual parameters to a database." + ( + "An engine must be specified when passing individual parameters to " + "a database." + ), ] } -@mock.patch("superset.databases.schemas.get_engine_specs") -def test_database_parameters_schema_mixin_invalid_engine(get_engine_specs): - get_engine_specs.return_value = {} +def test_database_parameters_schema_mixin_invalid_engine( + dummy_engine: None, + dummy_schema: "Schema", +) -> None: + from superset.models.core import ConfigurationMethod + payload = { "engine": "dummy_engine", "configuration_method": ConfigurationMethod.DYNAMIC_FORM, @@ -92,21 +128,24 @@ def test_database_parameters_schema_mixin_invalid_engine(get_engine_specs): "password": "password", "host": "localhost", "port": 12345, - "dbname": "dbname", + "database": "dbname", }, } - schema = DummySchema() try: - schema.load(payload) + dummy_schema.load(payload) except ValidationError as err: + print(err.messages) assert err.messages == { "_schema": ['Engine "dummy_engine" is not a valid engine.'] } -@mock.patch("superset.databases.schemas.get_engine_specs") -def test_database_parameters_schema_no_mixin(get_engine_specs): - get_engine_specs.return_value = {"invalid_engine": InvalidEngine} +def test_database_parameters_schema_no_mixin( + dummy_engine: None, + dummy_schema: "Schema", +) -> None: + from superset.models.core import ConfigurationMethod + payload = { "engine": "invalid_engine", "configuration_method": ConfigurationMethod.DYNAMIC_FORM, @@ -118,9 +157,8 @@ def test_database_parameters_schema_no_mixin(get_engine_specs): "database": "dbname", }, } - schema = DummySchema() try: - schema.load(payload) + dummy_schema.load(payload) except ValidationError as err: assert err.messages == { "_schema": [ @@ -132,9 +170,12 @@ def test_database_parameters_schema_no_mixin(get_engine_specs): } -@mock.patch("superset.databases.schemas.get_engine_specs") -def test_database_parameters_schema_mixin_invalid_type(get_engine_specs): - get_engine_specs.return_value = {"dummy_engine": DummyEngine} +def test_database_parameters_schema_mixin_invalid_type( + dummy_engine: None, + dummy_schema: "Schema", +) -> None: + from superset.models.core import ConfigurationMethod + payload = { "engine": "dummy_engine", "configuration_method": ConfigurationMethod.DYNAMIC_FORM, @@ -146,8 +187,7 @@ def test_database_parameters_schema_mixin_invalid_type(get_engine_specs): "database": "dbname", }, } - schema = DummySchema() try: - schema.load(payload) + dummy_schema.load(payload) except ValidationError as err: assert err.messages == {"port": ["Not a valid integer."]} diff --git a/tests/unit_tests/models/core_test.py b/tests/unit_tests/models/core_test.py index 4ae429c677aab..5eb60dc6f93ef 100644 --- a/tests/unit_tests/models/core_test.py +++ b/tests/unit_tests/models/core_test.py @@ -123,6 +123,12 @@ class OldDBEngineSpec(BaseEngineSpec): ).db_engine_spec == PostgresDBEngineSpec ) + assert ( + Database( + database_name="db", sqlalchemy_uri="postgresql+fancynewdriver://" + ).db_engine_spec + == PostgresDBEngineSpec + ) assert ( Database(database_name="db", sqlalchemy_uri="mysql://").db_engine_spec == OldDBEngineSpec @@ -133,3 +139,9 @@ class OldDBEngineSpec(BaseEngineSpec): ).db_engine_spec == OldDBEngineSpec ) + assert ( + Database( + database_name="db", sqlalchemy_uri="mysql+fancynewdriver://" + ).db_engine_spec + == OldDBEngineSpec + )