Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add better support for ManyToManyField's through model #1719

Merged
merged 3 commits into from
Sep 27, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 9 additions & 9 deletions django-stubs/db/models/fields/related.pyi
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from collections.abc import Callable, Iterable, Sequence
from typing import Any, Literal, TypeVar, overload
from typing import Any, Generic, Literal, TypeVar, overload
from uuid import UUID

from django.core import validators # due to weird mypy.stubtest error
Expand All @@ -11,14 +11,14 @@ from django.db.models.fields.related_descriptors import ForwardManyToOneDescript
from django.db.models.fields.related_descriptors import ( # noqa: F401
ForwardOneToOneDescriptor as ForwardOneToOneDescriptor,
)
from django.db.models.fields.related_descriptors import ManyRelatedManager
from django.db.models.fields.related_descriptors import ManyToManyDescriptor as ManyToManyDescriptor
from django.db.models.fields.related_descriptors import ReverseManyToOneDescriptor as ReverseManyToOneDescriptor
from django.db.models.fields.related_descriptors import ReverseOneToOneDescriptor as ReverseOneToOneDescriptor
from django.db.models.fields.reverse_related import ForeignObjectRel as ForeignObjectRel # noqa: F401
from django.db.models.fields.reverse_related import ManyToManyRel as ManyToManyRel
from django.db.models.fields.reverse_related import ManyToOneRel as ManyToOneRel
from django.db.models.fields.reverse_related import OneToOneRel as OneToOneRel
from django.db.models.manager import RelatedManager
from django.db.models.query_utils import FilteredRelation, PathInfo, Q
from django.utils.functional import _StrOrPromise
from typing_extensions import Self
Expand All @@ -27,6 +27,7 @@ RECURSIVE_RELATIONSHIP_CONSTANT: Literal["self"]

def resolve_relation(scope_model: type[Model], relation: str | type[Model]) -> str | type[Model]: ...

_M = TypeVar("_M", bound=Model)
# __set__ value type
_ST = TypeVar("_ST")
# __get__ return type
Expand Down Expand Up @@ -204,10 +205,9 @@ class OneToOneField(ForeignKey[_ST, _GT]):
@overload
def __get__(self, instance: Any, owner: Any) -> Self: ...

class ManyToManyField(RelatedField[_ST, _GT]):
_pyi_private_set_type: Sequence[Any]
_pyi_private_get_type: RelatedManager[Any]
_To = TypeVar("_To", bound=Model)

class ManyToManyField(RelatedField[Any, Any], Generic[_To, _M]):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we make it generic on all the things and simplify our plugin a bit?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Which parts are you referring to?

The plugin only acts on what is _T and _M here, essentially the to= and through= argument on ManyToManyField.

It does however have to sort some stuff out when there was no explicit through= argument passed, as Django automatically generates a through model behind the scenes in that case.

Additionally it supports resolving lazy references e.g. "myapp.MyModel" and replaces either _T or _M with any found model type.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Which parts are you referring to?

What I meant to say here was: which other parts should we make generic to simplify the plugin?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_GT, _ST? Will it help?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, this changes so that those arguments aren't used at all any more, for ManyToManyField. It's just the parent class that wants them.

The generics of the descriptor class does almost all the work now instead (except when through is implicit, then the plugin does that part)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a bit hard to follow, but instead of using _GT and _ST and let our plugin do all the work, I've updated the descriptor class ManyToManyDescriptor to become generic over a model, and that will be the through= model for a ManyToManyField().

And if you now look a bit further, you find: ManyToManyDescriptor.through -> type[_M] which means that the through= argument align with the MyModel.m2m_field.through without any plugin work (again, except for special cases of lazy referencing or not passing the through= argument, which is all the new plugin code you see in this PR)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Corresponding part is the to= argument, plugin code will only be running for lazy referencing. Previously this was completely managed by the plugin, now it's instead mypy doing the work.

description: str
has_null_arg: bool
swappable: bool
Expand All @@ -221,12 +221,12 @@ class ManyToManyField(RelatedField[_ST, _GT]):
rel_class: type[ManyToManyRel]
def __init__(
self,
to: type[Model] | str,
to: type[_To] | str,
related_name: str | None = ...,
related_query_name: str | None = ...,
limit_choices_to: _AllLimitChoicesTo | None = ...,
symmetrical: bool | None = ...,
through: str | type[Model] | None = ...,
through: type[_M] | str | None = ...,
through_fields: tuple[str, str] | None = ...,
db_constraint: bool = ...,
db_table: str | None = ...,
Expand Down Expand Up @@ -255,10 +255,10 @@ class ManyToManyField(RelatedField[_ST, _GT]):
) -> None: ...
# class access
@overload
def __get__(self, instance: None, owner: Any) -> ManyToManyDescriptor[Self]: ...
def __get__(self, instance: None, owner: Any) -> ManyToManyDescriptor[_M]: ...
# Model instance access
@overload
def __get__(self, instance: Model, owner: Any) -> _GT: ...
def __get__(self, instance: Model, owner: Any) -> ManyRelatedManager[_To]: ...
# non-Model instances
@overload
def __get__(self, instance: Any, owner: Any) -> Self: ...
Expand Down
70 changes: 53 additions & 17 deletions django-stubs/db/models/fields/related_descriptors.pyi
Original file line number Diff line number Diff line change
@@ -1,16 +1,17 @@
from collections.abc import Callable
from typing import Any, Generic, TypeVar, overload
from collections.abc import Callable, Iterable
from typing import Any, Generic, NoReturn, TypeVar, overload

from django.core.exceptions import ObjectDoesNotExist
from django.db.models.base import Model
from django.db.models.fields import Field
from django.db.models.fields.related import ForeignKey, RelatedField
from django.db.models.fields.related import ForeignKey, ManyToManyField, RelatedField
from django.db.models.fields.reverse_related import ManyToManyRel, ManyToOneRel, OneToOneRel
from django.db.models.manager import RelatedManager
from django.db.models.manager import BaseManager, RelatedManager
from django.db.models.query import QuerySet
from django.db.models.query_utils import DeferredAttribute
from typing_extensions import Self

_T = TypeVar("_T")
_M = TypeVar("_M", bound=Model)
_F = TypeVar("_F", bound=Field)
_From = TypeVar("_From", bound=Model)
_To = TypeVar("_To", bound=Model)
Expand Down Expand Up @@ -65,28 +66,63 @@ class ReverseOneToOneDescriptor(Generic[_From, _To]):
def __reduce__(self) -> tuple[Callable[..., Any], tuple[type[_To], str]]: ...

class ReverseManyToOneDescriptor:
"""
In the example::

class Child(Model):
parent = ForeignKey(Parent, related_name='children')

``Parent.children`` is a ``ReverseManyToOneDescriptor`` instance.
"""

rel: ManyToOneRel
field: ForeignKey
def __init__(self, rel: ManyToOneRel) -> None: ...
@property
def related_manager_cls(self) -> type[RelatedManager]: ...
def __get__(self, instance: Model | None, cls: type[Model] | None = ...) -> ReverseManyToOneDescriptor: ...
def __set__(self, instance: Model, value: list[Model]) -> Any: ...
def related_manager_cls(self) -> type[RelatedManager[Any]]: ...
@overload
def __get__(self, instance: None, cls: Any = ...) -> Self: ...
@overload
def __get__(self, instance: Model, cls: Any = ...) -> type[RelatedManager[Any]]: ...
def __set__(self, instance: Any, value: Any) -> NoReturn: ...

def create_reverse_many_to_one_manager(
superclass: type[BaseManager[_M]], rel: ManyToOneRel
) -> type[RelatedManager[_M]]: ...

def create_reverse_many_to_one_manager(superclass: type, rel: Any) -> type[RelatedManager]: ...
class ManyToManyDescriptor(ReverseManyToOneDescriptor, Generic[_M]):
"""
In the example::

class Pizza(Model):
toppings = ManyToManyField(Topping, related_name='pizzas')

``Pizza.toppings`` and ``Topping.pizzas`` are ``ManyToManyDescriptor``
instances.
"""

class ManyToManyDescriptor(ReverseManyToOneDescriptor, Generic[_F]):
field: _F # type: ignore[assignment]
# 'field' here is 'rel.field'
rel: ManyToManyRel # type: ignore[assignment]
field: ManyToManyField[Any, _M] # type: ignore[assignment]
reverse: bool
def __init__(self, rel: ManyToManyRel, reverse: bool = ...) -> None: ...
@property
def through(self) -> type[Model]: ...
def through(self) -> type[_M]: ...
@property
def related_manager_cls(self) -> type[Any]: ... # ManyRelatedManager
def related_manager_cls(self) -> type[ManyRelatedManager[Any]]: ... # type: ignore[override]

# fake
class _ForwardManyToManyManager(Generic[_T]):
def all(self) -> QuerySet: ...
class ManyRelatedManager(BaseManager[_M], Generic[_M]):
related_val: tuple[int, ...]
def add(self, *objs: _M | int, bulk: bool = ...) -> None: ...
async def aadd(self, *objs: _M | int, bulk: bool = ...) -> None: ...
def remove(self, *objs: _M | int, bulk: bool = ...) -> None: ...
async def aremove(self, *objs: _M | int, bulk: bool = ...) -> None: ...
def set(self, objs: QuerySet[_M] | Iterable[_M | int], *, bulk: bool = ..., clear: bool = ...) -> None: ...
async def aset(self, objs: QuerySet[_M] | Iterable[_M | int], *, bulk: bool = ..., clear: bool = ...) -> None: ...
def clear(self) -> None: ...
async def aclear(self) -> None: ...
def __call__(self, *, manager: str) -> ManyRelatedManager[_M]: ...

def create_forward_many_to_many_manager(superclass: type, rel: Any, reverse: Any) -> _ForwardManyToManyManager: ...
def create_forward_many_to_many_manager(
superclass: type[BaseManager[_M]], rel: ManyToManyRel, reverse: bool
) -> type[ManyRelatedManager[_M]]: ...
4 changes: 2 additions & 2 deletions django-stubs/db/models/fields/reverse_related.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -112,13 +112,13 @@ class OneToOneRel(ManyToOneRel):
) -> None: ...

class ManyToManyRel(ForeignObjectRel):
field: ManyToManyField # type: ignore[assignment]
field: ManyToManyField[Any, Any] # type: ignore[assignment]
through: type[Model] | None
through_fields: tuple[str, str] | None
db_constraint: bool
def __init__(
self,
field: ManyToManyField,
field: ManyToManyField[Any, Any],
to: type[Model] | str,
related_name: str | None = ...,
related_query_name: str | None = ...,
Expand Down
24 changes: 23 additions & 1 deletion mypy_django_plugin/django/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,21 @@
from collections import defaultdict
from contextlib import contextmanager
from functools import cached_property
from typing import TYPE_CHECKING, Any, Dict, Iterable, Iterator, Literal, Optional, Sequence, Set, Tuple, Type, Union
from typing import (
TYPE_CHECKING,
Any,
Dict,
Iterable,
Iterator,
Literal,
Mapping,
Optional,
Sequence,
Set,
Tuple,
Type,
Union,
)

from django.core.exceptions import FieldDoesNotExist, FieldError
from django.db import models
Expand Down Expand Up @@ -270,6 +284,14 @@ def all_registered_model_classes(self) -> Set[Type[models.Model]]:
def all_registered_model_class_fullnames(self) -> Set[str]:
return {helpers.get_class_fullname(cls) for cls in self.all_registered_model_classes}

@cached_property
def model_class_fullnames_by_label(self) -> Mapping[str, str]:
return {
klass._meta.label: helpers.get_class_fullname(klass)
for klass in self.all_registered_model_classes
if klass is not models.Model
}

def get_field_nullability(self, field: Union["Field[Any, Any]", ForeignObjectRel], method: Optional[str]) -> bool:
if method in ("values", "values_list"):
return field.null
Expand Down
1 change: 0 additions & 1 deletion mypy_django_plugin/lib/fullnames.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@
FOREIGN_OBJECT_FULLNAME,
FOREIGN_KEY_FULLNAME,
ONETOONE_FIELD_FULLNAME,
MANYTOMANY_FIELD_FULLNAME,
flaeppe marked this conversation as resolved.
Show resolved Hide resolved
)
)

Expand Down
65 changes: 64 additions & 1 deletion mypy_django_plugin/lib/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,10 @@
from mypy.nodes import (
GDEF,
MDEF,
AssignmentStmt,
Block,
ClassDef,
Context,
Expression,
MemberExpr,
MypyFile,
Expand All @@ -33,7 +35,8 @@
SemanticAnalyzerPluginInterface,
)
from mypy.semanal import SemanticAnalyzer
from mypy.types import AnyType, Instance, NoneTyp, TupleType, TypedDictType, TypeOfAny, UnionType
from mypy.semanal_shared import parse_bool
from mypy.types import AnyType, Instance, LiteralType, NoneTyp, TupleType, TypedDictType, TypeOfAny, UnionType
from mypy.types import Type as MypyType
from typing_extensions import TypedDict

Expand All @@ -45,12 +48,14 @@


class DjangoTypeMetadata(TypedDict, total=False):
is_abstract_model: bool
from_queryset_manager: str
reverse_managers: Dict[str, str]
baseform_bases: Dict[str, int]
manager_bases: Dict[str, int]
model_bases: Dict[str, int]
queryset_bases: Dict[str, int]
m2m_throughs: Dict[str, str]


def get_django_metadata(model_info: TypeInfo) -> DjangoTypeMetadata:
Expand Down Expand Up @@ -385,3 +390,61 @@ def add_new_manager_base(api: SemanticAnalyzerPluginInterface, fullname: str) ->
if sym is not None and isinstance(sym.node, TypeInfo):
bases = get_django_metadata_bases(sym.node, "manager_bases")
bases[fullname] = 1


def is_abstract_model(model: TypeInfo) -> bool:
if model.metaclass_type is None or model.metaclass_type.type.fullname != fullnames.MODEL_METACLASS_FULLNAME:
return False

metadata = get_django_metadata(model)
if metadata.get("is_abstract_model") is not None:
return metadata["is_abstract_model"]

meta = model.names.get("Meta")
# Check if 'abstract' is declared in this model's 'class Meta' as
# 'abstract = True' won't be inherited from a parent model.
if meta is not None and isinstance(meta.node, TypeInfo) and "abstract" in meta.node.names:
for stmt in meta.node.defn.defs.body:
if (
# abstract =
isinstance(stmt, AssignmentStmt)
and len(stmt.lvalues) == 1
and isinstance(stmt.lvalues[0], NameExpr)
and stmt.lvalues[0].name == "abstract"
):
# abstract = True (builtins.bool)
rhs_is_true = parse_bool(stmt.rvalue) is True
# abstract: Literal[True]
is_literal_true = isinstance(stmt.type, LiteralType) and stmt.type.value is True
metadata["is_abstract_model"] = rhs_is_true or is_literal_true
return metadata["is_abstract_model"]

metadata["is_abstract_model"] = False
return False


def resolve_lazy_reference(
reference: str, *, api: Union[TypeChecker, SemanticAnalyzer], django_context: "DjangoContext", ctx: Context
) -> Optional[TypeInfo]:
"""
Attempts to resolve a lazy reference(e.g. "<app_label>.<object_name>") to a
'TypeInfo' instance.
"""
if "." not in reference:
# <object_name> -- needs prefix of <app_label>. We can't implicitly solve
# what app label this should be, yet.
return None
Comment on lines +433 to +436
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have a case of this at least (the referenced model is defined right below in the file), which now causes "Needs type annotation" errors


# Reference conforms to the structure of a lazy reference: '<app_label>.<object_name>'
fullname = django_context.model_class_fullnames_by_label.get(reference)
if fullname is not None:
model_info = lookup_fully_qualified_typeinfo(api, fullname)
if model_info is not None:
return model_info
elif isinstance(api, SemanticAnalyzer) and not api.final_iteration:
# Getting this far, where Django matched the reference but we still can't
# find it, we want to defer
api.defer()
else:
api.fail("Could not match lazy reference with any model", ctx)
return None
5 changes: 5 additions & 0 deletions mypy_django_plugin/transformers/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from mypy_django_plugin.django.context import DjangoContext
from mypy_django_plugin.exceptions import UnregisteredModelError
from mypy_django_plugin.lib import fullnames, helpers
from mypy_django_plugin.transformers import manytomany

if TYPE_CHECKING:
from django.contrib.contenttypes.fields import GenericForeignKey
Expand Down Expand Up @@ -213,6 +214,10 @@ def transform_into_proper_return_type(ctx: FunctionContext, django_context: Djan

assert isinstance(outer_model_info, TypeInfo)

if default_return_type.type.has_base(fullnames.MANYTOMANY_FIELD_FULLNAME):
return manytomany.fill_model_args_for_many_to_many_field(
ctx=ctx, model_info=outer_model_info, default_return_type=default_return_type, django_context=django_context
)
if helpers.has_any_of_bases(default_return_type.type, fullnames.RELATED_FIELDS_CLASSES):
return fill_descriptor_types_for_related_field(ctx, django_context)

Expand Down
Loading