From 6b8e4c405471be1de91a28a4792987ac2da47111 Mon Sep 17 00:00:00 2001 From: Petter Friberg Date: Sun, 17 Sep 2023 19:29:43 +0200 Subject: [PATCH 1/2] Add better support for `ManyToManyField`'s `through` model --- django-stubs/db/models/fields/related.pyi | 18 +- .../db/models/fields/related_descriptors.pyi | 70 ++++-- .../db/models/fields/reverse_related.pyi | 4 +- mypy_django_plugin/django/context.py | 24 +- mypy_django_plugin/lib/fullnames.py | 1 - mypy_django_plugin/lib/helpers.py | 65 ++++- mypy_django_plugin/transformers/fields.py | 5 + mypy_django_plugin/transformers/manytomany.py | 144 +++++++++++ mypy_django_plugin/transformers/models.py | 224 ++++++++++++++++-- scripts/stubtest/allowlist.txt | 3 + tests/typecheck/fields/test_related.yml | 177 +++++++++++++- .../typecheck/models/test_contrib_models.yml | 6 + tests/typecheck/test_request.yml | 1 + 13 files changed, 678 insertions(+), 64 deletions(-) create mode 100644 mypy_django_plugin/transformers/manytomany.py diff --git a/django-stubs/db/models/fields/related.pyi b/django-stubs/db/models/fields/related.pyi index f8fde239c..7f09d3373 100644 --- a/django-stubs/db/models/fields/related.pyi +++ b/django-stubs/db/models/fields/related.pyi @@ -1,5 +1,5 @@ from collections.abc import Callable, Iterable, Sequence -from typing import Any, Literal, TypeVar, overload +from typing import Any, Generic, Literal, TypeVar, overload from uuid import UUID from django.core import validators # due to weird mypy.stubtest error @@ -11,6 +11,7 @@ from django.db.models.fields.related_descriptors import ForwardManyToOneDescript from django.db.models.fields.related_descriptors import ( # noqa: F401 ForwardOneToOneDescriptor as ForwardOneToOneDescriptor, ) +from django.db.models.fields.related_descriptors import ManyRelatedManager from django.db.models.fields.related_descriptors import ManyToManyDescriptor as ManyToManyDescriptor from django.db.models.fields.related_descriptors import ReverseManyToOneDescriptor as ReverseManyToOneDescriptor from django.db.models.fields.related_descriptors import ReverseOneToOneDescriptor as ReverseOneToOneDescriptor @@ -18,7 +19,6 @@ from django.db.models.fields.reverse_related import ForeignObjectRel as ForeignO from django.db.models.fields.reverse_related import ManyToManyRel as ManyToManyRel from django.db.models.fields.reverse_related import ManyToOneRel as ManyToOneRel from django.db.models.fields.reverse_related import OneToOneRel as OneToOneRel -from django.db.models.manager import RelatedManager from django.db.models.query_utils import FilteredRelation, PathInfo, Q from django.utils.functional import _StrOrPromise from typing_extensions import Self @@ -27,6 +27,7 @@ RECURSIVE_RELATIONSHIP_CONSTANT: Literal["self"] def resolve_relation(scope_model: type[Model], relation: str | type[Model]) -> str | type[Model]: ... +_M = TypeVar("_M", bound=Model) # __set__ value type _ST = TypeVar("_ST") # __get__ return type @@ -204,10 +205,9 @@ class OneToOneField(ForeignKey[_ST, _GT]): @overload def __get__(self, instance: Any, owner: Any) -> Self: ... -class ManyToManyField(RelatedField[_ST, _GT]): - _pyi_private_set_type: Sequence[Any] - _pyi_private_get_type: RelatedManager[Any] +_To = TypeVar("_To", bound=Model) +class ManyToManyField(RelatedField[Any, Any], Generic[_To, _M]): description: str has_null_arg: bool swappable: bool @@ -221,12 +221,12 @@ class ManyToManyField(RelatedField[_ST, _GT]): rel_class: type[ManyToManyRel] def __init__( self, - to: type[Model] | str, + to: type[_To] | str, related_name: str | None = ..., related_query_name: str | None = ..., limit_choices_to: _AllLimitChoicesTo | None = ..., symmetrical: bool | None = ..., - through: str | type[Model] | None = ..., + through: type[_M] | str | None = ..., through_fields: tuple[str, str] | None = ..., db_constraint: bool = ..., db_table: str | None = ..., @@ -255,10 +255,10 @@ class ManyToManyField(RelatedField[_ST, _GT]): ) -> None: ... # class access @overload - def __get__(self, instance: None, owner: Any) -> ManyToManyDescriptor[Self]: ... + def __get__(self, instance: None, owner: Any) -> ManyToManyDescriptor[_M]: ... # Model instance access @overload - def __get__(self, instance: Model, owner: Any) -> _GT: ... + def __get__(self, instance: Model, owner: Any) -> ManyRelatedManager[_To]: ... # non-Model instances @overload def __get__(self, instance: Any, owner: Any) -> Self: ... diff --git a/django-stubs/db/models/fields/related_descriptors.pyi b/django-stubs/db/models/fields/related_descriptors.pyi index e02b2d67d..ccabc4f36 100644 --- a/django-stubs/db/models/fields/related_descriptors.pyi +++ b/django-stubs/db/models/fields/related_descriptors.pyi @@ -1,16 +1,17 @@ -from collections.abc import Callable -from typing import Any, Generic, TypeVar, overload +from collections.abc import Callable, Iterable +from typing import Any, Generic, NoReturn, TypeVar, overload from django.core.exceptions import ObjectDoesNotExist from django.db.models.base import Model from django.db.models.fields import Field -from django.db.models.fields.related import ForeignKey, RelatedField +from django.db.models.fields.related import ForeignKey, ManyToManyField, RelatedField from django.db.models.fields.reverse_related import ManyToManyRel, ManyToOneRel, OneToOneRel -from django.db.models.manager import RelatedManager +from django.db.models.manager import BaseManager, RelatedManager from django.db.models.query import QuerySet from django.db.models.query_utils import DeferredAttribute +from typing_extensions import Self -_T = TypeVar("_T") +_M = TypeVar("_M", bound=Model) _F = TypeVar("_F", bound=Field) _From = TypeVar("_From", bound=Model) _To = TypeVar("_To", bound=Model) @@ -65,28 +66,63 @@ class ReverseOneToOneDescriptor(Generic[_From, _To]): def __reduce__(self) -> tuple[Callable[..., Any], tuple[type[_To], str]]: ... class ReverseManyToOneDescriptor: + """ + In the example:: + + class Child(Model): + parent = ForeignKey(Parent, related_name='children') + + ``Parent.children`` is a ``ReverseManyToOneDescriptor`` instance. + """ + rel: ManyToOneRel field: ForeignKey def __init__(self, rel: ManyToOneRel) -> None: ... @property - def related_manager_cls(self) -> type[RelatedManager]: ... - def __get__(self, instance: Model | None, cls: type[Model] | None = ...) -> ReverseManyToOneDescriptor: ... - def __set__(self, instance: Model, value: list[Model]) -> Any: ... + def related_manager_cls(self) -> type[RelatedManager[Any]]: ... + @overload + def __get__(self, instance: None, cls: Any = ...) -> Self: ... + @overload + def __get__(self, instance: Model, cls: Any = ...) -> type[RelatedManager[Any]]: ... + def __set__(self, instance: Any, value: Any) -> NoReturn: ... + +def create_reverse_many_to_one_manager( + superclass: type[BaseManager[_M]], rel: ManyToOneRel +) -> type[RelatedManager[_M]]: ... -def create_reverse_many_to_one_manager(superclass: type, rel: Any) -> type[RelatedManager]: ... +class ManyToManyDescriptor(ReverseManyToOneDescriptor, Generic[_M]): + """ + In the example:: + + class Pizza(Model): + toppings = ManyToManyField(Topping, related_name='pizzas') + + ``Pizza.toppings`` and ``Topping.pizzas`` are ``ManyToManyDescriptor`` + instances. + """ -class ManyToManyDescriptor(ReverseManyToOneDescriptor, Generic[_F]): - field: _F # type: ignore[assignment] + # 'field' here is 'rel.field' rel: ManyToManyRel # type: ignore[assignment] + field: ManyToManyField[Any, _M] # type: ignore[assignment] reverse: bool def __init__(self, rel: ManyToManyRel, reverse: bool = ...) -> None: ... @property - def through(self) -> type[Model]: ... + def through(self) -> type[_M]: ... @property - def related_manager_cls(self) -> type[Any]: ... # ManyRelatedManager + def related_manager_cls(self) -> type[ManyRelatedManager[Any]]: ... # type: ignore[override] -# fake -class _ForwardManyToManyManager(Generic[_T]): - def all(self) -> QuerySet: ... +class ManyRelatedManager(BaseManager[_M], Generic[_M]): + related_val: tuple[int, ...] + def add(self, *objs: _M | int, bulk: bool = ...) -> None: ... + async def aadd(self, *objs: _M | int, bulk: bool = ...) -> None: ... + def remove(self, *objs: _M | int, bulk: bool = ...) -> None: ... + async def aremove(self, *objs: _M | int, bulk: bool = ...) -> None: ... + def set(self, objs: QuerySet[_M] | Iterable[_M | int], *, bulk: bool = ..., clear: bool = ...) -> None: ... + async def aset(self, objs: QuerySet[_M] | Iterable[_M | int], *, bulk: bool = ..., clear: bool = ...) -> None: ... + def clear(self) -> None: ... + async def aclear(self) -> None: ... + def __call__(self, *, manager: str) -> ManyRelatedManager[_M]: ... -def create_forward_many_to_many_manager(superclass: type, rel: Any, reverse: Any) -> _ForwardManyToManyManager: ... +def create_forward_many_to_many_manager( + superclass: type[BaseManager[_M]], rel: ManyToManyRel, reverse: bool +) -> type[ManyRelatedManager[_M]]: ... diff --git a/django-stubs/db/models/fields/reverse_related.pyi b/django-stubs/db/models/fields/reverse_related.pyi index 21bfeb85f..15410f17c 100644 --- a/django-stubs/db/models/fields/reverse_related.pyi +++ b/django-stubs/db/models/fields/reverse_related.pyi @@ -112,13 +112,13 @@ class OneToOneRel(ManyToOneRel): ) -> None: ... class ManyToManyRel(ForeignObjectRel): - field: ManyToManyField # type: ignore[assignment] + field: ManyToManyField[Any, Any] # type: ignore[assignment] through: type[Model] | None through_fields: tuple[str, str] | None db_constraint: bool def __init__( self, - field: ManyToManyField, + field: ManyToManyField[Any, Any], to: type[Model] | str, related_name: str | None = ..., related_query_name: str | None = ..., diff --git a/mypy_django_plugin/django/context.py b/mypy_django_plugin/django/context.py index 79d5e8fdd..c6c13c381 100644 --- a/mypy_django_plugin/django/context.py +++ b/mypy_django_plugin/django/context.py @@ -3,7 +3,21 @@ from collections import defaultdict from contextlib import contextmanager from functools import cached_property -from typing import TYPE_CHECKING, Any, Dict, Iterable, Iterator, Literal, Optional, Sequence, Set, Tuple, Type, Union +from typing import ( + TYPE_CHECKING, + Any, + Dict, + Iterable, + Iterator, + Literal, + Mapping, + Optional, + Sequence, + Set, + Tuple, + Type, + Union, +) from django.core.exceptions import FieldDoesNotExist, FieldError from django.db import models @@ -270,6 +284,14 @@ def all_registered_model_classes(self) -> Set[Type[models.Model]]: def all_registered_model_class_fullnames(self) -> Set[str]: return {helpers.get_class_fullname(cls) for cls in self.all_registered_model_classes} + @cached_property + def model_class_fullnames_by_label(self) -> Mapping[str, str]: + return { + klass._meta.label: helpers.get_class_fullname(klass) + for klass in self.all_registered_model_classes + if klass is not models.Model + } + def get_field_nullability(self, field: Union["Field[Any, Any]", ForeignObjectRel], method: Optional[str]) -> bool: if method in ("values", "values_list"): return field.null diff --git a/mypy_django_plugin/lib/fullnames.py b/mypy_django_plugin/lib/fullnames.py index 5e9c8d4e9..261fd4a47 100644 --- a/mypy_django_plugin/lib/fullnames.py +++ b/mypy_django_plugin/lib/fullnames.py @@ -38,7 +38,6 @@ FOREIGN_OBJECT_FULLNAME, FOREIGN_KEY_FULLNAME, ONETOONE_FIELD_FULLNAME, - MANYTOMANY_FIELD_FULLNAME, ) ) diff --git a/mypy_django_plugin/lib/helpers.py b/mypy_django_plugin/lib/helpers.py index 9f5289970..e90498e9b 100644 --- a/mypy_django_plugin/lib/helpers.py +++ b/mypy_django_plugin/lib/helpers.py @@ -10,8 +10,10 @@ from mypy.nodes import ( GDEF, MDEF, + AssignmentStmt, Block, ClassDef, + Context, Expression, MemberExpr, MypyFile, @@ -33,7 +35,8 @@ SemanticAnalyzerPluginInterface, ) from mypy.semanal import SemanticAnalyzer -from mypy.types import AnyType, Instance, NoneTyp, TupleType, TypedDictType, TypeOfAny, UnionType +from mypy.semanal_shared import parse_bool +from mypy.types import AnyType, Instance, LiteralType, NoneTyp, TupleType, TypedDictType, TypeOfAny, UnionType from mypy.types import Type as MypyType from typing_extensions import TypedDict @@ -45,12 +48,14 @@ class DjangoTypeMetadata(TypedDict, total=False): + is_abstract_model: bool from_queryset_manager: str reverse_managers: Dict[str, str] baseform_bases: Dict[str, int] manager_bases: Dict[str, int] model_bases: Dict[str, int] queryset_bases: Dict[str, int] + m2m_throughs: Dict[str, str] def get_django_metadata(model_info: TypeInfo) -> DjangoTypeMetadata: @@ -385,3 +390,61 @@ def add_new_manager_base(api: SemanticAnalyzerPluginInterface, fullname: str) -> if sym is not None and isinstance(sym.node, TypeInfo): bases = get_django_metadata_bases(sym.node, "manager_bases") bases[fullname] = 1 + + +def is_abstract_model(model: TypeInfo) -> bool: + if model.metaclass_type is None or model.metaclass_type.type.fullname != fullnames.MODEL_METACLASS_FULLNAME: + return False + + metadata = get_django_metadata(model) + if metadata.get("is_abstract_model") is not None: + return metadata["is_abstract_model"] + + meta = model.names.get("Meta") + # Check if 'abstract' is declared in this model's 'class Meta' as + # 'abstract = True' won't be inherited from a parent model. + if meta is not None and isinstance(meta.node, TypeInfo) and "abstract" in meta.node.names: + for stmt in meta.node.defn.defs.body: + if ( + # abstract = + isinstance(stmt, AssignmentStmt) + and len(stmt.lvalues) == 1 + and isinstance(stmt.lvalues[0], NameExpr) + and stmt.lvalues[0].name == "abstract" + ): + # abstract = True (builtins.bool) + rhs_is_true = parse_bool(stmt.rvalue) is True + # abstract: Literal[True] + is_literal_true = isinstance(stmt.type, LiteralType) and stmt.type.value is True + metadata["is_abstract_model"] = rhs_is_true or is_literal_true + return metadata["is_abstract_model"] + + metadata["is_abstract_model"] = False + return False + + +def resolve_lazy_reference( + reference: str, *, api: Union[TypeChecker, SemanticAnalyzer], django_context: "DjangoContext", ctx: Context +) -> Optional[TypeInfo]: + """ + Attempts to resolve a lazy reference(e.g. ".") to a + 'TypeInfo' instance. + """ + if "." not in reference: + # -- needs prefix of . We can't implicitly solve + # what app label this should be, yet. + return None + + # Reference conforms to the structure of a lazy reference: '.' + fullname = django_context.model_class_fullnames_by_label.get(reference) + if fullname is not None: + model_info = lookup_fully_qualified_typeinfo(api, fullname) + if model_info is not None: + return model_info + elif isinstance(api, SemanticAnalyzer) and not api.final_iteration: + # Getting this far, where Django matched the reference but we still can't + # find it, we want to defer + api.defer() + else: + api.fail("Could not match lazy reference with any model", ctx) + return None diff --git a/mypy_django_plugin/transformers/fields.py b/mypy_django_plugin/transformers/fields.py index ff8c09646..15a8dc43a 100644 --- a/mypy_django_plugin/transformers/fields.py +++ b/mypy_django_plugin/transformers/fields.py @@ -13,6 +13,7 @@ from mypy_django_plugin.django.context import DjangoContext from mypy_django_plugin.exceptions import UnregisteredModelError from mypy_django_plugin.lib import fullnames, helpers +from mypy_django_plugin.transformers import manytomany if TYPE_CHECKING: from django.contrib.contenttypes.fields import GenericForeignKey @@ -213,6 +214,10 @@ def transform_into_proper_return_type(ctx: FunctionContext, django_context: Djan assert isinstance(outer_model_info, TypeInfo) + if default_return_type.type.has_base(fullnames.MANYTOMANY_FIELD_FULLNAME): + return manytomany.fill_model_args_for_many_to_many_field( + ctx=ctx, model_info=outer_model_info, default_return_type=default_return_type, django_context=django_context + ) if helpers.has_any_of_bases(default_return_type.type, fullnames.RELATED_FIELDS_CLASSES): return fill_descriptor_types_for_related_field(ctx, django_context) diff --git a/mypy_django_plugin/transformers/manytomany.py b/mypy_django_plugin/transformers/manytomany.py new file mode 100644 index 000000000..82726d7ca --- /dev/null +++ b/mypy_django_plugin/transformers/manytomany.py @@ -0,0 +1,144 @@ +from typing import NamedTuple, Optional, Union + +from mypy.checker import TypeChecker +from mypy.nodes import AssignmentStmt, Expression, NameExpr, StrExpr, TypeInfo +from mypy.plugin import FunctionContext +from mypy.semanal import SemanticAnalyzer +from mypy.types import Instance, ProperType, UninhabitedType +from mypy.types import Type as MypyType + +from mypy_django_plugin.django.context import DjangoContext +from mypy_django_plugin.lib import fullnames, helpers + + +class M2MThrough(NamedTuple): + arg: Optional[Expression] + model: ProperType + + +class M2MTo(NamedTuple): + arg: Expression + model: ProperType + self: bool # ManyToManyField('self', ...) + + +class M2MArguments(NamedTuple): + to: M2MTo + through: Optional[M2MThrough] + + +def fill_model_args_for_many_to_many_field( + *, + ctx: FunctionContext, + model_info: TypeInfo, + default_return_type: Instance, + django_context: DjangoContext, +) -> Instance: + if not ctx.args or not ctx.args[0] or len(default_return_type.args) < 2: + return default_return_type + + args = get_m2m_arguments(ctx=ctx, model_info=model_info, django_context=django_context) + if args is None: + return default_return_type + + to_arg: MypyType + if isinstance(default_return_type.args[0], UninhabitedType): + to_arg = args.to.model + else: + # Avoid overwriting a decent 'to' argument + to_arg = default_return_type.args[0] + + if isinstance(default_return_type.args[1], UninhabitedType): + if helpers.is_abstract_model(model_info): + # Many to many on abstract models doesn't create any implicit, concrete + # through model, so we populate it with the upper bound to avoid error messages + through_arg = default_return_type.type.defn.type_vars[1].upper_bound + elif args.through is None: + through_arg = default_return_type.args[1] + else: + through_arg = args.through.model + else: + # Avoid overwriting a decent 'through' argument + through_arg = default_return_type.args[1] + + return default_return_type.copy_modified(args=[to_arg, through_arg]) + + +def get_m2m_arguments( + *, + ctx: FunctionContext, + model_info: TypeInfo, + django_context: DjangoContext, +) -> Optional[M2MArguments]: + checker = helpers.get_typechecker_api(ctx) + to_arg = ctx.args[0][0] + to_model: Optional[ProperType] + if isinstance(to_arg, StrExpr) and to_arg.value == "self": + to_model = Instance(model_info, []) + to_self = True + else: + to_model = get_model_from_expression(to_arg, api=checker, django_context=django_context) + to_self = False + + if to_model is None: + # 'ManyToManyField()' requires the 'to' argument + return None + to = M2MTo(arg=to_arg, model=to_model, self=to_self) + + through = None + if len(ctx.args) > 5 and ctx.args[5]: + # 'ManyToManyField(..., through=)' was called + through_arg = ctx.args[5][0] + through_model = get_model_from_expression(through_arg, api=checker, django_context=django_context) + if through_model is not None: + through = M2MThrough(arg=through_arg, model=through_model) + elif not helpers.is_abstract_model(model_info): + # No explicit 'through' argument was provided and model is concrete. We need + # to dig up any generated through model for this 'ManyToManyField()' field + through_arg = None + m2m_throughs = helpers.get_django_metadata(model_info).get("m2m_throughs", {}) + if m2m_throughs: + field_name = None + for defn in model_info.defn.defs.body: + if ( + isinstance(defn, AssignmentStmt) + and defn.rvalue is ctx.context + and len(defn.lvalues) == 1 + and isinstance(defn.lvalues[0], NameExpr) + ): + field_name = defn.lvalues[0].name + break + + if field_name is not None: + through_model_fullname = m2m_throughs.get(field_name) + if through_model_fullname is not None: + through_model_info = helpers.lookup_fully_qualified_typeinfo(checker, through_model_fullname) + if through_model_info is not None: + through = M2MThrough(arg=through_arg, model=Instance(through_model_info, [])) + + return M2MArguments(to=to, through=through) + + +def get_model_from_expression( + expr: Expression, + *, + api: Union[TypeChecker, SemanticAnalyzer], + django_context: DjangoContext, +) -> Optional[ProperType]: + """ + Attempts to resolve an expression to a 'TypeInfo' instance. Any lazy reference + argument(e.g. ".") to a Django model is also attempted. + """ + # TODO: Handle settings.AUTH_USER_MODEL? + if isinstance(expr, NameExpr) and isinstance(expr.node, TypeInfo): + if ( + expr.node.metaclass_type is not None + and expr.node.metaclass_type.type.fullname == fullnames.MODEL_METACLASS_FULLNAME + ): + return Instance(expr.node, []) + elif isinstance(expr, StrExpr): + model_info = helpers.resolve_lazy_reference(expr.value, api=api, django_context=django_context, ctx=expr) + if model_info is not None: + return Instance(model_info, []) + + return None diff --git a/mypy_django_plugin/transformers/models.py b/mypy_django_plugin/transformers/models.py index 909f5dc50..86171aa5c 100644 --- a/mypy_django_plugin/transformers/models.py +++ b/mypy_django_plugin/transformers/models.py @@ -12,7 +12,10 @@ AssignmentStmt, CallExpr, Context, + Expression, NameExpr, + RefExpr, + StrExpr, SymbolTableNode, TypeInfo, Var, @@ -21,7 +24,7 @@ from mypy.plugins import common from mypy.semanal import SemanticAnalyzer from mypy.typeanal import TypeAnalyser -from mypy.types import AnyType, Instance, LiteralType, ProperType, TypedDictType, TypeOfAny, TypeType, get_proper_type +from mypy.types import AnyType, Instance, ProperType, TypedDictType, TypeOfAny, TypeType, get_proper_type from mypy.types import Type as MypyType from mypy.typevars import fill_typevars @@ -30,12 +33,12 @@ from mypy_django_plugin.exceptions import UnregisteredModelError from mypy_django_plugin.lib import fullnames, helpers from mypy_django_plugin.lib.fullnames import ANNOTATIONS_FULLNAME, ANY_ATTR_ALLOWED_CLASS_FULLNAME, MODEL_CLASS_FULLNAME -from mypy_django_plugin.transformers import fields from mypy_django_plugin.transformers.fields import get_field_descriptor_types from mypy_django_plugin.transformers.managers import ( MANAGER_METHODS_RETURNING_QUERYSET, create_manager_info_from_from_queryset_call, ) +from mypy_django_plugin.transformers.manytomany import M2MArguments, M2MThrough, M2MTo, get_model_from_expression class ModelClassInitializer: @@ -47,6 +50,10 @@ def __init__(self, ctx: ClassDefContext, django_context: DjangoContext) -> None: self.django_context = django_context self.ctx = ctx + @property + def is_model_abstract(self) -> bool: + return helpers.is_abstract_model(self.model_classdef.info) + def lookup_typeinfo(self, fullname: str) -> Optional[TypeInfo]: return helpers.lookup_fully_qualified_typeinfo(self.api, fullname) @@ -233,7 +240,7 @@ def create_autofield( auto_field_fullname = helpers.get_class_fullname(auto_field.__class__) auto_field_info = self.lookup_typeinfo_or_incomplete_defn_error(auto_field_fullname) - set_type, get_type = fields.get_field_descriptor_types( + set_type, get_type = get_field_descriptor_types( auto_field_info, is_set_nullable=True, is_get_nullable=False, @@ -613,6 +620,195 @@ def run_with_model_cls(self, model_cls: Type[Model]) -> None: self.add_new_node_to_model_class("_meta", Instance(options_info, [Instance(self.model_classdef.info, [])])) +class ProcessManyToManyFields(ModelClassInitializer): + """ + Processes 'ManyToManyField()' fields and generates any implicit through tables that + Django also generates. It won't do anything if the model is abstract or for fields + where an explicit 'through' argument has been passed. + """ + + def run(self) -> None: + if self.is_model_abstract: + # TODO: Create abstract through models? + return + + # Start out by prefetching a couple of dependencies needed to be able to declare any + # new, implicit, through model class. + model_base = self.lookup_typeinfo(fullnames.MODEL_CLASS_FULLNAME) + fk_field = self.lookup_typeinfo(fullnames.FOREIGN_KEY_FULLNAME) + manager_info = self.lookup_typeinfo(fullnames.MANAGER_CLASS_FULLNAME) + if model_base is None or fk_field is None or manager_info is None: + raise helpers.IncompleteDefnException() + + from_pk = self.get_pk_instance(self.model_classdef.info) + fk_set_type, fk_get_type = get_field_descriptor_types(fk_field, is_set_nullable=False, is_get_nullable=False) + + for defn in self.model_classdef.defs.body: + # Check if this part of the class body is an assignment from a 'ManyToManyField' call + # = ManyToManyField(...) + if ( + isinstance(defn, AssignmentStmt) + and len(defn.lvalues) == 1 + and isinstance(defn.lvalues[0], NameExpr) + and isinstance(defn.rvalue, CallExpr) + and len(defn.rvalue.args) > 0 # Need at least the 'to' argument + and isinstance(defn.rvalue.callee, RefExpr) + and isinstance(defn.rvalue.callee.node, TypeInfo) + and defn.rvalue.callee.node.has_base(fullnames.MANYTOMANY_FIELD_FULLNAME) + ): + m2m_field_name = defn.lvalues[0].name + m2m_field_symbol = self.model_classdef.info.names.get(m2m_field_name) + # The symbol referred to by the assignment expression is expected to be a variable + if m2m_field_symbol is None or not isinstance(m2m_field_symbol.node, Var): + continue + # Resolve argument information of the 'ManyToManyField(...)' call + args = self.resolve_many_to_many_arguments(defn.rvalue, context=defn) + if ( + # Ignore calls without required 'to' argument, mypy will complain + args is None + or not isinstance(args.to.model, Instance) + # Call has explicit 'through=', no need to create any implicit through table + or args.through is not None + ): + continue + + # Get the names of the implicit through model that will be generated + through_model_name = f"{self.model_classdef.name}_{m2m_field_name}" + through_model_fullname = f"{self.model_classdef.info.module_name}.{through_model_name}" + # If implicit through model is already declared there's nothing more we should do + through_model = self.lookup_typeinfo(through_model_fullname) + if through_model is not None: + continue + # Declare a new, empty, implicitly generated through model class named: '_' + through_model = self.add_new_class_for_current_module( + through_model_name, bases=[Instance(model_base, [])] + ) + # We attempt to be a bit clever here and store the generated through model's fullname in + # the metadata of the class containing the 'ManyToManyField' call expression, where its + # identifier is the field name of the 'ManyToManyField'. This would allow the containing + # model to always find the implicit through model, so that it doesn't get lost. + model_metadata = helpers.get_django_metadata(self.model_classdef.info) + model_metadata.setdefault("m2m_throughs", {}) + model_metadata["m2m_throughs"][m2m_field_name] = through_model.fullname + # Add a 'pk' symbol to the model class + helpers.add_new_sym_for_info( + through_model, name="pk", sym_type=self.default_pk_instance.copy_modified() + ) + # Add an 'id' symbol to the model class + helpers.add_new_sym_for_info( + through_model, name="id", sym_type=self.default_pk_instance.copy_modified() + ) + # Add the foreign key to the model containing the 'ManyToManyField' call: + # or from_ + from_name = ( + f"from_{self.model_classdef.name.lower()}" if args.to.self else self.model_classdef.name.lower() + ) + helpers.add_new_sym_for_info( + through_model, + name=from_name, + sym_type=Instance( + fk_field, + [ + helpers.convert_any_to_type(fk_set_type, Instance(self.model_classdef.info, [])), + helpers.convert_any_to_type(fk_get_type, Instance(self.model_classdef.info, [])), + ], + ), + ) + # Add the foreign key's '_id' field: _id or from__id + helpers.add_new_sym_for_info(through_model, name=f"{from_name}_id", sym_type=from_pk.copy_modified()) + # Add the foreign key to the model on the opposite side of the relation + # i.e. the model given as 'to' argument to the 'ManyToManyField' call: + # or to_ + to_name = f"to_{args.to.model.type.name.lower()}" if args.to.self else args.to.model.type.name.lower() + helpers.add_new_sym_for_info( + through_model, + name=to_name, + sym_type=Instance( + fk_field, + [ + helpers.convert_any_to_type(fk_set_type, args.to.model), + helpers.convert_any_to_type(fk_get_type, args.to.model), + ], + ), + ) + # Add the foreign key's '_id' field: _id or to__id + other_pk = self.get_pk_instance(args.to.model.type) + helpers.add_new_sym_for_info(through_model, name=f"{to_name}_id", sym_type=other_pk.copy_modified()) + # Add a manager named 'objects' + helpers.add_new_sym_for_info( + through_model, + name="objects", + sym_type=Instance(manager_info, [Instance(through_model, [])]), + ) + + @cached_property + def default_pk_instance(self) -> Instance: + default_pk_field = self.lookup_typeinfo(self.django_context.settings.DEFAULT_AUTO_FIELD) + if default_pk_field is None: + raise helpers.IncompleteDefnException() + return Instance( + default_pk_field, + list(get_field_descriptor_types(default_pk_field, is_set_nullable=True, is_get_nullable=False)), + ) + + def get_pk_instance(self, model: TypeInfo, /) -> Instance: + """ + Get a primary key instance of provided model's type info. If primary key can't be resolved, + return a default declaration. + """ + contains_from_pk_info = model.get_containing_type_info("pk") + if contains_from_pk_info is not None: + pk = contains_from_pk_info.names["pk"].node + if isinstance(pk, Var) and isinstance(pk.type, Instance): + return pk.type + return self.default_pk_instance + + def resolve_many_to_many_arguments(self, call: CallExpr, /, context: Context) -> Optional[M2MArguments]: + """ + Inspect a 'ManyToManyField(...)' call to collect argument data on any 'to' and + 'through' arguments. + """ + look_for: Dict[str, Optional[Expression]] = {"to": None, "through": None} + # Look for 'to', being declared as the first positional argument + if call.arg_kinds[0].is_positional(): + look_for["to"] = call.args[0] + # Look for 'through', being declared as the sixth positional argument. + if len(call.args) > 5 and call.arg_kinds[5].is_positional(): + look_for["through"] = call.args[5] + + # Sort out if any of the expected arguments was provided as keyword arguments + for pos, (arg_expr, arg_kind, arg_name) in enumerate(zip(call.args, call.arg_kinds, call.arg_names), start=1): + if arg_name in look_for and look_for[arg_name] is None: + look_for[arg_name] = arg_expr + + # 'to' is a required argument of 'ManyToManyField()', we can't do anything if it's not provided + to_arg = look_for["to"] + if to_arg is None: + return None + + # Resolve the type of the 'to' argument expression + to_model: Optional[ProperType] + if isinstance(to_arg, StrExpr) and to_arg.value == "self": + to_model = Instance(self.model_classdef.info, []) + to_self = True + else: + to_model = get_model_from_expression(to_arg, api=self.api, django_context=self.django_context) + to_self = False + if to_model is None: + return None + to = M2MTo(arg=to_arg, model=to_model, self=to_self) + + # Resolve the type of the 'through' argument expression + through_arg = look_for["through"] + through = None + if through_arg is not None: + through_model = get_model_from_expression(through_arg, api=self.api, django_context=self.django_context) + if through_model is not None: + through = M2MThrough(arg=through_arg, model=through_model) + + return M2MArguments(to=to, through=through) + + class MetaclassAdjustments(ModelClassInitializer): @classmethod def adjust_model_class(cls, ctx: ClassDefContext) -> None: @@ -662,27 +858,6 @@ def get_exception_bases(self, name: str) -> List[Instance]: return bases - @cached_property - def is_model_abstract(self) -> bool: - meta = self.model_classdef.info.names.get("Meta") - # Check if 'abstract' is declared in this model's 'class Meta' as - # 'abstract = True' won't be inherited from a parent model. - if meta is not None and isinstance(meta.node, TypeInfo) and "abstract" in meta.node.names: - for stmt in meta.node.defn.defs.body: - if ( - # abstract = - isinstance(stmt, AssignmentStmt) - and len(stmt.lvalues) == 1 - and isinstance(stmt.lvalues[0], NameExpr) - and stmt.lvalues[0].name == "abstract" - ): - # abstract = True (builtins.bool) - rhs_is_true = self.api.parse_bool(stmt.rvalue) is True - # abstract: Literal[True] - is_literal_true = isinstance(stmt.type, LiteralType) and stmt.type.value is True - return rhs_is_true or is_literal_true - return False - def add_exception_classes(self) -> None: """ Adds exception classes 'DoesNotExist' and 'MultipleObjectsReturned' to a model @@ -744,6 +919,7 @@ def process_model_class(ctx: ClassDefContext, django_context: DjangoContext) -> AddReverseLookups, AddExtraFieldMethods, AddMetaOptionsAttribute, + ProcessManyToManyFields, MetaclassAdjustments, ] for initializer_cls in initializers: diff --git a/scripts/stubtest/allowlist.txt b/scripts/stubtest/allowlist.txt index a230a84aa..864016373 100644 --- a/scripts/stubtest/allowlist.txt +++ b/scripts/stubtest/allowlist.txt @@ -17,6 +17,9 @@ django.contrib.contenttypes.migrations.* # default_storage is actually an instance of DefaultStorage, but it proxies through to a Storage django.core.files.storage.default_storage +# 'ManyRelatedManager' does exist and is declared locally, inside a function body +django.db.models.fields.related_descriptors.ManyRelatedManager + # BaseArchive abstract methods that take no argument, but typed with arguments to match the Archive and TarArchive Implementations django.utils.archive.BaseArchive.list django.utils.archive.BaseArchive.extract diff --git a/tests/typecheck/fields/test_related.yml b/tests/typecheck/fields/test_related.yml index fcee5a6d1..8f21557e1 100644 --- a/tests/typecheck/fields/test_related.yml +++ b/tests/typecheck/fields/test_related.yml @@ -278,9 +278,11 @@ - case: many_to_many_field_converts_to_queryset_of_model_type main: | from myapp.models import App, Member - reveal_type(Member().apps) # N: Revealed type is "django.db.models.manager.RelatedManager[myapp.models.App]" + reveal_type(Member().apps) # N: Revealed type is "django.db.models.fields.related_descriptors.ManyRelatedManager[myapp.models.App]" + reveal_type(Member().apps.get()) # N: Revealed type is "myapp.models.App" reveal_type(App().members) # N: Revealed type is "django.db.models.manager.RelatedManager[myapp.models.Member]" - reveal_type(Member.apps.field) # N: Revealed type is "django.db.models.fields.related.ManyToManyField[typing.Sequence[myapp.models.App], django.db.models.manager.RelatedManager[myapp.models.App]]" + reveal_type(App().members.get()) # N: Revealed type is "myapp.models.Member" + reveal_type(Member.apps.field) # N: Revealed type is "django.db.models.fields.related.ManyToManyField[Any, myapp.models.Member_apps]" # XXX the following is not correct: reveal_type(App.members) # N: Revealed type is "django.db.models.manager.RelatedManager[myapp.models.Member]" installed_apps: @@ -298,7 +300,7 @@ - case: many_to_many_works_with_string_if_imported main: | from myapp.models import Member - reveal_type(Member().apps) # N: Revealed type is "django.db.models.manager.RelatedManager[myapp2.models.App]" + reveal_type(Member().apps) # N: Revealed type is "django.db.models.fields.related_descriptors.ManyRelatedManager[myapp2.models.App]" installed_apps: - myapp - myapp2 @@ -333,7 +335,12 @@ - case: many_to_many_with_self main: | from myapp.models import User - reveal_type(User().friends) # N: Revealed type is "django.db.models.manager.RelatedManager[myapp.models.User]" + reveal_type(User().friends) # N: Revealed type is "django.db.models.fields.related_descriptors.ManyRelatedManager[myapp.models.User]" + reveal_type(User.friends.through.objects.get()) # N: Revealed type is "myapp.models.User_friends" + reveal_type(User.friends.through().from_user) # N: Revealed type is "myapp.models.User" + reveal_type(User.friends.through().from_user_id) # N: Revealed type is "builtins.int" + reveal_type(User.friends.through().to_user) # N: Revealed type is "myapp.models.User" + reveal_type(User.friends.through().to_user_id) # N: Revealed type is "builtins.int" installed_apps: - myapp files: @@ -580,7 +587,9 @@ - case: test_related_fields_returned_as_descriptors_from_model_class main: | from myapp.models import Author, Blog, Publisher, Profile - reveal_type(Author.blogs) # N: Revealed type is "django.db.models.fields.related_descriptors.ManyToManyDescriptor[django.db.models.fields.related.ManyToManyField[typing.Sequence[myapp.models.Blog], django.db.models.manager.RelatedManager[myapp.models.Blog]]]" + reveal_type(Author.blogs) # N: Revealed type is "django.db.models.fields.related_descriptors.ManyToManyDescriptor[myapp.models.Author_blogs]" + reveal_type(Author.blogs.through) # N: Revealed type is "Type[myapp.models.Author_blogs]" + reveal_type(Author().blogs) # N: Revealed type is "django.db.models.fields.related_descriptors.ManyRelatedManager[myapp.models.Blog]" reveal_type(Blog.publisher) # N: Revealed type is "django.db.models.fields.related_descriptors.ForwardManyToOneDescriptor[django.db.models.fields.related.ForeignKey[Union[myapp.models.Publisher, django.db.models.expressions.Combinable], myapp.models.Publisher]]" reveal_type(Publisher.profile) # N: Revealed type is "django.db.models.fields.related_descriptors.ForwardOneToOneDescriptor[django.db.models.fields.related.OneToOneField[Union[myapp.models.Profile, django.db.models.expressions.Combinable], myapp.models.Profile]]" reveal_type(Author.file) # N: Revealed type is "django.db.models.fields.files.FileDescriptor" @@ -917,14 +926,14 @@ from typing import TypeVar, Sequence, Union from django.db import models T = TypeVar("T", bound=models.Model) - ManyToManyFieldAlias = Union["models.ManyToManyField[Sequence[T], models.manager.RelatedManager[T]]"] + ManyToManyFieldAlias = Union["models.ManyToManyField[T, T]"] - case: callable_reverse_manager main: | from myapp.models import SalesMan sales_man = SalesMan() - reveal_type(sales_man.client) # N: Revealed type is "django.db.models.manager.RelatedManager[myapp.models.CustomUser]" - reveal_type(sales_man.client(manager="staffs")) # N: Revealed type is "django.db.models.manager.RelatedManager[myapp.models.CustomUser]" + reveal_type(sales_man.client) # N: Revealed type is "django.db.models.fields.related_descriptors.ManyRelatedManager[myapp.models.CustomUser]" + reveal_type(sales_man.client(manager="staffs")) # N: Revealed type is "django.db.models.fields.related_descriptors.ManyRelatedManager[myapp.models.CustomUser]" installed_apps: - myapp files: @@ -1051,9 +1060,159 @@ class Other(models.Model): ... - class MyModel(models.Model): first = models.OneToOneField(Other, on_delete=models.CASCADE) second = models.OneToOneField( Other, on_delete=models.CASCADE, related_name="has_explicit_name" ) + +- case: test_many_to_many + main: | + from myapp.models import MyModel, Other + reveal_type(MyModel.auto_through.through.objects.get()) + reveal_type(MyModel().auto_through.get()) + reveal_type(Other().autos.get()) + + reveal_type(MyModel.custom_through.through.objects.get()) + reveal_type(MyModel().custom_through.get()) + reveal_type(MyModel.custom_through.through.objects.custom_qs_method()) + reveal_type(Other().customs.get()) + + auto_through = MyModel.auto_through.through.objects.get() + reveal_type(auto_through.id) + reveal_type(auto_through.pk) + reveal_type(auto_through.mymodel) + reveal_type(auto_through.mymodel_id) + reveal_type(auto_through.other) + reveal_type(auto_through.other_id) + + reveal_type(MyModel.auto_through.through) + reveal_type(MyModel.auto_through.through.mymodel) + + reveal_type(MyModel.other_again.through) + out: | + main:2: note: Revealed type is "myapp.models.MyModel_auto_through" + main:3: note: Revealed type is "myapp.models.Other" + main:4: note: Revealed type is "myapp.models.MyModel" + main:6: note: Revealed type is "myapp.models.CustomThrough" + main:7: note: Revealed type is "myapp.models.Other" + main:8: note: Revealed type is "builtins.int" + main:9: note: Revealed type is "myapp.models.MyModel" + main:12: note: Revealed type is "builtins.int" + main:13: note: Revealed type is "builtins.int" + main:14: note: Revealed type is "myapp.models.MyModel" + main:15: note: Revealed type is "builtins.int" + main:16: note: Revealed type is "myapp.models.Other" + main:17: note: Revealed type is "builtins.int" + main:19: note: Revealed type is "Type[myapp.models.MyModel_auto_through]" + main:20: note: Revealed type is "django.db.models.fields.related_descriptors.ForwardManyToOneDescriptor[django.db.models.fields.related.ForeignKey[Union[myapp.models.MyModel, django.db.models.expressions.Combinable], myapp.models.MyModel]]" + main:22: note: Revealed type is "Type[myapp.models.MyModel_other_again]" + installed_apps: + - myapp + files: + - path: myapp/__init__.py + - path: myapp/models/__init__.py + content: | + from django.db import models + + class Other(models.Model): + ... + + class CustomThroughQuerySet(models.QuerySet["CustomThrough"]): + def custom_qs_method(self) -> int: + return 1 + + CustomThroughManager = models.Manager.from_queryset(CustomThroughQuerySet) + class CustomThrough(models.Model): + other = models.ForeignKey(Other, on_delete=models.CASCADE) + my_model = models.ForeignKey("myapp.MyModel", on_delete=models.CASCADE) + + objects = CustomThroughManager() + + class MyModel(models.Model): + auto_through = models.ManyToManyField(Other, related_name="autos") + # Have multiple M2Ms with implicit through + other_again = models.ManyToManyField(Other, related_name="others_again") + custom_through = models.ManyToManyField(Other, through=CustomThrough, related_name="customs") + +- case: test_many_to_many_with_lazy_references + main: | + from first.models import First + reveal_type(First().thirds.get()) + reveal_type(First.thirds.through.objects.get()) + reveal_type(First.thirds.through.objects.get().first) + reveal_type(First.thirds.through.objects.get().third) + + from third.models import Third + reveal_type(Third().fourths.get()) + reveal_type(Third.fourths.through.objects.get()) + out: | + main:2: note: Revealed type is "third.models.Third" + main:3: note: Revealed type is "second.models.Second" + main:4: note: Revealed type is "first.models.First" + main:5: note: Revealed type is "third.models.Third" + main:8: note: Revealed type is "third.models.Fourth" + main:9: note: Revealed type is "third.models.Third_fourths" + installed_apps: + - first + - second + - third + files: + - path: first/__init__.py + - path: first/models/__init__.py + content: | + from django.db import models + + class First(models.Model): + thirds = models.ManyToManyField("third.Third", through="second.Second") + - path: second/__init__.py + - path: second/models/__init__.py + content: | + from django.db import models + + class Second(models.Model): + first = models.ForeignKey("first.First", on_delete=models.CASCADE) + third = models.ForeignKey("third.Third", on_delete=models.CASCADE) + - path: third/__init__.py + - path: third/models/__init__.py + content: | + from django.db import models + + class Third(models.Model): + fourths = models.ManyToManyField("third.Fourth", blank=True) + + class Fourth(models.Model): + ... + +- case: test_many_to_many_lazy_references_with_implicit_app_label + main: | + from myapp.models import Child + reveal_type(Child.parents) + reveal_type(Child().parents) + out: | + main:2: note: Revealed type is "django.db.models.fields.related_descriptors.ManyToManyDescriptor[Any]" + main:3: note: Revealed type is "django.db.models.fields.related_descriptors.ManyRelatedManager[Any]" + myapp/models/child:5: error: Need type annotation for "parents" + myapp/models/child:6: error: Need type annotation for "other_parents" + installed_apps: + - myapp + files: + - path: myapp/__init__.py + - path: myapp/models/__init__.py + content: | + from .child import Child + from .parent import Parent + - path: myapp/models/parent.py + content: | + from django.db import models + + class Parent(models.Model): + ... + - path: myapp/models/child.py + content: | + from django.db import models + + class Child(models.Model): + # Reference without explicit app label + parents = models.ManyToManyField("Parent") + other_parents = models.ManyToManyField(to="Parent") diff --git a/tests/typecheck/models/test_contrib_models.yml b/tests/typecheck/models/test_contrib_models.yml index 9c8ed4f10..ade3f6877 100644 --- a/tests/typecheck/models/test_contrib_models.yml +++ b/tests/typecheck/models/test_contrib_models.yml @@ -13,6 +13,10 @@ reveal_type(User().last_login) # N: Revealed type is "Union[datetime.datetime, None]" reveal_type(User().is_authenticated) # N: Revealed type is "Literal[True]" reveal_type(User().is_anonymous) # N: Revealed type is "Literal[False]" + reveal_type(User().groups.get()) # N: Revealed type is "django.contrib.auth.models.Group" + reveal_type(User().user_permissions.get()) # N: Revealed type is "django.contrib.auth.models.Permission" + reveal_type(User.groups) # N: Revealed type is "django.db.models.fields.related_descriptors.ManyToManyDescriptor[django.db.models.base.Model]" + reveal_type(User.user_permissions) # N: Revealed type is "django.db.models.fields.related_descriptors.ManyToManyDescriptor[django.db.models.base.Model]" from django.contrib.auth.models import AnonymousUser reveal_type(AnonymousUser().is_authenticated) # N: Revealed type is "Literal[False]" @@ -28,6 +32,8 @@ from django.contrib.auth.models import Group reveal_type(Group().name) # N: Revealed type is "builtins.str" + reveal_type(Group().permissions.get()) # N: Revealed type is "django.contrib.auth.models.Permission" + reveal_type(Group.permissions) # N: Revealed type is "django.db.models.fields.related_descriptors.ManyToManyDescriptor[django.contrib.auth.models.Group_permissions]" - case: can_override_abstract_user_manager main: | diff --git a/tests/typecheck/test_request.yml b/tests/typecheck/test_request.yml index dd3a39d90..8a930ba4c 100644 --- a/tests/typecheck/test_request.yml +++ b/tests/typecheck/test_request.yml @@ -25,6 +25,7 @@ reveal_type(request.user) # N: Revealed type is "django.contrib.auth.models.User" if request.user.is_authenticated: reveal_type(request.user) # N: Revealed type is "django.contrib.auth.models.User" + reveal_type(request.user.groups.all().get()) # N: Revealed type is "django.contrib.auth.models.Group" custom_settings: | INSTALLED_APPS = ('django.contrib.contenttypes', 'django.contrib.auth') - case: request_object_user_without_auth_and_contenttypes_apps From cebd5255c0e1882687941c13b37f827c2f2d8f70 Mon Sep 17 00:00:00 2001 From: Petter Friberg Date: Mon, 25 Sep 2023 21:48:46 +0200 Subject: [PATCH 2/2] fixup! Add better support for `ManyToManyField`'s `through` model --- mypy_django_plugin/transformers/models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mypy_django_plugin/transformers/models.py b/mypy_django_plugin/transformers/models.py index 86171aa5c..c8cf52a68 100644 --- a/mypy_django_plugin/transformers/models.py +++ b/mypy_django_plugin/transformers/models.py @@ -777,7 +777,7 @@ def resolve_many_to_many_arguments(self, call: CallExpr, /, context: Context) -> look_for["through"] = call.args[5] # Sort out if any of the expected arguments was provided as keyword arguments - for pos, (arg_expr, arg_kind, arg_name) in enumerate(zip(call.args, call.arg_kinds, call.arg_names), start=1): + for arg_expr, arg_kind, arg_name in zip(call.args, call.arg_kinds, call.arg_names): if arg_name in look_for and look_for[arg_name] is None: look_for[arg_name] = arg_expr