diff --git a/src/flask_sqlalchemy/cli.py b/src/flask_sqlalchemy/cli.py index d7d7e4be..9f13a280 100644 --- a/src/flask_sqlalchemy/cli.py +++ b/src/flask_sqlalchemy/cli.py @@ -2,15 +2,17 @@ import typing as t +import sqlalchemy as sa from flask import current_app def add_models_to_shell() -> dict[str, t.Any]: """Registered with :meth:`~flask.Flask.shell_context_processor` if ``add_models_to_shell`` is enabled. Adds the ``db`` instance and all model classes - to ``flask shell``. + to ``flask shell``. Adds the ``sqlalchemy`` namespace as ``sa`` to ``flask shell``. """ db = current_app.extensions["sqlalchemy"] out = {m.class_.__name__: m.class_ for m in db.Model._sa_registry.mappers} + out["sa"] = sa out["db"] = db return out diff --git a/tests/test_cli.py b/tests/test_cli.py index 91672733..8d2f16ca 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -11,5 +11,6 @@ @pytest.mark.usefixtures("app_ctx") def test_shell_context(db: SQLAlchemy, Todo: t.Any) -> None: context = add_models_to_shell() + assert "sa" in context assert context["db"] is db assert context["Todo"] is Todo