Skip to content

Commit

Permalink
use field annotations for values_list types (#2248)
Browse files Browse the repository at this point in the history
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
asottile and pre-commit-ci[bot] authored Jul 4, 2024
1 parent e0366c5 commit 085b91b
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 23 deletions.
65 changes: 46 additions & 19 deletions mypy_django_plugin/django/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,35 @@ class LookupsAreUnsupported(Exception):
pass


def _get_field_type_from_model_type_info(info: Optional[TypeInfo], field_name: str) -> Optional[Instance]:
if info is None:
return None
field_node = info.get(field_name)
if field_node is None or not isinstance(field_node.type, Instance):
return None
# Field declares a set and a get type arg. Fallback to `None` when we can't find any args
elif len(field_node.type.args) != 2:
return None
else:
return field_node.type


def _get_field_set_type_from_model_type_info(info: Optional[TypeInfo], field_name: str) -> Optional[MypyType]:
field_type = _get_field_type_from_model_type_info(info, field_name)
if field_type is not None:
return field_type.args[0]
else:
return None


def _get_field_get_type_from_model_type_info(info: Optional[TypeInfo], field_name: str) -> Optional[MypyType]:
field_type = _get_field_type_from_model_type_info(info, field_name)
if field_type is not None:
return field_type.args[1]
else:
return None


class DjangoContext:
def __init__(self, django_settings_module: str) -> None:
self.django_settings_module = django_settings_module
Expand Down Expand Up @@ -152,13 +181,13 @@ def get_field_lookup_exact_type(
) -> MypyType:
if isinstance(field, (RelatedField, ForeignObjectRel)):
related_model_cls = self.get_field_related_model_cls(field)
primary_key_field = self.get_primary_key_field(related_model_cls)
primary_key_type = self.get_field_get_type(api, primary_key_field, method="init")

rel_model_info = helpers.lookup_class_typeinfo(api, related_model_cls)
if rel_model_info is None:
return AnyType(TypeOfAny.explicit)

primary_key_field = self.get_primary_key_field(related_model_cls)
primary_key_type = self.get_field_get_type(api, rel_model_info, primary_key_field, method="init")

model_and_primary_key_type = UnionType.make_union([Instance(rel_model_info, []), primary_key_type])
return helpers.make_optional(model_and_primary_key_type)

Expand Down Expand Up @@ -200,19 +229,6 @@ def get_expected_types(self, api: TypeChecker, model_cls: Type[Model], *, method
field_set_type = self.get_field_set_type(api, primary_key_field, method=method)
expected_types["pk"] = field_set_type

def get_field_set_type_from_model_type_info(info: Optional[TypeInfo], field_name: str) -> Optional[MypyType]:
if info is None:
return None
field_node = info.get(field_name)
if field_node is None or not isinstance(field_node.type, Instance):
return None
elif not field_node.type.args:
# Field declares a set and a get type arg. Fallback to `None` when we can't find any args
return None

set_type = field_node.type.args[0]
return set_type

model_info = helpers.lookup_class_typeinfo(api, model_cls)
for field in model_cls._meta.get_fields():
if isinstance(field, Field):
Expand All @@ -223,7 +239,7 @@ def get_field_set_type_from_model_type_info(info: Optional[TypeInfo], field_name
# Try to retrieve set type from a model's TypeInfo object and fallback to retrieving it manually
# from django-stubs own declaration. This is to align with the setter types declared for
# assignment.
field_set_type = get_field_set_type_from_model_type_info(
field_set_type = _get_field_set_type_from_model_type_info(
model_info, field_name
) or self.get_field_set_type(api, field, method=method)
expected_types[field_name] = field_set_type
Expand Down Expand Up @@ -340,20 +356,31 @@ def get_field_set_type(
return field_set_type

def get_field_get_type(
self, api: TypeChecker, field: Union["Field[Any, Any]", ForeignObjectRel], *, method: str
self,
api: TypeChecker,
model_info: Optional[TypeInfo],
field: Union["Field[Any, Any]", ForeignObjectRel],
*,
method: str,
) -> MypyType:
"""Get a type of __get__ for this specific Django field."""
if isinstance(field, Field):
get_type = _get_field_get_type_from_model_type_info(model_info, field.attname)
if get_type is not None:
return get_type

field_info = helpers.lookup_class_typeinfo(api, field.__class__)
if field_info is None:
return AnyType(TypeOfAny.unannotated)

is_nullable = self.get_field_nullability(field, method)
if isinstance(field, RelatedField):
related_model_cls = self.get_field_related_model_cls(field)
rel_model_info = helpers.lookup_class_typeinfo(api, related_model_cls)

if method in ("values", "values_list"):
primary_key_field = self.get_primary_key_field(related_model_cls)
return self.get_field_get_type(api, primary_key_field, method=method)
return self.get_field_get_type(api, rel_model_info, primary_key_field, method=method)

model_info = helpers.lookup_class_typeinfo(api, related_model_cls)
if model_info is None:
Expand Down
13 changes: 9 additions & 4 deletions mypy_django_plugin/transformers/querysets.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,10 +66,12 @@ def get_field_type_from_lookup(
elif (isinstance(lookup_field, RelatedField) and lookup_field.column == lookup) or isinstance(
lookup_field, ForeignObjectRel
):
related_model_cls = django_context.get_field_related_model_cls(lookup_field)
lookup_field = django_context.get_primary_key_field(related_model_cls)
model_cls = django_context.get_field_related_model_cls(lookup_field)
lookup_field = django_context.get_primary_key_field(model_cls)

field_get_type = django_context.get_field_get_type(helpers.get_typechecker_api(ctx), lookup_field, method=method)
api = helpers.get_typechecker_api(ctx)
model_info = helpers.lookup_class_typeinfo(api, model_cls)
field_get_type = django_context.get_field_get_type(api, model_info, lookup_field, method=method)
return field_get_type


Expand All @@ -87,6 +89,7 @@ def get_values_list_row_type(
return AnyType(TypeOfAny.from_error)

typechecker_api = helpers.get_typechecker_api(ctx)
model_info = helpers.lookup_class_typeinfo(typechecker_api, model_cls)
if len(field_lookups) == 0:
if flat:
primary_key_field = django_context.get_primary_key_field(model_cls)
Expand All @@ -98,7 +101,9 @@ def get_values_list_row_type(
elif named:
column_types: OrderedDict[str, MypyType] = OrderedDict()
for field in django_context.get_model_fields(model_cls):
column_type = django_context.get_field_get_type(typechecker_api, field, method="values_list")
column_type = django_context.get_field_get_type(
typechecker_api, model_info, field, method="values_list"
)
column_types[field.attname] = column_type
if is_annotated:
# Return a NamedTuple with a fallback so that it's possible to access any field
Expand Down
20 changes: 20 additions & 0 deletions tests/typecheck/managers/querysets/test_values_list.yml
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,26 @@
name = models.CharField(max_length=100)
age = models.IntegerField()
- case: values_list_types_are_field_types
main: |
from myapp.models import Concrete
ret = list(Concrete.objects.values_list('id', 'data'))
reveal_type(ret) # N: Revealed type is "builtins.list[Tuple[builtins.int, builtins.dict[builtins.str, builtins.str]]]"
installed_apps:
- myapp
files:
- path: myapp/__init__.py
- path: myapp/models.py
content: |
from __future__ import annotations
from django.db import models
class JSONField(models.TextField): pass # incomplete
class Concrete(models.Model):
id = models.IntegerField()
data: models.Field[dict[str, str], dict[str, str]] = JSONField()
- case: values_list_supports_queryset_methods
main: |
from myapp.models import MyUser
Expand Down

0 comments on commit 085b91b

Please sign in to comment.