Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix issue #1312: Type (typehint) error when calling db.Model subclass constructor with parameters #1321

Open
wants to merge 10 commits into
base: stable
Choose a base branch
from
196 changes: 158 additions & 38 deletions src/flask_sqlalchemy/extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,12 @@
import sqlalchemy.event as sa_event
import sqlalchemy.exc as sa_exc
import sqlalchemy.orm as sa_orm
import typing_extensions as te
from flask import abort
from flask import current_app
from flask import Flask
from flask import has_app_context
from sqlalchemy.util import typing as compat_typing

from .model import _QueryProperty
from .model import BindMixin
Expand All @@ -32,22 +34,124 @@


# Type accepted for model_class argument
_FSA_MCT = t.TypeVar(
"_FSA_MCT",
bound=t.Union[
t.Type[Model],
sa_orm.DeclarativeMeta,
t.Type[sa_orm.DeclarativeBase],
t.Type[sa_orm.DeclarativeBaseNoMeta],
],
)
_FSA_MCT = t.Union[
t.Type[Model],
sa_orm.DeclarativeMeta,
t.Type[sa_orm.DeclarativeBase],
t.Type[sa_orm.DeclarativeBaseNoMeta],
t.Type[sa_orm.MappedAsDataclass],
]
_FSA_MCT_T = t.TypeVar("_FSA_MCT_T", bound=_FSA_MCT, covariant=True)


# Type returned by make_declarative_base
class _FSAModel(Model):
metadata: sa.MetaData


if t.TYPE_CHECKING:

class _FSAModel_KW(_FSAModel):
def __init__(self, **kw: t.Any) -> None: ...

else:
# To minimize side effects, the type hint only works for static type checker.
# At run time, `_FSAModel_KW` falls back to `_FSAModel`
_FSAModel_KW = _FSAModel


if t.TYPE_CHECKING:

@compat_typing.dataclass_transform(
field_specifiers=(
sa_orm.MappedColumn,
sa_orm.RelationshipProperty,
sa_orm.Composite,
sa_orm.Synonym,
sa_orm.mapped_column,
sa_orm.relationship,
sa_orm.composite,
sa_orm.synonym,
sa_orm.deferred,
),
)
class _FSAModel_DataClass(_FSAModel): ...

else:
# To minimize side effects, the type hint only works for static type checker.
# At run time, `_FSAModel_DataClass` falls back to `_FSAModel`
_FSAModel_DataClass = _FSAModel


class ModelGetter:
"""Model getter for the ``SQLAlchemy().Model`` property.

This getter is used for determining the correct type of ``SQLAlchemy().Model``.

When ``SQLAlchemy`` is initialized by

.. code-block:: python

db = SQLAlchemy(model_class=MappedAsDataclass)

the ``db.Model`` property needs to be a class decorated by ``dataclass_transform``.

Otherwise, the ``db.Model`` property needs to provide a synthesized initialization
method accepting unknown keyword arguments. These keyword arguments are not
annotated but limited in the range of data items. This rule is guaranteed by the
featuers of all other candidates of ``model_class``.

Calling the class property ``SQLAlchemy.Model`` will return this descriptor
directly.
"""

# This variant is at first. Its priority is highest for making SQLAlchemy[Any]
# exports a Model with type[_FSAModel_KW].
# Note that in actual using cases, users do not need to inherit Model classes.
@te.overload
def __get__(
self, obj: SQLAlchemy[type[Model]], obj_cls: t.Any = None
) -> type[_FSAModel_KW]: ...

# This variant needs to be prior than DeclarativeBase, because a class may inherit
# multiple classes. When both MappedAsDataclass and DeclarativeBase are in the MRO
# list, this configuration make type[_FSAModel_DataClass] preferred.
@te.overload
def __get__(
self, obj: SQLAlchemy[type[sa_orm.MappedAsDataclass]], obj_cls: t.Any = None
) -> type[_FSAModel_DataClass]: ...

@te.overload
def __get__(
self, obj: SQLAlchemy[type[sa_orm.DeclarativeBase]], obj_cls: t.Any = None
) -> type[_FSAModel_KW]: ...

@te.overload
def __get__(
self,
obj: SQLAlchemy[type[sa_orm.DeclarativeBaseNoMeta]],
obj_cls: t.Any = None,
) -> type[_FSAModel_KW]: ...

@te.overload
def __get__(
self, obj: SQLAlchemy[sa_orm.DeclarativeMeta], obj_cls: t.Any = None
) -> type[_FSAModel_KW]: ...

@te.overload
def __get__(
self: te.Self, obj: None, obj_cls: type[SQLAlchemy[t.Any]] | None = None
) -> type[_FSAModel]: ...

def __get__(
self: te.Self, obj: SQLAlchemy[t.Any] | None, obj_cls: t.Any = None
) -> te.Self | type[Model] | type[t.Any]:
if isinstance(obj, SQLAlchemy):
return obj._Model
else:
return self


def _get_2x_declarative_bases(
model_class: _FSA_MCT,
) -> list[type[sa_orm.DeclarativeBase | sa_orm.DeclarativeBaseNoMeta]]:
Expand All @@ -58,7 +162,7 @@ def _get_2x_declarative_bases(
]


class SQLAlchemy:
class SQLAlchemy(t.Generic[_FSA_MCT_T]):
"""Integrates SQLAlchemy with Flask. This handles setting up one or more engines,
associating tables and models with specific engines, and cleaning up connections and
sessions after each request.
Expand Down Expand Up @@ -168,7 +272,7 @@ def __init__(
metadata: sa.MetaData | None = None,
session_options: dict[str, t.Any] | None = None,
query_class: type[Query] = Query,
model_class: _FSA_MCT = Model, # type: ignore[assignment]
model_class: _FSA_MCT_T = Model, # type: ignore[assignment]
engine_options: dict[str, t.Any] | None = None,
add_models_to_shell: bool = True,
disable_autonaming: bool = False,
Expand Down Expand Up @@ -241,29 +345,17 @@ def __init__(
This is a subclass of SQLAlchemy's ``Table`` rather than a function.
"""

self.Model = self._make_declarative_base(
self._Model = self._make_declarative_base(
model_class, disable_autonaming=disable_autonaming
)
"""A SQLAlchemy declarative model class. Subclass this to define database
models.

If a model does not set ``__tablename__``, it will be generated by converting
the class name from ``CamelCase`` to ``snake_case``. It will not be generated
if the model looks like it uses single-table inheritance.

If a model or parent class sets ``__bind_key__``, it will use that metadata and
database engine. Otherwise, it will use the default :attr:`metadata` and
:attr:`engine`. This is ignored if the model sets ``metadata`` or ``__table__``.

For code using the SQLAlchemy 1.x API, customize this model by subclassing
:class:`.Model` and passing the ``model_class`` parameter to the extension.
A fully created declarative model class can be
passed as well, to use a custom metaclass.

For code using the SQLAlchemy 2.x API, customize this model by subclassing
:class:`sqlalchemy.orm.DeclarativeBase` or
:class:`sqlalchemy.orm.DeclarativeBaseNoMeta`
and passing the ``model_class`` parameter to the extension.
"""A SQLAlchemy declarative model class. This private model class is returned
by ``_make_declarative_base``.

At run time, this class is the same as ``SQLAlchemy.Model``. Accessing
``SQLAlchemy.Model`` rather than this class is more recommended because
``SQLAlchemy.Model`` can provide better type hints.

:meta private:
"""

if engine_options is None:
Expand All @@ -277,6 +369,31 @@ def __init__(
if app is not None:
self.init_app(app)

# Need to be placed after __init__ because __init__ takes a default value
# named `Model`.
Model = ModelGetter()
"""A SQLAlchemy declarative model class. Subclass this to define database
models.

If a model does not set ``__tablename__``, it will be generated by converting
the class name from ``CamelCase`` to ``snake_case``. It will not be generated
if the model looks like it uses single-table inheritance.

If a model or parent class sets ``__bind_key__``, it will use that metadata and
database engine. Otherwise, it will use the default :attr:`metadata` and
:attr:`engine`. This is ignored if the model sets ``metadata`` or ``__table__``.

For code using the SQLAlchemy 1.x API, customize this model by subclassing
:class:`.Model` and passing the ``model_class`` parameter to the extension.
A fully created declarative model class can be
passed as well, to use a custom metaclass.

For code using the SQLAlchemy 2.x API, customize this model by subclassing
:class:`sqlalchemy.orm.DeclarativeBase` or
:class:`sqlalchemy.orm.DeclarativeBaseNoMeta`
and passing the ``model_class`` parameter to the extension.
"""

def __repr__(self) -> str:
if not has_app_context():
return f"<{type(self).__name__}>"
Expand Down Expand Up @@ -534,7 +651,7 @@ def _make_declarative_base(
``model`` can be an already created declarative model class.
"""
model: type[_FSAModel]
declarative_bases = _get_2x_declarative_bases(model_class)
declarative_bases = _get_2x_declarative_bases(t.cast(t.Any, model_class))
if len(declarative_bases) > 1:
# raise error if more than one declarative base is found
raise ValueError(
Expand All @@ -547,11 +664,14 @@ def _make_declarative_base(
mixin_classes = [BindMixin, NameMixin, Model]
if disable_autonaming:
mixin_classes.remove(NameMixin)
model = types.new_class(
"FlaskSQLAlchemyBase",
(*mixin_classes, *model_class.__bases__),
{"metaclass": type(declarative_bases[0])},
lambda ns: ns.update(body),
model = t.cast(
t.Type[_FSAModel],
types.new_class(
"FlaskSQLAlchemyBase",
(*mixin_classes, *model_class.__bases__),
{"metaclass": type(declarative_bases[0])},
lambda ns: ns.update(body),
),
)
elif not isinstance(model_class, sa_orm.DeclarativeMeta):
metadata = self._make_metadata(None)
Expand Down
6 changes: 3 additions & 3 deletions src/flask_sqlalchemy/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ class Model:
already created declarative model class as ``model_class``.
"""

__fsa__: t.ClassVar[SQLAlchemy]
__fsa__: t.ClassVar[SQLAlchemy[t.Any]]
"""Internal reference to the extension object.

:meta private:
Expand Down Expand Up @@ -73,7 +73,7 @@ class BindMetaMixin(type):
directly on the child model.
"""

__fsa__: SQLAlchemy
__fsa__: SQLAlchemy[t.Any]
metadata: sa.MetaData

def __init__(
Expand Down Expand Up @@ -104,7 +104,7 @@ class BindMixin:
.. versionchanged:: 3.1.0
"""

__fsa__: SQLAlchemy
__fsa__: SQLAlchemy[t.Any]
metadata: sa.MetaData

@classmethod
Expand Down
2 changes: 1 addition & 1 deletion src/flask_sqlalchemy/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ class Session(sa_orm.Session):
Renamed from ``SignallingSession``.
"""

def __init__(self, db: SQLAlchemy, **kwargs: t.Any) -> None:
def __init__(self, db: SQLAlchemy[t.Any], **kwargs: t.Any) -> None:
super().__init__(**kwargs)
self._db = db
self._model_changes: dict[object, tuple[t.Any, str]] = {}
Expand Down
4 changes: 2 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def app_ctx(app: Flask) -> t.Generator[AppContext, None, None]:


@pytest.fixture(params=test_classes)
def db(app: Flask, request: pytest.FixtureRequest) -> SQLAlchemy:
def db(app: Flask, request: pytest.FixtureRequest) -> SQLAlchemy[t.Any]:
if request.param is not Model:
return SQLAlchemy(app, model_class=types.new_class(*request.param))
else:
Expand All @@ -79,7 +79,7 @@ def model_class(request: pytest.FixtureRequest) -> t.Any:


@pytest.fixture
def Todo(app: Flask, db: SQLAlchemy) -> t.Generator[t.Any, None, None]:
def Todo(app: Flask, db: SQLAlchemy[t.Any]) -> t.Generator[t.Any, None, None]:
if issubclass(db.Model, (sa_orm.MappedAsDataclass)):

class Todo(db.Model):
Expand Down
2 changes: 1 addition & 1 deletion tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@


@pytest.mark.usefixtures("app_ctx")
def test_shell_context(db: SQLAlchemy, Todo: t.Any) -> None:
def test_shell_context(db: SQLAlchemy[t.Any], Todo: t.Any) -> None:
context = add_models_to_shell()
assert context["db"] is db
assert context["Todo"] is Todo
2 changes: 1 addition & 1 deletion tests/test_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from flask_sqlalchemy import SQLAlchemy


def test_default_engine(app: Flask, db: SQLAlchemy) -> None:
def test_default_engine(app: Flask, db: SQLAlchemy[t.Any]) -> None:
with app.app_context():
assert db.engine is db.engines[None]

Expand Down
2 changes: 1 addition & 1 deletion tests/test_extension_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@


@pytest.mark.usefixtures("app_ctx")
def test_get_or_404(db: SQLAlchemy, Todo: t.Any) -> None:
def test_get_or_404(db: SQLAlchemy[t.Any], Todo: t.Any) -> None:
item = Todo()
db.session.add(item)
db.session.commit()
Expand Down
11 changes: 6 additions & 5 deletions tests/test_extension_repr.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@
from flask import Flask

from flask_sqlalchemy import SQLAlchemy
from flask_sqlalchemy.model import Model


def test_repr_no_context() -> None:
db = SQLAlchemy()
db: SQLAlchemy[type[Model]] = SQLAlchemy()
app = Flask(__name__)
app.config["SQLALCHEMY_DATABASE_URI"] = "sqlite://"

Expand All @@ -15,7 +16,7 @@ def test_repr_no_context() -> None:


def test_repr_default() -> None:
db = SQLAlchemy()
db: SQLAlchemy[type[Model]] = SQLAlchemy()
app = Flask(__name__)
app.config["SQLALCHEMY_DATABASE_URI"] = "sqlite://"

Expand All @@ -25,7 +26,7 @@ def test_repr_default() -> None:


def test_repr_default_plustwo() -> None:
db = SQLAlchemy()
db: SQLAlchemy[type[Model]] = SQLAlchemy()
app = Flask(__name__)
app.config["SQLALCHEMY_DATABASE_URI"] = "sqlite://"
app.config["SQLALCHEMY_BINDS"] = {
Expand All @@ -39,7 +40,7 @@ def test_repr_default_plustwo() -> None:


def test_repr_nodefault() -> None:
db = SQLAlchemy()
db: SQLAlchemy[type[Model]] = SQLAlchemy()
app = Flask(__name__)
app.config["SQLALCHEMY_BINDS"] = {"x": "sqlite:///:memory:"}

Expand All @@ -49,7 +50,7 @@ def test_repr_nodefault() -> None:


def test_repr_nodefault_plustwo() -> None:
db = SQLAlchemy()
db: SQLAlchemy[type[Model]] = SQLAlchemy()
app = Flask(__name__)
app.config["SQLALCHEMY_BINDS"] = {
"a": "sqlite:///:memory:",
Expand Down
Loading