diff --git a/README.md b/README.md index c174c0e1f..dd4d26ca0 100644 --- a/README.md +++ b/README.md @@ -428,23 +428,34 @@ Use this setting on your own risk, because it can hide valid errors. > This require type generic support, see this section to enable it. -Django `models.Field` (and subclasses) are generic types with two parameters: +Django `models.Field` (and subclasses) are generic types with three parameters: - `_ST`: type that can be used when setting a value - `_GT`: type that will be returned when getting a value +- `_NT`: `Literal[True]` or `Literal[False]`, tracking the field's `null=...` flag so + `None` can be added to `_GT`/ `_ST` automatically when the field is nullable When you create a subclass, you have two options depending on how strict you want the type to be for consumers of your custom field. +> [!IMPORTANT] +> Each `TypeVar` you forward to `models.Field` (or one of its subclasses) **must** +> declare a `default=` value (PEP 696). Without a default, mypy will not be able +> to instantiate your field without explicit type arguments and the plugin will +> not be able to infer the right types for your model attributes. + 1. Generic subclass: ```python -from typing import TypeVar, reveal_type +from typing import Literal, reveal_type +from typing_extensions import TypeVar # for `default=` (PEP 696) from django.db import models +from django.db.models.expressions import Combinable -_ST = TypeVar("_ST", contravariant=True) -_GT = TypeVar("_GT", covariant=True) +_ST = TypeVar("_ST", contravariant=True, default=float | int | str | Combinable) +_GT = TypeVar("_GT", covariant=True, default=int) +_NT = TypeVar("_NT", Literal[True], Literal[False], default=Literal[False]) -class MyIntegerField(models.IntegerField[_ST, _GT]): +class MyIntegerField(models.IntegerField[_ST, _GT, _NT]): ... class User(models.Model): @@ -458,12 +469,14 @@ User().my_field = "12" # OK (because Django IntegerField allows str and will tr 2. Non-generic subclass (more strict): ```python -from typing import reveal_type +from typing import Literal, reveal_type, TypeVar from django.db import models +_NT = TypeVar("_NT", Literal[True], Literal[False], default=Literal[False]) + # This is a non-generic subclass being very explicit # that it expects only int when setting values. -class MyStrictIntegerField(models.IntegerField[int, int]): +class MyStrictIntegerField(models.IntegerField[int, int, _NT]): ... class User(models.Model): @@ -476,6 +489,34 @@ User().my_field = "12" # E: Incompatible types in assignment (expression has typ See mypy section on [generic classes subclasses](https://mypy.readthedocs.io/en/stable/generics.html#defining-subclasses-of-generic-classes). +#### Overriding `__init__` + +If you override `__init__`, expose `null: _NT` in the signature so type +checkers can track the `null=` flag — a plain `*args, **kwargs` passthrough +loses it and `_NT` falls back to `Literal[False]`: + +```python +from typing import Any, Literal +from typing_extensions import TypeVar, assert_type +from django.db import models + +_ST = TypeVar("_ST", contravariant=True, default=float | int | str) +_GT = TypeVar("_GT", covariant=True, default=int) +_NT = TypeVar("_NT", Literal[True], Literal[False], default=Literal[False]) + +class MyIntegerField(models.IntegerField[_ST, _GT, _NT]): + def __init__(self, *args: Any, null: _NT = False, **kwargs: Any) -> None: # type: ignore[assignment] + kwargs["null"] = null + super().__init__(*args, **kwargs) + +class User(models.Model): + custom_int = MyIntegerField(null=False) + custom_int_nullable = MyIntegerField(null=True) + +assert_type(User().custom_int, int) +assert_type(User().custom_int_nullable, int | None) +``` + ## Related projects - [`awesome-python-typing`](https://github.com/typeddjango/awesome-python-typing) - Awesome list of all typing-related things in Python. diff --git a/django-stubs/contrib/admin/filters.pyi b/django-stubs/contrib/admin/filters.pyi index b63cb4453..49bc65b8a 100644 --- a/django-stubs/contrib/admin/filters.pyi +++ b/django-stubs/contrib/admin/filters.pyi @@ -51,11 +51,11 @@ class SimpleListFilter(FacetsMixin, ListFilter): class FieldListFilter(FacetsMixin, ListFilter): list_separator: ClassVar[str] - field: Field + field: Field[Any, Any, Any] field_path: str def __init__( self, - field: Field, + field: Field[Any, Any, Any], request: HttpRequest, params: dict[str, list[str]], model: type[Model], @@ -64,12 +64,15 @@ class FieldListFilter(FacetsMixin, ListFilter): ) -> None: ... @classmethod def register( - cls, test: Callable[[Field], Any], list_filter_class: type[FieldListFilter], take_priority: bool = ... + cls, + test: Callable[[Field[Any, Any, Any]], Any], + list_filter_class: type[FieldListFilter], + take_priority: bool = ..., ) -> None: ... @classmethod def create( cls, - field: Field, + field: Field[Any, Any, Any], request: HttpRequest, params: dict[str, list[str]], model: type[Model], diff --git a/django-stubs/contrib/admin/options.pyi b/django-stubs/contrib/admin/options.pyi index 2127d7bf5..f9caed640 100644 --- a/django-stubs/contrib/admin/options.pyi +++ b/django-stubs/contrib/admin/options.pyi @@ -95,7 +95,7 @@ class BaseModelAdmin(Generic[_ModelT], metaclass=MediaDefiningClass): filter_horizontal: ClassVar[_ListOrTuple[str]] radio_fields: ClassVar[Mapping[str, _Direction]] prepopulated_fields: ClassVar[dict[str, Sequence[str]]] - formfield_overrides: ClassVar[Mapping[type[Field], Mapping[str, Any]]] + formfield_overrides: ClassVar[Mapping[type[Field[Any, Any, Any]], Mapping[str, Any]]] readonly_fields: ClassVar[_ListOrTuple[str]] ordering: ClassVar[_ListOrTuple[_OrderByFieldName] | None] sortable_by: ClassVar[_ListOrTuple[str] | None] @@ -106,9 +106,11 @@ class BaseModelAdmin(Generic[_ModelT], metaclass=MediaDefiningClass): admin_site: AdminSite def __init__(self) -> None: ... def check(self, **kwargs: Any) -> list[CheckMessage]: ... - def formfield_for_dbfield(self, db_field: Field, request: HttpRequest, **kwargs: Any) -> FormField | None: ... + def formfield_for_dbfield( + self, db_field: Field[Any, Any, Any], request: HttpRequest, **kwargs: Any + ) -> FormField | None: ... def formfield_for_choice_field( - self, db_field: Field, request: HttpRequest, **kwargs: Any + self, db_field: Field[Any, Any, Any], request: HttpRequest, **kwargs: Any ) -> TypedChoiceField | None: ... def get_field_queryset(self, db: str | None, db_field: RelatedField, request: HttpRequest) -> QuerySet | None: ... def formfield_for_foreignkey( diff --git a/django-stubs/contrib/admin/utils.pyi b/django-stubs/contrib/admin/utils.pyi index b8eedc3c4..44d3deb55 100644 --- a/django-stubs/contrib/admin/utils.pyi +++ b/django-stubs/contrib/admin/utils.pyi @@ -65,7 +65,10 @@ class NestedObjects(Collector): ) -> None: ... @override def related_objects( - self, related_model: type[Model], related_fields: Iterable[Field], objs: _IndexableCollection[Model] + self, + related_model: type[Model], + related_fields: Iterable[Field[Any, Any, Any]], + objs: _IndexableCollection[Model], ) -> QuerySet[Model]: ... @overload def nested(self, format_callback: None = None) -> list[Any]: ... @@ -83,7 +86,7 @@ def model_format_dict(obj: Model | type[Model] | QuerySet | Options[Model]) -> _ def model_ngettext(obj: Options | QuerySet, n: int | None = ...) -> str: ... def lookup_field( name: Callable | str, obj: Model, model_admin: BaseModelAdmin | None = ... -) -> tuple[Field | None, str | None, Any]: ... +) -> tuple[Field[Any, Any, Any] | None, str | None, Any]: ... @overload def label_for_field( name: Callable | str, @@ -101,14 +104,16 @@ def label_for_field( form: BaseForm | None = ..., ) -> str: ... def help_text_for_field(name: str, model: type[Model]) -> str: ... -def display_for_field(value: Any, field: Field, empty_value_display: str, avoid_link: bool = False) -> str: ... +def display_for_field( + value: Any, field: Field[Any, Any, Any], empty_value_display: str, avoid_link: bool = False +) -> str: ... def display_for_value(value: Any, empty_value_display: str, boolean: bool = ...) -> str: ... class NotRelationField(Exception): ... -def get_model_from_relation(field: Field | reverse_related.ForeignObjectRel) -> type[Model]: ... +def get_model_from_relation(field: Field[Any, Any, Any] | reverse_related.ForeignObjectRel) -> type[Model]: ... def reverse_field_path(model: type[Model], path: str) -> tuple[type[Model], str]: ... -def get_fields_from_path(model: type[Model], path: str) -> list[Field]: ... +def get_fields_from_path(model: type[Model], path: str) -> list[Field[Any, Any, Any]]: ... def construct_change_message( form: Form, formsets: Iterable[BaseFormSet], add: bool ) -> list[dict[str, dict[str, list[str]]]]: ... diff --git a/django-stubs/contrib/admindocs/views.pyi b/django-stubs/contrib/admindocs/views.pyi index f015bc5e4..2173f0d33 100644 --- a/django-stubs/contrib/admindocs/views.pyi +++ b/django-stubs/contrib/admindocs/views.pyi @@ -24,7 +24,7 @@ class ModelDetailView(BaseAdminDocsView): ... class TemplateDetailView(BaseAdminDocsView): ... def get_return_data_type(func_name: Any) -> str: ... -def get_readable_field_data_type(field: Field | str) -> str: ... +def get_readable_field_data_type(field: Field[Any, Any, Any] | str) -> str: ... def extract_views_from_urlpatterns( urlpatterns: Iterable[_AnyURL], base: str = ..., namespace: str | None = ... ) -> list[tuple[Callable, Pattern[str], str | None, str | None]]: ... diff --git a/django-stubs/contrib/auth/base_user.pyi b/django-stubs/contrib/auth/base_user.pyi index 6336d1366..3b5a2aa7b 100644 --- a/django-stubs/contrib/auth/base_user.pyi +++ b/django-stubs/contrib/auth/base_user.pyi @@ -3,7 +3,6 @@ from typing import Any, ClassVar, Literal, overload from django.db import models from django.db.models.base import Model -from django.db.models.expressions import Combinable from django.db.models.fields import BooleanField from typing_extensions import TypeVar @@ -23,7 +22,7 @@ class AbstractBaseUser(models.Model): password = models.CharField(max_length=128) last_login = models.DateTimeField(blank=True, null=True) - is_active: bool | BooleanField[bool | Combinable, bool] + is_active: bool | BooleanField[bool, bool] backend: str # Set dynamically by authenticate(), used by login() def get_username(self) -> str: ... diff --git a/django-stubs/contrib/auth/management/commands/createsuperuser.pyi b/django-stubs/contrib/auth/management/commands/createsuperuser.pyi index 4bc27c592..47f90446d 100644 --- a/django-stubs/contrib/auth/management/commands/createsuperuser.pyi +++ b/django-stubs/contrib/auth/management/commands/createsuperuser.pyi @@ -12,8 +12,8 @@ PASSWORD_FIELD: str class Command(BaseCommand): UserModel: type[AbstractBaseUser] - username_field: Field + username_field: Field[Any, Any, Any] stdin: Any - def get_input_data(self, field: Field, message: str, default: str | None = ...) -> str | None: ... + def get_input_data(self, field: Field[Any, Any, Any], message: str, default: str | None = ...) -> str | None: ... @cached_property def username_is_unique(self) -> bool: ... diff --git a/django-stubs/contrib/contenttypes/fields.pyi b/django-stubs/contrib/contenttypes/fields.pyi index 0156aba7e..93de5d66e 100644 --- a/django-stubs/contrib/contenttypes/fields.pyi +++ b/django-stubs/contrib/contenttypes/fields.pyi @@ -4,8 +4,7 @@ from typing import Any from django.contrib.contenttypes.models import ContentType from django.core.checks.messages import CheckMessage from django.db.models.base import Model -from django.db.models.expressions import Combinable -from django.db.models.fields import Field, _AllLimitChoicesTo +from django.db.models.fields import _GT, _NT, _ST, Field, _AllLimitChoicesTo from django.db.models.fields.mixins import FieldCacheMixin from django.db.models.fields.related import ForeignObject from django.db.models.fields.related_descriptors import ReverseManyToOneDescriptor @@ -16,11 +15,7 @@ from django.db.models.sql.where import WhereNode from django.utils.functional import cached_property from typing_extensions import override -class GenericForeignKey(FieldCacheMixin, Field): - # django-stubs implementation only fields - _pyi_private_set_type: Any | Combinable - _pyi_private_get_type: Any - # attributes +class GenericForeignKey(FieldCacheMixin, Field[_ST, _GT, _NT]): hidden: bool is_relation: bool many_to_many: bool @@ -74,7 +69,7 @@ class GenericRel(ForeignObjectRel): limit_choices_to: _AllLimitChoicesTo | None = None, ) -> None: ... -class GenericRelation(ForeignObject[Any, Any]): +class GenericRelation(ForeignObject[_ST, _GT, _NT]): rel_class: type[GenericRel] mti_inherited: bool object_id_field_name: str @@ -91,7 +86,7 @@ class GenericRelation(ForeignObject[Any, Any]): **kwargs: Any, ) -> None: ... @override - def resolve_related_fields(self) -> list[tuple[Field, Field]]: ... + def resolve_related_fields(self) -> list[tuple[Field[Any, Any, Any], Field[Any, Any, Any]]]: ... @override def get_local_related_value(self, instance: Model) -> tuple[Any, ...]: ... @override diff --git a/django-stubs/contrib/gis/admin/options.pyi b/django-stubs/contrib/gis/admin/options.pyi index 19dadf275..8f8fc4237 100644 --- a/django-stubs/contrib/gis/admin/options.pyi +++ b/django-stubs/contrib/gis/admin/options.pyi @@ -13,6 +13,8 @@ _ModelT = TypeVar("_ModelT", bound=Model) class GeoModelAdminMixin: gis_widget: type[OSMWidget] gis_widget_kwargs: dict[str, Any] - def formfield_for_dbfield(self, db_field: Field, request: HttpRequest, **kwargs: Any) -> FormField | None: ... + def formfield_for_dbfield( + self, db_field: Field[Any, Any, Any], request: HttpRequest, **kwargs: Any + ) -> FormField | None: ... class GISModelAdmin(GeoModelAdminMixin, ModelAdmin[_ModelT]): ... diff --git a/django-stubs/contrib/gis/db/backends/mysql/schema.pyi b/django-stubs/contrib/gis/db/backends/mysql/schema.pyi index 2e7384a59..070f7212a 100644 --- a/django-stubs/contrib/gis/db/backends/mysql/schema.pyi +++ b/django-stubs/contrib/gis/db/backends/mysql/schema.pyi @@ -11,14 +11,14 @@ logger: Logger class MySQLGISSchemaEditor(DatabaseSchemaEditor): sql_add_spatial_index: str @override - def skip_default(self, field: Field) -> bool: ... + def skip_default(self, field: Field[Any, Any, Any]) -> bool: ... @override def column_sql( - self, model: type[Model], field: Field, include_default: bool = ... + self, model: type[Model], field: Field[Any, Any, Any], include_default: bool = ... ) -> tuple[None, None] | tuple[str, list[Any]]: ... @override def create_model(self, model: type[Model]) -> None: ... @override - def add_field(self, model: type[Model], field: Field) -> None: ... + def add_field(self, model: type[Model], field: Field[Any, Any, Any]) -> None: ... @override - def remove_field(self, model: type[Model], field: Field) -> None: ... + def remove_field(self, model: type[Model], field: Field[Any, Any, Any]) -> None: ... diff --git a/django-stubs/contrib/gis/db/backends/oracle/schema.pyi b/django-stubs/contrib/gis/db/backends/oracle/schema.pyi index a30c90b36..467dbd160 100644 --- a/django-stubs/contrib/gis/db/backends/oracle/schema.pyi +++ b/django-stubs/contrib/gis/db/backends/oracle/schema.pyi @@ -15,14 +15,14 @@ class OracleGISSchemaEditor(DatabaseSchemaEditor): def geo_quote_name(self, name: Any) -> Any: ... @override def column_sql( - self, model: type[Model], field: Field, include_default: bool = ... + self, model: type[Model], field: Field[Any, Any, Any], include_default: bool = ... ) -> tuple[None, None] | tuple[str, list[Any]]: ... @override def create_model(self, model: type[Model]) -> None: ... @override def delete_model(self, model: type[Model]) -> None: ... @override - def add_field(self, model: type[Model], field: Field) -> None: ... + def add_field(self, model: type[Model], field: Field[Any, Any, Any]) -> None: ... @override - def remove_field(self, model: type[Model], field: Field) -> None: ... + def remove_field(self, model: type[Model], field: Field[Any, Any, Any]) -> None: ... def run_geometry_sql(self) -> None: ... diff --git a/django-stubs/contrib/gis/db/backends/spatialite/schema.pyi b/django-stubs/contrib/gis/db/backends/spatialite/schema.pyi index c15f949c8..c32d083de 100644 --- a/django-stubs/contrib/gis/db/backends/spatialite/schema.pyi +++ b/django-stubs/contrib/gis/db/backends/spatialite/schema.pyi @@ -19,17 +19,17 @@ class SpatialiteSchemaEditor(DatabaseSchemaEditor): def geo_quote_name(self, name: Any) -> Any: ... @override def column_sql( - self, model: type[Model], field: Field, include_default: bool = False + self, model: type[Model], field: Field[Any, Any, Any], include_default: bool = False ) -> tuple[None, None] | tuple[str, list[Any]]: ... - def remove_geometry_metadata(self, model: type[Model], field: Field) -> None: ... + def remove_geometry_metadata(self, model: type[Model], field: Field[Any, Any, Any]) -> None: ... @override def create_model(self, model: type[Model]) -> None: ... @override def delete_model(self, model: type[Model], **kwargs: Any) -> None: ... # type: ignore[override] @override - def add_field(self, model: type[Model], field: Field) -> None: ... + def add_field(self, model: type[Model], field: Field[Any, Any, Any]) -> None: ... @override - def remove_field(self, model: type[Model], field: Field) -> None: ... + def remove_field(self, model: type[Model], field: Field[Any, Any, Any]) -> None: ... @override def alter_db_table( self, diff --git a/django-stubs/contrib/gis/db/models/fields.pyi b/django-stubs/contrib/gis/db/models/fields.pyi index 08f542006..0747fe05a 100644 --- a/django-stubs/contrib/gis/db/models/fields.pyi +++ b/django-stubs/contrib/gis/db/models/fields.pyi @@ -14,17 +14,12 @@ from django.contrib.gis.geos import ( ) from django.core.validators import _ValidatorCallable from django.db.models import Model -from django.db.models.expressions import Combinable, Expression -from django.db.models.fields import NOT_PROVIDED, Field, _ErrorMessagesMapping +from django.db.models.expressions import Expression +from django.db.models.fields import _GT, _NT, _ST, NOT_PROVIDED, Field, _ErrorMessagesMapping from django.utils.choices import _Choices from django.utils.functional import _StrOrPromise from typing_extensions import TypeVar, override -# __set__ value type -_ST = TypeVar("_ST") -# __get__ return type -_GT = TypeVar("_GT") - class SRIDCacheEntry(NamedTuple): units: Any units_name: str @@ -33,7 +28,7 @@ class SRIDCacheEntry(NamedTuple): def get_srid_info(srid: int, connection: Any) -> SRIDCacheEntry: ... -class BaseSpatialField(Field[_ST, _GT]): +class BaseSpatialField(Field[_ST, _GT, _NT]): form_class: type[forms.GeometryField] geom_type: str geom_class: type[GEOSGeometry] | None @@ -51,7 +46,7 @@ class BaseSpatialField(Field[_ST, _GT]): max_length: int | None = ..., unique: bool = ..., blank: bool = ..., - null: bool = ..., + null: _NT = ..., db_index: bool = ..., default: Any = ..., db_default: type[NOT_PROVIDED] | Expression | _ST = ..., @@ -83,7 +78,7 @@ class BaseSpatialField(Field[_ST, _GT]): @override def get_prep_value(self, value: Any) -> Any: ... -class GeometryField(BaseSpatialField[_ST, _GT]): +class GeometryField(BaseSpatialField[_ST, _GT, _NT]): dim: int def __init__( self, @@ -100,7 +95,7 @@ class GeometryField(BaseSpatialField[_ST, _GT]): max_length: int | None = ..., unique: bool = ..., blank: bool = ..., - null: bool = ..., + null: _NT = ..., db_index: bool = ..., default: Any = ..., db_default: type[NOT_PROVIDED] | Expression | _ST = ..., @@ -130,67 +125,74 @@ class GeometryField(BaseSpatialField[_ST, _GT]): **kwargs: Any, ) -> forms.GeometryField: ... -class PointField(GeometryField[_ST, _GT]): - _pyi_private_set_type: Point | Combinable - _pyi_private_get_type: Point +_ST_Point = TypeVar("_ST_Point", default=Point) +_GT_Point = TypeVar("_GT_Point", default=Point) + +class PointField(GeometryField[_ST_Point, _GT_Point, _NT]): _pyi_lookup_exact_type: Point geom_class: type[Point] form_class: type[forms.PointField] -class LineStringField(GeometryField[_ST, _GT]): - _pyi_private_set_type: LineString | Combinable - _pyi_private_get_type: LineString +_ST_LineString = TypeVar("_ST_LineString", default=LineString) +_GT_LineString = TypeVar("_GT_LineString", default=LineString) + +class LineStringField(GeometryField[_ST_LineString, _GT_LineString, _NT]): _pyi_lookup_exact_type: LineString geom_class: type[LineString] form_class: type[forms.LineStringField] -class PolygonField(GeometryField[_ST, _GT]): - _pyi_private_set_type: Polygon | Combinable - _pyi_private_get_type: Polygon +_ST_Polygon = TypeVar("_ST_Polygon", default=Polygon) +_GT_Polygon = TypeVar("_GT_Polygon", default=Polygon) + +class PolygonField(GeometryField[_ST_Polygon, _GT_Polygon, _NT]): _pyi_lookup_exact_type: Polygon geom_class: type[Polygon] form_class: type[forms.PolygonField] -class MultiPointField(GeometryField[_ST, _GT]): - _pyi_private_set_type: MultiPoint | Combinable - _pyi_private_get_type: MultiPoint +_ST_MultiPoint = TypeVar("_ST_MultiPoint", default=MultiPoint) +_GT_MultiPoint = TypeVar("_GT_MultiPoint", default=MultiPoint) + +class MultiPointField(GeometryField[_ST_MultiPoint, _GT_MultiPoint, _NT]): _pyi_lookup_exact_type: MultiPoint geom_class: type[MultiPoint] form_class: type[forms.MultiPointField] -class MultiLineStringField(GeometryField[_ST, _GT]): - _pyi_private_set_type: MultiLineString | Combinable - _pyi_private_get_type: MultiLineString +_ST_MultiLineString = TypeVar("_ST_MultiLineString", default=MultiLineString) +_GT_MultiLineString = TypeVar("_GT_MultiLineString", default=MultiLineString) + +class MultiLineStringField(GeometryField[_ST_MultiLineString, _GT_MultiLineString, _NT]): _pyi_lookup_exact_type: MultiLineString geom_class: type[MultiLineString] form_class: type[forms.MultiLineStringField] -class MultiPolygonField(GeometryField[_ST, _GT]): - _pyi_private_set_type: MultiPolygon | Combinable - _pyi_private_get_type: MultiPolygon +_ST_MultiPolygon = TypeVar("_ST_MultiPolygon", default=MultiPolygon) +_GT_MultiPolygon = TypeVar("_GT_MultiPolygon", default=MultiPolygon) + +class MultiPolygonField(GeometryField[_ST_MultiPolygon, _GT_MultiPolygon, _NT]): _pyi_lookup_exact_type: MultiPolygon geom_class: type[MultiPolygon] form_class: type[forms.MultiPolygonField] -class GeometryCollectionField(GeometryField[_ST, _GT]): - _pyi_private_set_type: GeometryCollection | Combinable - _pyi_private_get_type: GeometryCollection +_ST_GeometryCollection = TypeVar("_ST_GeometryCollection", default=GeometryCollection) +_GT_GeometryCollection = TypeVar("_GT_GeometryCollection", default=GeometryCollection) + +class GeometryCollectionField(GeometryField[_ST_GeometryCollection, _GT_GeometryCollection, _NT]): _pyi_lookup_exact_type: GeometryCollection geom_class: type[GeometryCollection] form_class: type[forms.GeometryCollectionField] -class ExtentField(Field[Any, Any]): +class ExtentField(Field[Any, Any, _NT]): @override def get_internal_type(self) -> str: ... -class RasterField(BaseSpatialField): +class RasterField(BaseSpatialField[_ST, _GT, _NT]): @override def db_type(self, connection: Any) -> Any: ... def from_db_value(self, value: Any, expression: Any, connection: Any) -> Any: ... diff --git a/django-stubs/contrib/gis/utils/layermapping.pyi b/django-stubs/contrib/gis/utils/layermapping.pyi index b78c61b2e..b7cf55615 100644 --- a/django-stubs/contrib/gis/utils/layermapping.pyi +++ b/django-stubs/contrib/gis/utils/layermapping.pyi @@ -20,7 +20,7 @@ class _Writer(Protocol): class LayerMapping: MULTI_TYPES: dict[int, OGRGeomType] - FIELD_TYPES: dict[Field, OGRField | tuple[OGRField, ...]] + FIELD_TYPES: dict[Field[Any, Any, Any], OGRField | tuple[OGRField, ...]] ds: DataSource layer: Layer using: str @@ -49,7 +49,7 @@ class LayerMapping: ) -> None: ... def check_fid_range(self, fid_range: Any) -> Any: ... geom_field: str - fields: dict[str, Field] + fields: dict[str, Field[Any, Any, Any]] coord_dim: int def check_layer(self) -> Any: ... def check_srs(self, source_srs: Any) -> Any: ... diff --git a/django-stubs/contrib/postgres/fields/array.pyi b/django-stubs/contrib/postgres/fields/array.pyi index 9a3aa9f75..dbbd67403 100644 --- a/django-stubs/contrib/postgres/fields/array.pyi +++ b/django-stubs/contrib/postgres/fields/array.pyi @@ -8,31 +8,28 @@ from django.core.validators import _ValidatorCallable from django.db.backends.base.base import BaseDatabaseWrapper from django.db.models import Field from django.db.models.expressions import Combinable, Expression -from django.db.models.fields import NOT_PROVIDED, _ErrorMessagesDict, _ErrorMessagesMapping +from django.db.models.fields import _NT, NOT_PROVIDED, _ErrorMessagesDict, _ErrorMessagesMapping from django.db.models.fields.mixins import CheckFieldDefaultMixin from django.db.models.lookups import Transform from django.utils.choices import _Choices from django.utils.functional import _StrOrPromise from typing_extensions import TypeVar, override -# __set__ value type -_ST = TypeVar("_ST") -# __get__ return type -_GT = TypeVar("_GT") - -class ArrayField(CheckPostgresInstalledMixin, CheckFieldDefaultMixin, Field[_ST, _GT]): - _pyi_private_set_type: Sequence[Any] | Combinable - _pyi_private_get_type: list[Any] +_ST_Array = TypeVar("_ST_Array", contravariant=True, default=Any) +_GT_Array = TypeVar("_GT_Array", covariant=True, default=Any) +class ArrayField( + CheckPostgresInstalledMixin, CheckFieldDefaultMixin, Field[Sequence[_ST_Array] | Combinable, list[_GT_Array], _NT] +): empty_strings_allowed: bool default_error_messages: ClassVar[_ErrorMessagesDict] - base_field: Field + base_field: Field[_ST_Array, _GT_Array, Any] size: int | None default_validators: Sequence[_ValidatorCallable] from_db_value: Any def __init__( self, - base_field: Field, + base_field: Field[_ST_Array, _GT_Array, Any], size: int | None = None, *, verbose_name: _StrOrPromise | None = ..., @@ -41,10 +38,10 @@ class ArrayField(CheckPostgresInstalledMixin, CheckFieldDefaultMixin, Field[_ST, max_length: int | None = ..., unique: bool = ..., blank: bool = ..., - null: bool = ..., + null: _NT = ..., db_index: bool = ..., default: Any = ..., - db_default: type[NOT_PROVIDED] | Expression | _ST = ..., + db_default: type[NOT_PROVIDED] | Expression | list[_ST_Array] = ..., editable: bool = ..., auto_created: bool = ..., serialize: bool = ..., diff --git a/django-stubs/contrib/postgres/fields/hstore.pyi b/django-stubs/contrib/postgres/fields/hstore.pyi index 340389951..07c72c993 100644 --- a/django-stubs/contrib/postgres/fields/hstore.pyi +++ b/django-stubs/contrib/postgres/fields/hstore.pyi @@ -8,7 +8,7 @@ from django.db.models.fields.mixins import CheckFieldDefaultMixin from django.db.models.sql.compiler import SQLCompiler, _AsSqlType from typing_extensions import override -class HStoreField(CheckPostgresInstalledMixin, CheckFieldDefaultMixin, Field): +class HStoreField(CheckPostgresInstalledMixin, CheckFieldDefaultMixin, Field[Any, Any, Any]): @override def get_transform(self, name: str) -> Any: ... @override diff --git a/django-stubs/contrib/postgres/fields/ranges.pyi b/django-stubs/contrib/postgres/fields/ranges.pyi index ab553aa54..827c7dfe2 100644 --- a/django-stubs/contrib/postgres/fields/ranges.pyi +++ b/django-stubs/contrib/postgres/fields/ranges.pyi @@ -5,9 +5,10 @@ from django.contrib.postgres import forms from django.contrib.postgres.utils import CheckPostgresInstalledMixin from django.db import models from django.db.backends.base.base import BaseDatabaseWrapper +from django.db.models.fields import _NT, _ST from django.db.models.lookups import PostgresOperatorLookup from django.db.models.sql.compiler import SQLCompiler, _AsSqlType -from psycopg2.extras import DateRange, DateTimeTZRange, NumericRange, Range # type: ignore[import-untyped] +from psycopg2.extras import DateRange, DateTimeTZRange, NumericRange # type: ignore[import-untyped] from typing_extensions import TypeVar, override class RangeBoundary(models.Expression): @@ -27,9 +28,9 @@ class RangeOperators: NOT_GT: Literal["&<"] ADJACENT_TO: Literal["-|-"] -_RangeT = TypeVar("_RangeT", bound=Range[Any]) +_RangeT = TypeVar("_RangeT", covariant=True, default=Any) -class RangeField(CheckPostgresInstalledMixin, models.Field[Any, _RangeT]): +class RangeField(CheckPostgresInstalledMixin, models.Field[_ST, _RangeT, _NT]): empty_strings_allowed: bool base_field: type[models.Field] range_type: type[_RangeT] @@ -43,27 +44,27 @@ class RangeField(CheckPostgresInstalledMixin, models.Field[Any, _RangeT]): @override def formfield(self, **kwargs: Any) -> Any: ... # type: ignore[override] -class ContinuousRangeField(RangeField[_RangeT]): +class ContinuousRangeField(RangeField[_ST, _RangeT, _NT]): default_bounds: str def __init__(self, *args: Any, default_bounds: str = "[)", **kwargs: Any) -> None: ... -class IntegerRangeField(RangeField[NumericRange]): +class IntegerRangeField(RangeField[_ST, NumericRange, _NT]): base_field: type[models.IntegerField] form_field: type[forms.IntegerRangeField] -class BigIntegerRangeField(RangeField[NumericRange]): +class BigIntegerRangeField(RangeField[_ST, NumericRange, _NT]): base_field: type[models.BigIntegerField] form_field: type[forms.IntegerRangeField] -class DecimalRangeField(ContinuousRangeField[NumericRange]): +class DecimalRangeField(ContinuousRangeField[_ST, NumericRange, _NT]): base_field: type[models.DecimalField] form_field: type[forms.DecimalRangeField] -class DateTimeRangeField(ContinuousRangeField[DateTimeTZRange]): +class DateTimeRangeField(ContinuousRangeField[_ST, DateTimeTZRange, _NT]): base_field: type[models.DateTimeField] form_field: type[forms.DateTimeRangeField] -class DateRangeField(RangeField[DateRange]): +class DateRangeField(RangeField[_ST, DateRange, _NT]): base_field: type[models.DateField] form_field: type[forms.DateRangeField] diff --git a/django-stubs/contrib/postgres/search.pyi b/django-stubs/contrib/postgres/search.pyi index f8e82c56c..91586c6b9 100644 --- a/django-stubs/contrib/postgres/search.pyi +++ b/django-stubs/contrib/postgres/search.pyi @@ -31,8 +31,8 @@ class SearchVectorExact(Lookup): @override def as_sql(self, qn: SQLCompiler, connection: BaseDatabaseWrapper) -> _AsSqlType: ... -class SearchVectorField(CheckPostgresInstalledMixin, Field): ... -class SearchQueryField(CheckPostgresInstalledMixin, Field): ... +class SearchVectorField(CheckPostgresInstalledMixin, Field[Any, Any, Any]): ... +class SearchQueryField(CheckPostgresInstalledMixin, Field[Any, Any, Any]): ... class SearchConfig(Expression): config: _Expression | None @@ -67,7 +67,7 @@ class CombinedSearchVector(SearchVectorCombinable, CombinedExpression): connector: str, rhs: Combinable, config: _Expression | None, - output_field: Field | None = None, + output_field: Field[Any, Any, Any] | None = None, ) -> None: ... class SearchQueryCombinable: @@ -84,7 +84,7 @@ class SearchQuery(SearchQueryCombinable, Func): # type: ignore[misc] def __init__( self, value: _Expression, - output_field: Field | None = None, + output_field: Field[Any, Any, Any] | None = None, *, config: _Expression | None = None, invert: bool = False, @@ -108,7 +108,7 @@ class CombinedSearchQuery(SearchQueryCombinable, CombinedExpression): # type: i connector: str, rhs: Combinable, config: _Expression | None, - output_field: Field | None = None, + output_field: Field[Any, Any, Any] | None = None, ) -> None: ... class SearchRank(Func): diff --git a/django-stubs/contrib/sessions/backends/db.pyi b/django-stubs/contrib/sessions/backends/db.pyi index 2db64bf67..0678d2ef6 100644 --- a/django-stubs/contrib/sessions/backends/db.pyi +++ b/django-stubs/contrib/sessions/backends/db.pyi @@ -5,13 +5,13 @@ from django.contrib.sessions.base_session import AbstractBaseSession from django.utils.functional import cached_property from typing_extensions import TypeVar -_ST = TypeVar("_ST", bound=AbstractBaseSession, default=AbstractBaseSession) +_SessionT = TypeVar("_SessionT", bound=AbstractBaseSession, default=AbstractBaseSession) -class SessionStore(SessionBase, Generic[_ST]): +class SessionStore(SessionBase, Generic[_SessionT]): def __init__(self, session_key: str | None = None) -> None: ... @classmethod - def get_model_class(cls) -> type[_ST]: ... + def get_model_class(cls) -> type[_SessionT]: ... @cached_property - def model(self) -> type[_ST]: ... - def create_model_instance(self, data: dict[str, Any]) -> _ST: ... - async def acreate_model_instance(self, data: dict[str, Any]) -> _ST: ... + def model(self) -> type[_SessionT]: ... + def create_model_instance(self, data: dict[str, Any]) -> _SessionT: ... + async def acreate_model_instance(self, data: dict[str, Any]) -> _SessionT: ... diff --git a/django-stubs/core/serializers/base.pyi b/django-stubs/core/serializers/base.pyi index ecc3e167b..6da384161 100644 --- a/django-stubs/core/serializers/base.pyi +++ b/django-stubs/core/serializers/base.pyi @@ -70,12 +70,12 @@ class Deserializer: class DeserializedObject: object: Model m2m_data: dict[str, Sequence[Any]] | None - deferred_fields: dict[Field, Any] + deferred_fields: dict[Field[Any, Any, Any], Any] def __init__( self, obj: Model, m2m_data: dict[str, Sequence[Any]] | None = None, - deferred_fields: dict[Field, Any] | None = None, + deferred_fields: dict[Field[Any, Any, Any], Any] | None = None, ) -> None: ... def save(self, save_m2m: bool = True, using: str | None = None, **kwargs: Any) -> None: ... def save_deferred_fields(self, using: str | None = None) -> None: ... diff --git a/django-stubs/core/serializers/pyyaml.pyi b/django-stubs/core/serializers/pyyaml.pyi index c2dbbfff5..6c89a6ae3 100644 --- a/django-stubs/core/serializers/pyyaml.pyi +++ b/django-stubs/core/serializers/pyyaml.pyi @@ -14,7 +14,7 @@ class DjangoSafeDumper(SafeDumper): class Serializer(PythonSerializer): internal_use_only: bool @override - def handle_field(self, obj: Any, field: Field) -> None: ... + def handle_field(self, obj: Any, field: Field[Any, Any, Any]) -> None: ... @override def end_serialization(self) -> None: ... @override diff --git a/django-stubs/db/backends/base/operations.pyi b/django-stubs/db/backends/base/operations.pyi index 7d00d5bee..0c15aab49 100644 --- a/django-stubs/db/backends/base/operations.pyi +++ b/django-stubs/db/backends/base/operations.pyi @@ -33,7 +33,7 @@ class BaseDatabaseOperations: def bulk_batch_size(self, fields: Any, objs: Any) -> int: ... def format_for_duration_arithmetic(self, sql: str) -> str: ... def cache_key_culling_sql(self) -> str: ... - def unification_cast_sql(self, output_field: Field) -> str: ... + def unification_cast_sql(self, output_field: Field[Any, Any, Any]) -> str: ... def date_extract_sql(self, lookup_type: str, sql: Any, params: Any) -> tuple[str, Any]: ... # def date_interval_sql(self, timedelta: None) -> Any: ... def date_trunc_sql(self, lookup_type: str, sql: str, params: Any, tzname: str | None = None) -> tuple[str, Any]: ... @@ -51,7 +51,7 @@ class BaseDatabaseOperations: self, nowait: bool = False, skip_locked: bool = False, of: Any = (), no_key: bool = False ) -> str: ... def limit_offset_sql(self, low_mark: int, high_mark: int | None) -> str: ... - def bulk_insert_sql(self, fields: Iterable[Field], placeholder_rows: Iterable[str]) -> str: ... + def bulk_insert_sql(self, fields: Iterable[Field[Any, Any, Any]], placeholder_rows: Iterable[str]) -> str: ... def last_executed_query(self, cursor: Any, sql: Any, params: Any) -> str: ... def last_insert_id(self, cursor: CursorWrapper, table_name: str, pk_name: str) -> int: ... def lookup_cast(self, lookup_type: str, internal_type: str | None = None) -> str: ... @@ -61,7 +61,7 @@ class BaseDatabaseOperations: def pk_default_value(self) -> str: ... def prepare_sql_script(self, sql: Any) -> list[str]: ... def process_clob(self, value: str) -> str: ... - def returning_columns(self, fields: Iterable[Field]) -> tuple[str, tuple[()]]: ... + def returning_columns(self, fields: Iterable[Field[Any, Any, Any]]) -> tuple[str, tuple[()]]: ... def fetch_returned_rows(self, cursor: CursorWrapper, returning_params: tuple[()]) -> list[tuple[Any, ...]]: ... def compiler(self, compiler_name: str) -> type[SQLCompiler]: ... def quote_name(self, name: str) -> str: ... @@ -115,7 +115,7 @@ class BaseDatabaseOperations: self, fields: Any, on_conflict: Any, update_fields: Any, unique_fields: Any ) -> str | Any: ... def prepare_join_on_clause( - self, lhs_table: str, lhs_field: Field, rhs_table: str, rhs_field: Field + self, lhs_table: str, lhs_field: Field[Any, Any, Any], rhs_table: str, rhs_field: Field[Any, Any, Any] ) -> tuple[Col, Col]: ... def format_debug_sql(self, sql: str) -> str: ... def format_json_path_numeric_index(self, num: int) -> str: ... diff --git a/django-stubs/db/backends/base/schema.pyi b/django-stubs/db/backends/base/schema.pyi index 7def0ffe6..b3ee7a056 100644 --- a/django-stubs/db/backends/base/schema.pyi +++ b/django-stubs/db/backends/base/schema.pyi @@ -75,13 +75,13 @@ class BaseDatabaseSchemaEditor(AbstractContextManager[Any]): def quote_name(self, name: str) -> str: ... def table_sql(self, model: type[Model]) -> tuple[str, list[Any]]: ... def column_sql( - self, model: type[Model], field: Field, include_default: bool = False + self, model: type[Model], field: Field[Any, Any, Any], include_default: bool = False ) -> tuple[None, None] | tuple[str, list[Any]]: ... - def skip_default(self, field: Field) -> bool: ... - def skip_default_on_alter(self, field: Field) -> bool: ... + def skip_default(self, field: Field[Any, Any, Any]) -> bool: ... + def skip_default_on_alter(self, field: Field[Any, Any, Any]) -> bool: ... def prepare_default(self, value: Any) -> Any: ... - def db_default_sql(self, field: Field) -> _AsSqlType: ... - def effective_default(self, field: Field) -> int | str: ... + def db_default_sql(self, field: Field[Any, Any, Any]) -> _AsSqlType: ... + def effective_default(self, field: Field[Any, Any, Any]) -> int | str: ... def quote_value(self, value: Any) -> str: ... def create_model(self, model: type[Model]) -> None: ... def delete_model(self, model: type[Model]) -> None: ... @@ -107,7 +107,9 @@ class BaseDatabaseSchemaEditor(AbstractContextManager[Any]): self, model: type[Model], old_db_table_comment: str, new_db_table_comment: str ) -> None: ... def alter_db_tablespace(self, model: type[Model], old_db_tablespace: str, new_db_tablespace: str) -> None: ... - def add_field(self, model: type[Model], field: Field) -> None: ... - def remove_field(self, model: type[Model], field: Field) -> None: ... - def alter_field(self, model: type[Model], old_field: Field, new_field: Field, strict: bool = False) -> None: ... + def add_field(self, model: type[Model], field: Field[Any, Any, Any]) -> None: ... + def remove_field(self, model: type[Model], field: Field[Any, Any, Any]) -> None: ... + def alter_field( + self, model: type[Model], old_field: Field[Any, Any, Any], new_field: Field[Any, Any, Any], strict: bool = False + ) -> None: ... def remove_procedure(self, procedure_name: Any, param_types: Any = ()) -> None: ... diff --git a/django-stubs/db/backends/base/validation.pyi b/django-stubs/db/backends/base/validation.pyi index 6152f8fe0..7c8a1358e 100644 --- a/django-stubs/db/backends/base/validation.pyi +++ b/django-stubs/db/backends/base/validation.pyi @@ -9,4 +9,4 @@ class BaseDatabaseValidation: def __init__(self, connection: BaseDatabaseWrapper) -> None: ... def __del__(self) -> None: ... def check(self, **kwargs: Any) -> list[CheckMessage]: ... - def check_field(self, field: Field, **kwargs: Any) -> list[CheckMessage]: ... + def check_field(self, field: Field[Any, Any, Any], **kwargs: Any) -> list[CheckMessage]: ... diff --git a/django-stubs/db/backends/mysql/schema.pyi b/django-stubs/db/backends/mysql/schema.pyi index 1cc9664b0..f0989f487 100644 --- a/django-stubs/db/backends/mysql/schema.pyi +++ b/django-stubs/db/backends/mysql/schema.pyi @@ -30,6 +30,6 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor): @override def quote_value(self, value: Any) -> str: ... @override - def skip_default(self, field: Field) -> bool: ... + def skip_default(self, field: Field[Any, Any, Any]) -> bool: ... @override - def add_field(self, model: type[Model], field: Field) -> None: ... + def add_field(self, model: type[Model], field: Field[Any, Any, Any]) -> None: ... diff --git a/django-stubs/db/backends/mysql/validation.pyi b/django-stubs/db/backends/mysql/validation.pyi index a4b8c5279..e97a460bc 100644 --- a/django-stubs/db/backends/mysql/validation.pyi +++ b/django-stubs/db/backends/mysql/validation.pyi @@ -10,4 +10,4 @@ class DatabaseValidation(BaseDatabaseValidation): connection: DatabaseWrapper @override def check(self, **kwargs: Any) -> list[CheckMessage]: ... - def check_field_type(self, field: Field, field_type: str) -> list[CheckMessage]: ... + def check_field_type(self, field: Field[Any, Any, Any], field_type: str) -> list[CheckMessage]: ... diff --git a/django-stubs/db/backends/oracle/schema.pyi b/django-stubs/db/backends/oracle/schema.pyi index 18d0f59ea..841f9159d 100644 --- a/django-stubs/db/backends/oracle/schema.pyi +++ b/django-stubs/db/backends/oracle/schema.pyi @@ -21,11 +21,13 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor): @override def quote_value(self, value: Any) -> str: ... @override - def remove_field(self, model: type[Model], field: Field) -> None: ... + def remove_field(self, model: type[Model], field: Field[Any, Any, Any]) -> None: ... @override def delete_model(self, model: type[Model]) -> None: ... @override - def alter_field(self, model: type[Model], old_field: Field, new_field: Field, strict: bool = False) -> None: ... + def alter_field( + self, model: type[Model], old_field: Field[Any, Any, Any], new_field: Field[Any, Any, Any], strict: bool = False + ) -> None: ... def normalize_name(self, name: Any) -> str: ... @override def prepare_default(self, value: Any) -> Any: ... diff --git a/django-stubs/db/backends/oracle/validation.pyi b/django-stubs/db/backends/oracle/validation.pyi index 733ec8af9..080755bb4 100644 --- a/django-stubs/db/backends/oracle/validation.pyi +++ b/django-stubs/db/backends/oracle/validation.pyi @@ -1,3 +1,5 @@ +from typing import Any + from django.core.checks.messages import CheckMessage from django.db.backends.base.validation import BaseDatabaseValidation from django.db.backends.oracle.base import DatabaseWrapper @@ -5,4 +7,4 @@ from django.db.models.fields import Field class DatabaseValidation(BaseDatabaseValidation): connection: DatabaseWrapper - def check_field_type(self, field: Field, field_type: str) -> list[CheckMessage]: ... + def check_field_type(self, field: Field[Any, Any, Any], field_type: str) -> list[CheckMessage]: ... diff --git a/django-stubs/db/migrations/autodetector.pyi b/django-stubs/db/migrations/autodetector.pyi index ea9343ab1..ba675a8d2 100644 --- a/django-stubs/db/migrations/autodetector.pyi +++ b/django-stubs/db/migrations/autodetector.pyi @@ -47,7 +47,7 @@ class MigrationAutodetector: ) -> dict[str, list[Migration]]: ... def deep_deconstruct(self, obj: Any) -> Any: ... def only_relation_agnostic_fields( - self, fields: dict[str, Field] + self, fields: dict[str, Field[Any, Any, Any]] ) -> list[tuple[str, list[Any], dict[str, Callable | int | str]]]: ... def check_dependency(self, operation: Operation, dependency: tuple[str, str, str | None, bool | str]) -> bool: ... def add_operation( diff --git a/django-stubs/db/migrations/operations/fields.pyi b/django-stubs/db/migrations/operations/fields.pyi index 3b1a09d63..e0572a31a 100644 --- a/django-stubs/db/migrations/operations/fields.pyi +++ b/django-stubs/db/migrations/operations/fields.pyi @@ -1,3 +1,5 @@ +from typing import Any + from django.db.models.fields import Field from django.utils.functional import cached_property @@ -6,7 +8,7 @@ from .base import Operation class FieldOperation(Operation): model_name: str name: str - def __init__(self, model_name: str, name: str, field: Field | None = None) -> None: ... + def __init__(self, model_name: str, name: str, field: Field[Any, Any, Any] | None = None) -> None: ... @cached_property def name_lower(self) -> str: ... @cached_property @@ -15,16 +17,20 @@ class FieldOperation(Operation): def is_same_field_operation(self, operation: FieldOperation) -> bool: ... class AddField(FieldOperation): - field: Field + field: Field[Any, Any, Any] preserve_default: bool - def __init__(self, model_name: str, name: str, field: Field, preserve_default: bool = True) -> None: ... + def __init__( + self, model_name: str, name: str, field: Field[Any, Any, Any], preserve_default: bool = True + ) -> None: ... class RemoveField(FieldOperation): ... class AlterField(FieldOperation): - field: Field + field: Field[Any, Any, Any] preserve_default: bool - def __init__(self, model_name: str, name: str, field: Field, preserve_default: bool = True) -> None: ... + def __init__( + self, model_name: str, name: str, field: Field[Any, Any, Any], preserve_default: bool = True + ) -> None: ... class RenameField(FieldOperation): old_name: str diff --git a/django-stubs/db/migrations/operations/models.pyi b/django-stubs/db/migrations/operations/models.pyi index 6c644ca38..16d01b1f7 100644 --- a/django-stubs/db/migrations/operations/models.pyi +++ b/django-stubs/db/migrations/operations/models.pyi @@ -20,14 +20,14 @@ class ModelOperation(Operation): def can_reduce_through(self, operation: Operation, app_label: str) -> bool: ... class CreateModel(ModelOperation): - fields: list[tuple[str, Field]] + fields: list[tuple[str, Field[Any, Any, Any]]] options: dict[str, Any] bases: Sequence[type[Model] | str] | None managers: Sequence[tuple[str, Manager]] | None def __init__( self, name: str, - fields: list[tuple[str, Field]], + fields: list[tuple[str, Field[Any, Any, Any]]], options: dict[str, Any] | None = None, bases: Sequence[type[Any] | str] | None = None, managers: Sequence[tuple[str, Manager]] | None = None, diff --git a/django-stubs/db/migrations/questioner.pyi b/django-stubs/db/migrations/questioner.pyi index 3822855a1..ba5a29ce7 100644 --- a/django-stubs/db/migrations/questioner.pyi +++ b/django-stubs/db/migrations/questioner.pyi @@ -18,7 +18,9 @@ class MigrationQuestioner: def ask_initial(self, app_label: str) -> bool: ... def ask_not_null_addition(self, field_name: str, model_name: str) -> Any: ... def ask_not_null_alteration(self, field_name: Any, model_name: Any) -> Any: ... - def ask_rename(self, model_name: str, old_name: str, new_name: str, field_instance: Field) -> bool: ... + def ask_rename( + self, model_name: str, old_name: str, new_name: str, field_instance: Field[Any, Any, Any] + ) -> bool: ... def ask_rename_model(self, old_model_state: ModelState, new_model_state: ModelState) -> bool: ... def ask_merge(self, app_label: str) -> bool: ... def ask_auto_now_add_addition(self, field_name: str, model_name: str) -> Any: ... diff --git a/django-stubs/db/migrations/state.pyi b/django-stubs/db/migrations/state.pyi index df25c5686..b9952b013 100644 --- a/django-stubs/db/migrations/state.pyi +++ b/django-stubs/db/migrations/state.pyi @@ -16,7 +16,7 @@ class AppConfigStub(AppConfig): class ModelState: name: str app_label: str - fields: dict[str, Field] + fields: dict[str, Field[Any, Any, Any]] options: dict[str, Any] bases: Sequence[type[Model] | str] managers: list[tuple[str, Manager]] @@ -24,7 +24,7 @@ class ModelState: self, app_label: str, name: str, - fields: list[tuple[str, Field]] | dict[str, Field], + fields: list[tuple[str, Field[Any, Any, Any]]] | dict[str, Field[Any, Any, Any]], options: dict[str, Any] | None = None, bases: Sequence[type[Model] | str] | None = None, managers: list[tuple[str, Manager]] | None = None, @@ -33,7 +33,7 @@ class ModelState: def construct_managers(self) -> Iterator[tuple[str, Manager]]: ... @classmethod def from_model(cls, model: type[Model], exclude_rels: bool = False) -> ModelState: ... - def get_field(self, field_name: str) -> Field: ... + def get_field(self, field_name: str) -> Field[Any, Any, Any]: ... @cached_property def name_lower(self) -> str: ... def render(self, apps: Apps) -> Any: ... @@ -76,15 +76,24 @@ class ProjectState: def add_constraint(self, app_label: str, model_name: str, constraint: Any) -> None: ... def remove_constraint(self, app_label: str, model_name: str, constraint_name: str) -> None: ... def alter_constraint(self, app_label: str, model_name: str, constraint_name: str, constraint: Any) -> None: ... - def add_field(self, app_label: str, model_name: str, name: str, field: Field, preserve_default: Any) -> None: ... + def add_field( + self, app_label: str, model_name: str, name: str, field: Field[Any, Any, Any], preserve_default: Any + ) -> None: ... def remove_field(self, app_label: str, model_name: str, name: str) -> None: ... - def alter_field(self, app_label: str, model_name: str, name: str, field: Field, preserve_default: Any) -> None: ... + def alter_field( + self, app_label: str, model_name: str, name: str, field: Field[Any, Any, Any], preserve_default: Any + ) -> None: ... def rename_field(self, app_label: str, model_name: str, old_name: str, new_name: str) -> None: ... def update_model_field_relation( - self, model: type[Model], model_key: tuple[str, str], field_name: str, field: Field, concretes: Any + self, + model: type[Model], + model_key: tuple[str, str], + field_name: str, + field: Field[Any, Any, Any], + concretes: Any, ) -> None: ... def resolve_model_field_relations( - self, model_key: tuple[str, str], field_name: str, field: Field, concretes: Any | None = None + self, model_key: tuple[str, str], field_name: str, field: Field[Any, Any, Any], concretes: Any | None = None ) -> None: ... def resolve_model_relations(self, model_key: tuple[str, str], concretes: Any | None = None) -> None: ... def resolve_fields_and_relations(self) -> None: ... diff --git a/django-stubs/db/migrations/utils.pyi b/django-stubs/db/migrations/utils.pyi index 1737c12ae..3790e0308 100644 --- a/django-stubs/db/migrations/utils.pyi +++ b/django-stubs/db/migrations/utils.pyi @@ -23,14 +23,16 @@ class FieldReference(NamedTuple): def field_references( model_tuple: tuple[str, str], - field: Field, + field: Field[Any, Any, Any], reference_model_tuple: tuple[str, str], reference_field_name: str | None = None, - reference_field: Field | None = None, + reference_field: Field[Any, Any, Any] | None = None, ) -> Literal[False] | FieldReference: ... def get_references( state: ProjectState, model_tuple: tuple[str, str], - field_tuple: tuple[()] | tuple[str, Field] = (), -) -> Iterator[tuple[ModelState, str, Field, FieldReference]]: ... -def field_is_referenced(state: ProjectState, model_tuple: tuple[str, str], field_tuple: tuple[str, Field]) -> bool: ... + field_tuple: tuple[()] | tuple[str, Field[Any, Any, Any]] = (), +) -> Iterator[tuple[ModelState, str, Field[Any, Any, Any], FieldReference]]: ... +def field_is_referenced( + state: ProjectState, model_tuple: tuple[str, str], field_tuple: tuple[str, Field[Any, Any, Any]] +) -> bool: ... diff --git a/django-stubs/db/models/base.pyi b/django-stubs/db/models/base.pyi index 87b2e2d0d..cebab7fcc 100644 --- a/django-stubs/db/models/base.pyi +++ b/django-stubs/db/models/base.pyi @@ -104,14 +104,14 @@ class Model(AltersData, metaclass=ModelBase): base_qs: QuerySet[Self], using: str | None, pk_val: Any, - values: Collection[tuple[Field, type[Model] | None, Any]], + values: Collection[tuple[Field[Any, Any, Any], type[Model] | None, Any]], update_fields: Iterable[str] | None, forced_update: bool, returning_fields: Sequence[Field[Any, Any]], ) -> list[Sequence[Any]]: ... def delete(self, using: Any | None = None, keep_parents: bool = False) -> tuple[int, dict[str, int]]: ... async def adelete(self, using: Any | None = None, keep_parents: bool = False) -> tuple[int, dict[str, int]]: ... - def prepare_database_save(self, field: Field) -> Any: ... + def prepare_database_save(self, field: Field[Any, Any, Any]) -> Any: ... def clean(self) -> None: ... def validate_unique(self, exclude: Collection[str] | None = None) -> None: ... def date_error_message(self, lookup_type: str, field_name: str, unique_for: str) -> ValidationError: ... diff --git a/django-stubs/db/models/deletion.pyi b/django-stubs/db/models/deletion.pyi index 56143fe9c..ccd9ae856 100644 --- a/django-stubs/db/models/deletion.pyi +++ b/django-stubs/db/models/deletion.pyi @@ -46,7 +46,7 @@ def RESTRICT( using: str, ) -> None: ... def SET(value: Any) -> Callable[..., Any]: ... -def get_candidate_relations_to_delete(opts: Options) -> Iterable[Field]: ... +def get_candidate_relations_to_delete(opts: Options) -> Iterable[Field[Any, Any, Any]]: ... class ProtectedError(IntegrityError): protected_objects: set[Model] @@ -60,8 +60,8 @@ class Collector: using: str origin: Model | QuerySet[Model] | None data: dict[type[Model], set[Model] | list[Model]] - field_updates: defaultdict[tuple[Field, Any], list[Model]] - restricted_objects: defaultdict[Model, defaultdict[Field, set[Model]]] + field_updates: defaultdict[tuple[Field[Any, Any, Any], Any], list[Model]] + restricted_objects: defaultdict[Model, defaultdict[Field[Any, Any, Any], set[Model]]] fast_deletes: list[QuerySet[Model]] dependencies: defaultdict[Model, set[Model]] def __init__(self, using: str, origin: Model | QuerySet[Model] | None = None) -> None: ... @@ -73,13 +73,15 @@ class Collector: reverse_dependency: bool = False, ) -> list[Model]: ... def add_dependency(self, model: type[Model], dependency: type[Model], reverse_dependency: bool = False) -> None: ... - def add_field_update(self, field: Field, value: Any, objs: _IndexableCollection[Model]) -> None: ... - def add_restricted_objects(self, field: Field, objs: _IndexableCollection[Model]) -> None: ... + def add_field_update(self, field: Field[Any, Any, Any], value: Any, objs: _IndexableCollection[Model]) -> None: ... + def add_restricted_objects(self, field: Field[Any, Any, Any], objs: _IndexableCollection[Model]) -> None: ... def clear_restricted_objects_from_set(self, model: type[Model], objs: set[Model]) -> None: ... def clear_restricted_objects_from_queryset(self, model: type[Model], qs: QuerySet[Model]) -> None: ... - def can_fast_delete(self, objs: Model | Iterable[Model], from_field: Field | None = None) -> bool: ... + def can_fast_delete( + self, objs: Model | Iterable[Model], from_field: Field[Any, Any, Any] | None = None + ) -> bool: ... def get_del_batches( - self, objs: _IndexableCollection[Model], fields: Iterable[Field] + self, objs: _IndexableCollection[Model], fields: Iterable[Field[Any, Any, Any]] ) -> Sequence[Sequence[Model]]: ... def collect( self, @@ -93,7 +95,10 @@ class Collector: fail_on_restricted: bool = True, ) -> None: ... def related_objects( - self, related_model: type[Model], related_fields: Iterable[Field], objs: _IndexableCollection[Model] + self, + related_model: type[Model], + related_fields: Iterable[Field[Any, Any, Any]], + objs: _IndexableCollection[Model], ) -> QuerySet[Model]: ... def instances_with_model(self) -> Iterator[tuple[type[Model], Model]]: ... def sort(self) -> None: ... diff --git a/django-stubs/db/models/expressions.pyi b/django-stubs/db/models/expressions.pyi index 2f56d4e9c..b9e8fb29b 100644 --- a/django-stubs/db/models/expressions.pyi +++ b/django-stubs/db/models/expressions.pyi @@ -16,7 +16,7 @@ from django.utils.deconstruct import _Deconstructible from django.utils.functional import cached_property from typing_extensions import Never, Self, TypeVar, override -_OutputField = TypeVar("_OutputField", bound=Field, default=Field) +_OutputField = TypeVar("_OutputField", bound=Field[Any, Any, Any], default=Field[Any, Any, Any]) _Numeric: TypeAlias = float | Decimal _AddOperand: TypeAlias = datetime.datetime | datetime.timedelta | _Numeric | Combinable | str | None _SubOperand: TypeAlias = datetime.datetime | datetime.timedelta | _Numeric | Combinable | None @@ -74,7 +74,7 @@ class BaseExpression: constraint_validation_compatible: bool set_returning: bool allows_composite_expressions: bool - def __init__(self, output_field: Field | None = None) -> None: ... + def __init__(self, output_field: Field[Any, Any, Any] | None = None) -> None: ... def get_db_converters(self, connection: BaseDatabaseWrapper) -> list[Callable]: ... def get_source_expressions(self) -> list[Any]: ... def set_source_expressions(self, exprs: Sequence[Combinable | Expression]) -> None: ... @@ -98,9 +98,9 @@ class BaseExpression: @property def conditional(self) -> bool: ... @property - def field(self) -> Field: ... + def field(self) -> Field[Any, Any, Any]: ... @cached_property - def output_field(self) -> Field: ... + def output_field(self) -> Field[Any, Any, Any]: ... @cached_property def convert_value(self) -> Callable: ... def get_lookup(self, lookup: str) -> type[Lookup] | None: ... @@ -111,7 +111,7 @@ class BaseExpression: def copy(self) -> Self: ... def prefix_references(self, prefix: str) -> Self: ... def get_group_by_cols(self) -> list[BaseExpression]: ... - def get_source_fields(self) -> list[Field | None]: ... + def get_source_fields(self) -> list[Field[Any, Any, Any] | None]: ... def asc(self, **kwargs: Any) -> OrderBy: ... def desc(self, **kwargs: Any) -> OrderBy: ... def reverse_ordering(self) -> BaseExpression: ... @@ -124,10 +124,10 @@ class Expression(_Deconstructible, BaseExpression, Combinable): def identity(self) -> tuple[Any, ...]: ... def register_combinable_fields( - lhs: type[Field], + lhs: type[Field[Any, Any, Any]], connector: str, - rhs: type[Field], - result: type[Field], + rhs: type[Field[Any, Any, Any]], + result: type[Field[Any, Any, Any]], ) -> None: ... class CombinedExpression(SQLiteNumericMixin, Expression): @@ -137,7 +137,9 @@ class CombinedExpression(SQLiteNumericMixin, Expression): connector: str lhs: Combinable rhs: Combinable - def __init__(self, lhs: Combinable, connector: str, rhs: Combinable, output_field: Field | None = None) -> None: ... + def __init__( + self, lhs: Combinable, connector: str, rhs: Combinable, output_field: Field[Any, Any, Any] | None = None + ) -> None: ... class DurationExpression(CombinedExpression): def compile(self, side: Combinable, compiler: SQLCompiler, connection: BaseDatabaseWrapper) -> _AsSqlType: ... @@ -214,7 +216,7 @@ class Func(SQLiteNumericMixin, Expression, Generic[_OutputField]): class Value(Expression): value: Any for_save: bool - def __init__(self, value: Any, output_field: Field | None = None) -> None: ... + def __init__(self, value: Any, output_field: Field[Any, Any, Any] | None = None) -> None: ... @property @override def empty_result_set_value(self) -> Any: ... @@ -222,28 +224,34 @@ class Value(Expression): class RawSQL(Expression): params: list[Any] sql: str - def __init__(self, sql: str, params: Sequence[Any], output_field: Field | None = None) -> None: ... + def __init__(self, sql: str, params: Sequence[Any], output_field: Field[Any, Any, Any] | None = None) -> None: ... class Star(Expression): ... class DatabaseDefault(Expression): - def __init__(self, expression: Expression, output_field: Field | None = None) -> None: ... + def __init__(self, expression: Expression, output_field: Field[Any, Any, Any] | None = None) -> None: ... class Col(Expression): - target: Field + target: Field[Any, Any, Any] alias: str contains_column_references: Literal[True] possibly_multivalued: Literal[False] - def __init__(self, alias: str, target: Field, output_field: Field | None = None) -> None: ... + def __init__( + self, alias: str, target: Field[Any, Any, Any], output_field: Field[Any, Any, Any] | None = None + ) -> None: ... @override def relabeled_clone(self, relabels: Mapping[str, str]) -> Self: ... class ColPairs(Expression): alias: str - targets: Sequence[Field] - sources: Sequence[Field] + targets: Sequence[Field[Any, Any, Any]] + sources: Sequence[Field[Any, Any, Any]] def __init__( - self, alias: str, targets: Sequence[Field], sources: Sequence[Field], output_field: Field | None + self, + alias: str, + targets: Sequence[Field[Any, Any, Any]], + sources: Sequence[Field[Any, Any, Any]], + output_field: Field[Any, Any, Any] | None, ) -> None: ... def __len__(self) -> int: ... def __iter__(self) -> Iterator[Col]: ... @@ -258,7 +266,7 @@ class Ref(Expression): class ExpressionList(Func): def __init__( - self, *expressions: BaseExpression | Combinable, output_field: Field | None = None, **extra: Any + self, *expressions: BaseExpression | Combinable, output_field: Field[Any, Any, Any] | None = None, **extra: Any ) -> None: ... class OrderByList(ExpressionList): @@ -271,7 +279,7 @@ class ExpressionWrapper(Expression, Generic[_E]): @property @override def allowed_default(self) -> bool: ... # type: ignore[override] - def __init__(self, expression: _E, output_field: Field) -> None: ... + def __init__(self, expression: _E, output_field: Field[Any, Any, Any]) -> None: ... expression: _E class NegatedExpression(ExpressionWrapper[_E]): @@ -302,7 +310,7 @@ class Case(Expression): default: Any extra: Any def __init__( - self, *cases: Any, default: Any | None = None, output_field: Field | None = None, **extra: Any + self, *cases: Any, default: Any | None = None, output_field: Field[Any, Any, Any] | None = None, **extra: Any ) -> None: ... @override def as_sql( @@ -370,7 +378,7 @@ class Window(SQLiteNumericMixin, Expression): partition_by: _ExprListCompatible | None = None, order_by: _ExprListCompatible | None = None, frame: WindowFrame | None = None, - output_field: Field | None = None, + output_field: Field[Any, Any, Any] | None = None, ) -> None: ... @override def as_sql( diff --git a/django-stubs/db/models/fields/__init__.pyi b/django-stubs/db/models/fields/__init__.pyi index 5fc8ca280..d88d7e879 100644 --- a/django-stubs/db/models/fields/__init__.pyi +++ b/django-stubs/db/models/fields/__init__.pyi @@ -3,7 +3,7 @@ import uuid from collections.abc import Callable, Iterable, Mapping, Sequence from datetime import date, time, timedelta from datetime import datetime as real_datetime -from typing import Any, ClassVar, Generic, Protocol, TypeAlias, overload, type_check_only +from typing import Any, ClassVar, Generic, Literal, Protocol, TypeAlias, overload, type_check_only from django import forms from django.core import validators # due to weird mypy.stubtest error @@ -29,7 +29,7 @@ _ChoicesList: TypeAlias = Sequence[_Choice] | Sequence[_ChoiceNamedGroup] _LimitChoicesTo: TypeAlias = Q | dict[str, Any] _LimitChoicesToCallable: TypeAlias = Callable[[], _LimitChoicesTo] -_F = TypeVar("_F", bound=Field, covariant=True) +_F = TypeVar("_F", bound=Field[Any, Any, Any], covariant=True) @type_check_only class _FieldDescriptor(Protocol[_F]): @@ -46,11 +46,13 @@ _ErrorMessagesMapping: TypeAlias = Mapping[str, _StrOrPromise] _ErrorMessagesDict: TypeAlias = dict[str, _StrOrPromise] # __set__ value type -_ST = TypeVar("_ST", contravariant=True) +_ST = TypeVar("_ST", contravariant=True, default=Any) # __get__ return type -_GT = TypeVar("_GT", covariant=True) +_GT = TypeVar("_GT", covariant=True, default=Any) +# null flag type +_NT = TypeVar("_NT", Literal[True], Literal[False], default=Literal[False]) -class Field(RegisterLookupMixin, Generic[_ST, _GT]): +class Field(RegisterLookupMixin, Generic[_ST, _GT, _NT]): """ Typing model fields. @@ -104,13 +106,10 @@ class Field(RegisterLookupMixin, Generic[_ST, _GT]): Notice, that this is not magic. This is how descriptors work with ``mypy``. - We also need ``_pyi_private_set_type`` attributes - and friends to help inside our plugin. + We also need ``_pyi_lookup_exact_type`` to help inside our plugin. It is required to enhance parts like ``filter`` queries. """ - _pyi_private_set_type: Any - _pyi_private_get_type: Any _pyi_lookup_exact_type: Any help_text: _StrOrPromise @@ -158,7 +157,7 @@ class Field(RegisterLookupMixin, Generic[_ST, _GT]): max_length: int | None = None, unique: bool = False, blank: bool = False, - null: bool = False, + null: _NT = ..., db_index: bool = False, rel: ForeignObjectRel | None = None, default: Any = ..., @@ -177,15 +176,30 @@ class Field(RegisterLookupMixin, Generic[_ST, _GT]): db_comment: str | None = None, db_default: type[NOT_PROVIDED] | Expression | _ST = ..., ) -> None: ... - def __set__(self, instance: Any, value: _ST) -> None: ... + @overload + @type_check_only + def __set__(self: Field[Any, Any, Literal[False]], instance: Any, value: _ST | Combinable) -> None: ... + @overload + @type_check_only + def __set__(self: Field[Any, Any, Literal[True]], instance: Any, value: _ST | Combinable | None) -> None: ... + @overload + @type_check_only + def __set__(self, instance: Any, value: _ST | Combinable) -> None: ... # class access @overload + @type_check_only def __get__(self, instance: None, owner: Any) -> _FieldDescriptor[Self]: ... - # Model instance access + # non-null Model instance access + @overload + @type_check_only + def __get__(self: Field[Any, Any, Literal[False]], instance: Model, owner: Any) -> _GT: ... + # nullable Model instance access @overload - def __get__(self, instance: Model, owner: Any) -> _GT: ... + @type_check_only + def __get__(self: Field[Any, Any, Literal[True]], instance: Model, owner: Any) -> _GT | None: ... # non-Model instances @overload + @type_check_only def __get__(self, instance: Any, owner: Any) -> Self: ... def check(self, **kwargs: Any) -> list[CheckMessage]: ... def get_col(self, alias: str, output_field: Field | None = None) -> Col: ... @@ -257,9 +271,10 @@ class Field(RegisterLookupMixin, Generic[_ST, _GT]): def value_from_object(self, obj: Model) -> _GT: ... def slice_expression(self, expression: Expression, start: int, length: int | None) -> Func: ... -class IntegerField(Field[_ST, _GT]): - _pyi_private_set_type: float | int | str | Combinable - _pyi_private_get_type: int +_ST_Int = TypeVar("_ST_Int", contravariant=True, default=float | int | str) +_GT_Int = TypeVar("_GT_Int", covariant=True, default=int) + +class IntegerField(Field[_ST_Int, _GT_Int, _NT]): _pyi_lookup_exact_type: str | int @override def formfield(self, **kwargs: Any) -> forms.Field | None: ... # type: ignore[override] @@ -267,38 +282,40 @@ class IntegerField(Field[_ST, _GT]): class PositiveIntegerRelDbTypeMixin: def rel_db_type(self, connection: BaseDatabaseWrapper) -> str: ... -class SmallIntegerField(IntegerField[_ST, _GT]): ... +class SmallIntegerField(IntegerField[_ST_Int, _GT_Int, _NT]): ... -class BigIntegerField(IntegerField[_ST, _GT]): +class BigIntegerField(IntegerField[_ST_Int, _GT_Int, _NT]): MAX_BIGINT: ClassVar[int] @override def formfield(self, **kwargs: Any) -> forms.Field | None: ... # type: ignore[override] -class PositiveIntegerField(PositiveIntegerRelDbTypeMixin, IntegerField[_ST, _GT]): +class PositiveIntegerField(PositiveIntegerRelDbTypeMixin, IntegerField[_ST_Int, _GT_Int, _NT]): integer_field_class: type[IntegerField] @override def formfield(self, **kwargs: Any) -> forms.Field | None: ... # type: ignore[override] -class PositiveSmallIntegerField(PositiveIntegerRelDbTypeMixin, SmallIntegerField[_ST, _GT]): +class PositiveSmallIntegerField(PositiveIntegerRelDbTypeMixin, SmallIntegerField[_ST_Int, _GT_Int, _NT]): integer_field_class: type[SmallIntegerField] @override def formfield(self, **kwargs: Any) -> forms.Field | None: ... # type: ignore[override] -class PositiveBigIntegerField(PositiveIntegerRelDbTypeMixin, BigIntegerField[_ST, _GT]): +class PositiveBigIntegerField(PositiveIntegerRelDbTypeMixin, BigIntegerField[_ST_Int, _GT_Int, _NT]): integer_field_class: type[BigIntegerField] @override def formfield(self, **kwargs: Any) -> forms.Field | None: ... # type: ignore[override] -class FloatField(Field[_ST, _GT]): - _pyi_private_set_type: float | int | str | Combinable - _pyi_private_get_type: float +_ST_Float = TypeVar("_ST_Float", contravariant=True, default=float | int | str) +_GT_Float = TypeVar("_GT_Float", covariant=True, default=float) + +class FloatField(Field[_ST_Float, _GT_Float, _NT]): _pyi_lookup_exact_type: float @override def formfield(self, **kwargs: Any) -> forms.Field | None: ... # type: ignore[override] -class DecimalField(Field[_ST, _GT]): - _pyi_private_set_type: str | float | decimal.Decimal | Combinable - _pyi_private_get_type: decimal.Decimal +_ST_Decimal = TypeVar("_ST_Decimal", contravariant=True, default=str | float | decimal.Decimal) +_GT_Decimal = TypeVar("_GT_Decimal", covariant=True, default=decimal.Decimal) + +class DecimalField(Field[_ST_Decimal, _GT_Decimal, _NT]): _pyi_lookup_exact_type: str | int | decimal.Decimal # attributes max_digits: int @@ -313,10 +330,10 @@ class DecimalField(Field[_ST, _GT]): primary_key: bool = ..., unique: bool = ..., blank: bool = ..., - null: bool = ..., + null: _NT = ..., db_index: bool = ..., default: Any = ..., - db_default: type[NOT_PROVIDED] | Expression | _ST = ..., + db_default: type[NOT_PROVIDED] | Expression | _ST_Decimal = ..., editable: bool = ..., auto_created: bool = ..., serialize: bool = ..., @@ -333,9 +350,10 @@ class DecimalField(Field[_ST, _GT]): @override def formfield(self, **kwargs: Any) -> forms.Field | None: ... # type: ignore[override] -class CharField(Field[_ST, _GT]): - _pyi_private_set_type: str | int | Combinable - _pyi_private_get_type: str +_ST_Char = TypeVar("_ST_Char", contravariant=True, default=str | int) +_GT_Char = TypeVar("_GT_Char", covariant=True, default=str) + +class CharField(Field[_ST_Char, _GT_Char, _NT]): # objects are converted to string before comparison _pyi_lookup_exact_type: Any def __init__( @@ -346,10 +364,10 @@ class CharField(Field[_ST, _GT]): max_length: int | None = ..., unique: bool = ..., blank: bool = ..., - null: bool = ..., + null: _NT = ..., db_index: bool = ..., default: Any = ..., - db_default: type[NOT_PROVIDED] | Expression | _ST = ..., + db_default: type[NOT_PROVIDED] | Expression | _ST_Char = ..., editable: bool = ..., auto_created: bool = ..., serialize: bool = ..., @@ -369,9 +387,9 @@ class CharField(Field[_ST, _GT]): @override def formfield(self, **kwargs: Any) -> forms.Field | None: ... # type: ignore[override] -class CommaSeparatedIntegerField(CharField[_ST, _GT]): ... +class CommaSeparatedIntegerField(CharField[_ST_Char, _GT_Char, _NT]): ... -class SlugField(CharField[_ST, _GT]): +class SlugField(CharField[_ST_Char, _GT_Char, _NT]): def __init__( self, verbose_name: _StrOrPromise | None = ..., @@ -379,9 +397,9 @@ class SlugField(CharField[_ST, _GT]): primary_key: bool = ..., unique: bool = ..., blank: bool = ..., - null: bool = ..., + null: _NT = ..., default: Any = ..., - db_default: type[NOT_PROVIDED] | Expression | _ST = ..., + db_default: type[NOT_PROVIDED] | Expression | _ST_Char = ..., editable: bool = ..., auto_created: bool = ..., serialize: bool = ..., @@ -403,12 +421,13 @@ class SlugField(CharField[_ST, _GT]): @override def formfield(self, **kwargs: Any) -> forms.Field | None: ... # type: ignore[override] -class EmailField(CharField[_ST, _GT]): - _pyi_private_set_type: str | Combinable +_ST_Email = TypeVar("_ST_Email", contravariant=True, default=str) + +class EmailField(CharField[_ST_Email, _GT_Char, _NT]): @override def formfield(self, **kwargs: Any) -> forms.Field | None: ... # type: ignore[override] -class URLField(CharField[_ST, _GT]): +class URLField(CharField[_ST_Char, _GT_Char, _NT]): def __init__( self, verbose_name: _StrOrPromise | None = None, @@ -418,11 +437,11 @@ class URLField(CharField[_ST, _GT]): max_length: int | None = ..., unique: bool = ..., blank: bool = ..., - null: bool = ..., + null: _NT = ..., db_index: bool = ..., rel: ForeignObjectRel | None = ..., default: Any = ..., - db_default: type[NOT_PROVIDED] | Expression | _ST = ..., + db_default: type[NOT_PROVIDED] | Expression | _ST_Char = ..., editable: bool = ..., serialize: bool = ..., unique_for_date: str | None = ..., @@ -440,9 +459,10 @@ class URLField(CharField[_ST, _GT]): @override def formfield(self, **kwargs: Any) -> forms.Field | None: ... # type: ignore[override] -class TextField(Field[_ST, _GT]): - _pyi_private_set_type: str | Combinable - _pyi_private_get_type: str +_ST_Text = TypeVar("_ST_Text", contravariant=True, default=str) +_GT_Text = TypeVar("_GT_Text", covariant=True, default=str) + +class TextField(Field[_ST_Text, _GT_Text, _NT]): # objects are converted to string before comparison _pyi_lookup_exact_type: Any def __init__( @@ -453,10 +473,10 @@ class TextField(Field[_ST, _GT]): max_length: int | None = ..., unique: bool = ..., blank: bool = ..., - null: bool = ..., + null: _NT = ..., db_index: bool = ..., default: Any = ..., - db_default: type[NOT_PROVIDED] | Expression | _ST = ..., + db_default: type[NOT_PROVIDED] | Expression | _ST_Text = ..., editable: bool = ..., auto_created: bool = ..., serialize: bool = ..., @@ -476,26 +496,28 @@ class TextField(Field[_ST, _GT]): @override def formfield(self, **kwargs: Any) -> forms.Field | None: ... # type: ignore[override] -class BooleanField(Field[_ST, _GT]): - _pyi_private_set_type: bool | Combinable - _pyi_private_get_type: bool +_ST_Bool = TypeVar("_ST_Bool", contravariant=True, default=bool) +_GT_Bool = TypeVar("_GT_Bool", covariant=True, default=bool) + +class BooleanField(Field[_ST_Bool, _GT_Bool, _NT]): _pyi_lookup_exact_type: bool @override def formfield(self, **kwargs: Any) -> forms.Field | None: ... # type: ignore[override] -class NullBooleanField(BooleanField[_ST, _GT]): - _pyi_private_set_type: bool | Combinable | None # type: ignore[assignment] - _pyi_private_get_type: bool | None # type: ignore[assignment] +_ST_NBool = TypeVar("_ST_NBool", contravariant=True, default=bool | None) +_GT_NBool = TypeVar("_GT_NBool", covariant=True, default=bool | None) + +class NullBooleanField(BooleanField[_ST_NBool, _GT_NBool, _NT]): _pyi_lookup_exact_type: bool | None # type: ignore[assignment] -class IPAddressField(Field[_ST, _GT]): - _pyi_private_set_type: str | Combinable - _pyi_private_get_type: str +_ST_IP = TypeVar("_ST_IP", contravariant=True, default=str) +_GT_IP = TypeVar("_GT_IP", covariant=True, default=str) -class GenericIPAddressField(Field[_ST, _GT]): - _pyi_private_set_type: str | int | Callable[..., Any] | Combinable - _pyi_private_get_type: str +class IPAddressField(Field[_ST_IP, _GT_IP, _NT]): ... +_ST_GenIP = TypeVar("_ST_GenIP", contravariant=True, default=str | int | Callable[..., Any]) + +class GenericIPAddressField(Field[_ST_GenIP, _GT_IP, _NT]): default_error_messages: ClassVar[_ErrorMessagesDict] unpack_ipv4: bool protocol: str @@ -508,10 +530,10 @@ class GenericIPAddressField(Field[_ST, _GT]): primary_key: bool = ..., unique: bool = ..., blank: bool = ..., - null: bool = ..., + null: _NT = ..., db_index: bool = ..., default: Any = ..., - db_default: type[NOT_PROVIDED] | Expression | _ST = ..., + db_default: type[NOT_PROVIDED] | Expression | _ST_GenIP = ..., editable: bool = ..., auto_created: bool = ..., serialize: bool = ..., @@ -529,9 +551,10 @@ class GenericIPAddressField(Field[_ST, _GT]): class DateTimeCheckMixin: def check(self, **kwargs: Any) -> list[CheckMessage]: ... -class DateField(DateTimeCheckMixin, Field[_ST, _GT]): - _pyi_private_set_type: str | date | Combinable - _pyi_private_get_type: date +_ST_Date = TypeVar("_ST_Date", contravariant=True, default=str | date) +_GT_Date = TypeVar("_GT_Date", covariant=True, default=date) + +class DateField(DateTimeCheckMixin, Field[_ST_Date, _GT_Date, _NT]): _pyi_lookup_exact_type: str | date auto_now: bool auto_now_add: bool @@ -546,10 +569,10 @@ class DateField(DateTimeCheckMixin, Field[_ST, _GT]): max_length: int | None = ..., unique: bool = ..., blank: bool = ..., - null: bool = ..., + null: _NT = ..., db_index: bool = ..., default: Any = ..., - db_default: type[NOT_PROVIDED] | Expression | _ST = ..., + db_default: type[NOT_PROVIDED] | Expression | _ST_Date = ..., editable: bool = ..., auto_created: bool = ..., serialize: bool = ..., @@ -566,9 +589,10 @@ class DateField(DateTimeCheckMixin, Field[_ST, _GT]): @override def formfield(self, **kwargs: Any) -> forms.Field | None: ... # type: ignore[override] -class TimeField(DateTimeCheckMixin, Field[_ST, _GT]): - _pyi_private_set_type: str | time | real_datetime | Combinable - _pyi_private_get_type: time +_ST_Time = TypeVar("_ST_Time", contravariant=True, default=str | time | real_datetime) +_GT_Time = TypeVar("_GT_Time", covariant=True, default=time) + +class TimeField(DateTimeCheckMixin, Field[_ST_Time, _GT_Time, _NT]): auto_now: bool auto_now_add: bool def __init__( @@ -581,10 +605,10 @@ class TimeField(DateTimeCheckMixin, Field[_ST, _GT]): primary_key: bool = ..., unique: bool = ..., blank: bool = ..., - null: bool = ..., + null: _NT = ..., db_index: bool = ..., default: Any = ..., - db_default: type[NOT_PROVIDED] | Expression | _ST = ..., + db_default: type[NOT_PROVIDED] | Expression | _ST_Time = ..., editable: bool = ..., auto_created: bool = ..., serialize: bool = ..., @@ -599,16 +623,18 @@ class TimeField(DateTimeCheckMixin, Field[_ST, _GT]): @override def formfield(self, **kwargs: Any) -> forms.Field | None: ... # type: ignore[override] -class DateTimeField(DateField[_ST, _GT]): - _pyi_private_set_type: str | real_datetime | date | Combinable - _pyi_private_get_type: real_datetime +_ST_DateTime = TypeVar("_ST_DateTime", contravariant=True, default=str | real_datetime | date) +_GT_DateTime = TypeVar("_GT_DateTime", covariant=True, default=real_datetime) + +class DateTimeField(DateField[_ST_DateTime, _GT_DateTime, _NT]): _pyi_lookup_exact_type: str | real_datetime @override def formfield(self, **kwargs: Any) -> forms.Field | None: ... # type: ignore[override] -class UUIDField(Field[_ST, _GT]): - _pyi_private_set_type: str | uuid.UUID - _pyi_private_get_type: uuid.UUID +_ST_UUID = TypeVar("_ST_UUID", contravariant=True, default=str | uuid.UUID) +_GT_UUID = TypeVar("_GT_UUID", covariant=True, default=uuid.UUID) + +class UUIDField(Field[_ST_UUID, _GT_UUID, _NT]): _pyi_lookup_exact_type: uuid.UUID | str def __init__( self, @@ -619,11 +645,11 @@ class UUIDField(Field[_ST, _GT]): max_length: int | None = ..., unique: bool = ..., blank: bool = ..., - null: bool = ..., + null: _NT = ..., db_index: bool = ..., rel: ForeignObjectRel | None = ..., default: Any = ..., - db_default: type[NOT_PROVIDED] | Expression | _ST = ..., + db_default: type[NOT_PROVIDED] | Expression | _ST_UUID = ..., editable: bool = ..., serialize: bool = ..., unique_for_date: str | None = ..., @@ -641,7 +667,7 @@ class UUIDField(Field[_ST, _GT]): @override def formfield(self, **kwargs: Any) -> forms.Field | None: ... # type: ignore[override] -class FilePathField(Field[_ST, _GT]): +class FilePathField(Field[_ST, _GT, _NT]): path: Any match: str | None recursive: bool @@ -661,7 +687,7 @@ class FilePathField(Field[_ST, _GT]): max_length: int = ..., unique: bool = ..., blank: bool = ..., - null: bool = ..., + null: _NT = ..., db_index: bool = ..., default: Any = ..., db_default: type[NOT_PROVIDED] | Expression | _ST = ..., @@ -679,12 +705,16 @@ class FilePathField(Field[_ST, _GT]): @override def formfield(self, **kwargs: Any) -> forms.Field | None: ... # type: ignore[override] -class BinaryField(Field[_ST, _GT]): - _pyi_private_get_type: bytes | memoryview +_ST_Binary = TypeVar("_ST_Binary", contravariant=True, default=bytes | bytearray | memoryview) +_GT_Binary = TypeVar("_GT_Binary", covariant=True, default=bytes | memoryview) + +class BinaryField(Field[_ST_Binary, _GT_Binary, _NT]): def get_placeholder(self, value: Any, compiler: SQLCompiler, connection: BaseDatabaseWrapper) -> str: ... -class DurationField(Field[_ST, _GT]): - _pyi_private_get_type: timedelta +_ST_Duration = TypeVar("_ST_Duration", contravariant=True, default=str | timedelta) +_GT_Duration = TypeVar("_GT_Duration", covariant=True, default=timedelta) + +class DurationField(Field[_ST_Duration, _GT_Duration, _NT]): @override def formfield(self, **kwargs: Any) -> forms.Field | None: ... # type: ignore[override] @@ -700,13 +730,14 @@ class AutoFieldMixin: class AutoFieldMeta(type): ... -class AutoField(AutoFieldMixin, IntegerField[_ST, _GT], metaclass=AutoFieldMeta): # type: ignore[misc] - _pyi_private_set_type: Combinable | int | str - _pyi_private_get_type: int +_ST_Auto = TypeVar("_ST_Auto", contravariant=True, default=int | str) +_GT_Auto = TypeVar("_GT_Auto", covariant=True, default=int) + +class AutoField(AutoFieldMixin, IntegerField[_ST_Auto, _GT_Auto, _NT], metaclass=AutoFieldMeta): # type: ignore[misc] _pyi_lookup_exact_type: str | int -class BigAutoField(AutoFieldMixin, BigIntegerField[_ST, _GT]): ... # type: ignore[misc] -class SmallAutoField(AutoFieldMixin, SmallIntegerField[_ST, _GT]): ... # type: ignore[misc] +class BigAutoField(AutoFieldMixin, BigIntegerField[_ST_Auto, _GT_Auto, _NT]): ... # type: ignore[misc] +class SmallAutoField(AutoFieldMixin, SmallIntegerField[_ST_Auto, _GT_Auto, _NT]): ... # type: ignore[misc] __all__ = [ "BLANK_CHOICE_DASH", diff --git a/django-stubs/db/models/fields/files.pyi b/django-stubs/db/models/fields/files.pyi index 4bcc92686..d8f06978c 100644 --- a/django-stubs/db/models/fields/files.pyi +++ b/django-stubs/db/models/fields/files.pyi @@ -8,7 +8,7 @@ from django.core.files.images import ImageFile from django.core.files.storage import Storage from django.db.models.base import Model from django.db.models.expressions import Expression -from django.db.models.fields import NOT_PROVIDED, Field, _ErrorMessagesMapping +from django.db.models.fields import _NT, _ST, NOT_PROVIDED, Field, _ErrorMessagesMapping, _FieldDescriptor from django.db.models.query_utils import DeferredAttribute from django.db.models.utils import AltersData from django.utils._os import _PathCompatible @@ -60,7 +60,10 @@ _M = TypeVar("_M", bound=Model, contravariant=True) class _UploadToCallable(Protocol[_M]): def __call__(self, instance: _M, filename: str, /) -> _PathCompatible: ... -class FileField(Field[Any, Any]): +# __get__ return type +_GT_File = TypeVar("_GT_File", covariant=True, default=FieldFile) + +class FileField(Field[_ST, _GT_File, _NT]): attr_class: type[FieldFile] descriptor_class: type[FileDescriptor] storage: Storage @@ -75,10 +78,10 @@ class FileField(Field[Any, Any]): max_length: int | None = ..., unique: bool = ..., blank: bool = ..., - null: bool = ..., + null: _NT = ..., db_index: bool = ..., default: Any = ..., - db_default: type[NOT_PROVIDED] | Expression | str = ..., + db_default: type[NOT_PROVIDED] | Expression | _ST = ..., editable: bool = ..., auto_created: bool = ..., serialize: bool = ..., @@ -93,16 +96,21 @@ class FileField(Field[Any, Any]): validators: Iterable[validators._ValidatorCallable] = ..., error_messages: _ErrorMessagesMapping | None = ..., ) -> None: ... + # At runtime, FileDescriptor.__get__ ALWAYS returns a FieldFile even when the underlying database value is NULL. + # It wraps None in FieldFile(instance, field, name=None). # class access @overload + @type_check_only @override - def __get__(self, instance: None, owner: Any) -> FileDescriptor: ... - # Model instance access + def __get__(self, instance: None, owner: Any) -> _FieldDescriptor[Self]: ... + # Model instance access — null=True does NOT add `| None` @overload + @type_check_only @override - def __get__(self, instance: Model, owner: Any) -> FieldFile: ... + def __get__(self, instance: Model, owner: Any) -> _GT_File: ... # non-Model instances @overload + @type_check_only @override def __get__(self, instance: Any, owner: Any) -> Self: ... @override @@ -121,7 +129,9 @@ class ImageFieldFile(ImageFile, FieldFile): @override def delete(self, save: bool = True) -> None: ... -class ImageField(FileField): +_GT_ImageFile = TypeVar("_GT_ImageFile", covariant=True, default=ImageFieldFile) + +class ImageField(FileField[_ST, _GT_ImageFile, _NT]): attr_class: type[ImageFieldFile] descriptor_class: type[ImageFileDescriptor] def __init__( @@ -130,20 +140,28 @@ class ImageField(FileField): name: str | None = None, width_field: str | None = None, height_field: str | None = None, - **kwargs: Any, + *, + max_length: int | None = ..., + unique: bool = ..., + blank: bool = ..., + null: _NT = ..., + db_index: bool = ..., + default: Any = ..., + db_default: type[NOT_PROVIDED] | Expression | _ST = ..., + editable: bool = ..., + auto_created: bool = ..., + serialize: bool = ..., + unique_for_date: str | None = ..., + unique_for_month: str | None = ..., + unique_for_year: str | None = ..., + choices: _Choices | None = ..., + help_text: _StrOrPromise = ..., + db_column: str | None = ..., + db_comment: str | None = ..., + db_tablespace: str | None = ..., + validators: Iterable[validators._ValidatorCallable] = ..., + error_messages: _ErrorMessagesMapping | None = ..., ) -> None: ... - # class access - @overload - @override - def __get__(self, instance: None, owner: Any) -> ImageFileDescriptor: ... - # Model instance access - @overload - @override - def __get__(self, instance: Model, owner: Any) -> ImageFieldFile: ... - # non-Model instances - @overload - @override - def __get__(self, instance: Any, owner: Any) -> Self: ... def update_dimension_fields(self, instance: Model, force: bool = False, *args: Any, **kwargs: Any) -> None: ... @override def formfield(self, **kwargs: Any) -> Any: ... # type: ignore[override] diff --git a/django-stubs/db/models/fields/json.pyi b/django-stubs/db/models/fields/json.pyi index ab4ff1b3c..6b1253f93 100644 --- a/django-stubs/db/models/fields/json.pyi +++ b/django-stubs/db/models/fields/json.pyi @@ -1,25 +1,23 @@ import json -from collections.abc import Callable +from collections.abc import Callable, Iterable from typing import Any, ClassVar +from django.core import validators from django.db.backends.base.base import BaseDatabaseWrapper from django.db.models import Model, lookups from django.db.models.expressions import Expression -from django.db.models.fields import TextField +from django.db.models.fields import _NT, Field, TextField, _ErrorMessagesMapping +from django.db.models.fields.mixins import CheckFieldDefaultMixin from django.db.models.lookups import FieldGetDbPrepValueMixin, PostgresOperatorLookup, Transform from django.db.models.sql.compiler import SQLCompiler, _AsSqlType +from django.utils.choices import _ChoicesInput from django.utils.functional import _StrOrPromise from typing_extensions import Self, TypeVar, override -from . import Field -from .mixins import CheckFieldDefaultMixin +_ST_JSON = TypeVar("_ST_JSON", contravariant=True, default=Any) +_GT_JSON = TypeVar("_GT_JSON", covariant=True, default=Any) -# __set__ value type -_ST = TypeVar("_ST", contravariant=True, default=Any) -# __get__ return type -_GT = TypeVar("_GT", covariant=True, default=Any) - -class JSONField(CheckFieldDefaultMixin, Field[_ST, _GT]): +class JSONField(CheckFieldDefaultMixin, Field[_ST_JSON, _GT_JSON, _NT]): encoder: type[json.JSONEncoder] | None decoder: type[json.JSONDecoder] | None def __init__( @@ -28,7 +26,24 @@ class JSONField(CheckFieldDefaultMixin, Field[_ST, _GT]): name: str | None = None, encoder: type[json.JSONEncoder] | None = None, decoder: type[json.JSONDecoder] | None = None, - **kwargs: Any, + *, + primary_key: bool = ..., + unique: bool = ..., + blank: bool = ..., + null: _NT = ..., + db_index: bool = ..., + default: Any = ..., + db_default: Any = ..., + editable: bool = ..., + auto_created: bool = ..., + serialize: bool = ..., + choices: _ChoicesInput | None = ..., + help_text: _StrOrPromise = ..., + db_column: str | None = ..., + db_comment: str | None = ..., + db_tablespace: str | None = ..., + validators: Iterable[validators._ValidatorCallable] = ..., + error_messages: _ErrorMessagesMapping | None = ..., ) -> None: ... def from_db_value(self, value: str | None, expression: Expression, connection: BaseDatabaseWrapper) -> Any: ... @override diff --git a/django-stubs/db/models/fields/related.pyi b/django-stubs/db/models/fields/related.pyi index efba88ee2..d095174c3 100644 --- a/django-stubs/db/models/fields/related.pyi +++ b/django-stubs/db/models/fields/related.pyi @@ -1,5 +1,5 @@ from collections.abc import Callable, Iterable, Sequence -from typing import Any, Generic, Literal, overload +from typing import Any, Generic, Literal, overload, type_check_only from uuid import UUID from django import forms @@ -7,7 +7,7 @@ from django.core import validators # due to weird mypy.stubtest error from django.db.backends.base.base import BaseDatabaseWrapper from django.db.models.base import Model from django.db.models.expressions import Combinable, Expression -from django.db.models.fields import NOT_PROVIDED, Field, _AllLimitChoicesTo, _ErrorMessagesMapping, _LimitChoicesTo +from django.db.models.fields import _NT, NOT_PROVIDED, Field, _AllLimitChoicesTo, _ErrorMessagesMapping, _LimitChoicesTo from django.db.models.fields.mixins import FieldCacheMixin from django.db.models.fields.related_descriptors import ForeignKeyDeferredAttribute, ManyRelatedManager from django.db.models.fields.related_descriptors import ForwardManyToOneDescriptor as ForwardManyToOneDescriptor @@ -33,11 +33,11 @@ def lazy_related_operation( ) -> None: ... # __set__ value type -_ST = TypeVar("_ST", contravariant=True) +_ST = TypeVar("_ST", contravariant=True, default=Any) # __get__ return type -_GT = TypeVar("_GT", covariant=True, default=_ST) +_GT = TypeVar("_GT", covariant=True, default=Any) -class RelatedField(FieldCacheMixin, Field[_ST, _GT]): +class RelatedField(FieldCacheMixin, Field[_ST, _GT, _NT]): one_to_many: bool one_to_one: bool many_to_many: bool @@ -58,7 +58,7 @@ class RelatedField(FieldCacheMixin, Field[_ST, _GT]): max_length: int | None = ..., unique: bool = ..., blank: bool = ..., - null: bool = ..., + null: _NT = ..., db_index: bool = ..., rel: ForeignObjectRel | None = ..., default: Any = ..., @@ -95,7 +95,7 @@ class RelatedField(FieldCacheMixin, Field[_ST, _GT]): @property def target_field(self) -> Field: ... -class ForeignObject(RelatedField[_ST, _GT]): +class ForeignObject(RelatedField[_ST, _GT, _NT]): remote_field: ForeignObjectRel rel_class: type[ForeignObjectRel] column: None @@ -119,7 +119,7 @@ class ForeignObject(RelatedField[_ST, _GT]): primary_key: bool = ..., unique: bool = ..., blank: bool = ..., - null: bool = ..., + null: _NT = ..., db_index: bool = ..., default: Any = ..., db_default: type[NOT_PROVIDED] | Expression | _ST = ..., @@ -134,16 +134,36 @@ class ForeignObject(RelatedField[_ST, _GT]): error_messages: _ErrorMessagesMapping | None = ..., db_comment: str | None = ..., ) -> None: ... + @overload + @type_check_only + @override + def __set__(self: ForeignObject[Any, Any, Literal[False]], instance: Any, value: _ST | Combinable) -> None: ... + @overload + @type_check_only + def __set__( + self: ForeignObject[Any, Any, Literal[True]], instance: Any, value: _ST | Combinable | None + ) -> None: ... + @overload + @type_check_only + def __set__(self, instance: Any, value: _ST | Combinable) -> None: ... # class access @overload + @type_check_only @override def __get__(self, instance: None, owner: Any) -> ForwardManyToOneDescriptor[Self]: ... - # Model instance access + # non-null Model instance access + @overload + @type_check_only + @override + def __get__(self: ForeignObject[Any, Any, Literal[False]], instance: Model, owner: Any) -> _GT: ... + # nullable Model instance access @overload + @type_check_only @override - def __get__(self, instance: Model, owner: Any) -> _GT: ... + def __get__(self: ForeignObject[Any, Any, Literal[True]], instance: Model, owner: Any) -> _GT | None: ... # non-Model instances @overload + @type_check_only @override def __get__(self, instance: Any, owner: Any) -> Self: ... def resolve_related_fields(self) -> list[tuple[Field, Field]]: ... @@ -176,10 +196,7 @@ class ForeignObject(RelatedField[_ST, _GT]): related_accessor_class: type[ReverseManyToOneDescriptor] requires_unique_target: bool -class ForeignKey(ForeignObject[_ST, _GT]): - _pyi_private_set_type: Any | Combinable - _pyi_private_get_type: Any - +class ForeignKey(ForeignObject[_ST, _GT, _NT]): descriptor_class: type[ForeignKeyDeferredAttribute] remote_field: ManyToOneRel rel_class: type[ManyToOneRel] @@ -202,7 +219,7 @@ class ForeignKey(ForeignObject[_ST, _GT]): max_length: int | None = ..., unique: bool = ..., blank: bool = ..., - null: bool = ..., + null: _NT = ..., db_index: bool = ..., default: Any = ..., db_default: type[NOT_PROVIDED] | Expression | _ST = ..., @@ -232,10 +249,7 @@ class ForeignKey(ForeignObject[_ST, _GT]): @override def get_attname_column(self) -> tuple[str, str]: ... # type: ignore[override] -class OneToOneField(ForeignKey[_ST, _GT]): - _pyi_private_set_type: Any | Combinable - _pyi_private_get_type: Any - +class OneToOneField(ForeignKey[_ST, _GT, _NT]): remote_field: OneToOneRel rel_class: type[OneToOneRel] def __init__( @@ -256,7 +270,7 @@ class OneToOneField(ForeignKey[_ST, _GT]): max_length: int | None = ..., unique: bool = ..., blank: bool = ..., - null: bool = ..., + null: _NT = ..., db_index: bool = ..., default: Any = ..., db_default: type[NOT_PROVIDED] | Expression | _ST = ..., @@ -274,16 +288,36 @@ class OneToOneField(ForeignKey[_ST, _GT]): error_messages: _ErrorMessagesMapping | None = ..., db_comment: str | None = ..., ) -> None: ... + @overload + @type_check_only + @override + def __set__(self: OneToOneField[Any, Any, Literal[False]], instance: Any, value: _ST | Combinable) -> None: ... + @overload + @type_check_only + def __set__( + self: OneToOneField[Any, Any, Literal[True]], instance: Any, value: _ST | Combinable | None + ) -> None: ... + @overload + @type_check_only + def __set__(self, instance: Any, value: _ST | Combinable) -> None: ... # class access @overload + @type_check_only @override def __get__(self, instance: None, owner: Any) -> ForwardOneToOneDescriptor[Self]: ... - # Model instance access + # non-null Model instance access + @overload + @type_check_only + @override + def __get__(self: OneToOneField[Any, Any, Literal[False]], instance: Model, owner: Any) -> _GT: ... + # nullable Model instance access @overload + @type_check_only @override - def __get__(self, instance: Model, owner: Any) -> _GT: ... + def __get__(self: OneToOneField[Any, Any, Literal[True]], instance: Model, owner: Any) -> _GT | None: ... # non-Model instances @overload + @type_check_only @override def __get__(self, instance: Any, owner: Any) -> Self: ... @override @@ -294,7 +328,7 @@ class OneToOneField(ForeignKey[_ST, _GT]): _Through = TypeVar("_Through", bound=Model) _To = TypeVar("_To", bound=Model) -class ManyToManyField(RelatedField[Any, Any], Generic[_To, _Through]): +class ManyToManyField(RelatedField[Any, Any, Literal[False]], Generic[_To, _Through]): has_null_arg: bool swappable: bool @@ -325,7 +359,7 @@ class ManyToManyField(RelatedField[Any, Any], Generic[_To, _Through]): max_length: int | None = ..., unique: bool = ..., blank: bool = ..., - null: bool = ..., + null: _NT = ..., db_index: bool = ..., default: Any = ..., editable: bool = ..., diff --git a/django-stubs/db/models/fields/related_descriptors.pyi b/django-stubs/db/models/fields/related_descriptors.pyi index a5771518a..5fd9d3546 100644 --- a/django-stubs/db/models/fields/related_descriptors.pyi +++ b/django-stubs/db/models/fields/related_descriptors.pyi @@ -13,7 +13,7 @@ from django.utils.functional import cached_property from typing_extensions import Never, Self, TypeVar, override _M = TypeVar("_M", bound=Model) -_F = TypeVar("_F", bound=Field) +_F = TypeVar("_F", bound=Field[Any, Any, Any]) _From = TypeVar("_From", bound=Model) _Through = TypeVar("_Through", bound=Model, default=Model) _To = TypeVar("_To", bound=Model) @@ -135,7 +135,7 @@ class ManyToManyDescriptor(ReverseManyToOneDescriptor, Generic[_To, _Through]): # 'field' here is 'rel.field' rel: ManyToManyRel # type: ignore[assignment] - field: ManyToManyField[_To, _Through] # type: ignore[assignment] + field: ManyToManyField[_To, _Through] reverse: bool def __init__(self, rel: ManyToManyRel, reverse: bool = False) -> None: ... @property diff --git a/django-stubs/db/models/fields/reverse_related.pyi b/django-stubs/db/models/fields/reverse_related.pyi index 65e2fc8a9..7b501a637 100644 --- a/django-stubs/db/models/fields/reverse_related.pyi +++ b/django-stubs/db/models/fields/reverse_related.pyi @@ -53,7 +53,7 @@ class ForeignObjectRel(FieldCacheMixin): @property def remote_field(self) -> ForeignObject: ... @property - def target_field(self) -> Field: ... + def target_field(self) -> Field[Any, Any, Any]: ... @cached_property def related_model(self) -> type[Model]: ... @cached_property @@ -79,7 +79,7 @@ class ForeignObjectRel(FieldCacheMixin): limit_choices_to: _LimitChoicesTo | None = None, ordering: Sequence[_OrderByFieldName] = (), ) -> _ChoicesList: ... - def get_joining_fields(self) -> tuple[tuple[Field, Field], ...]: ... + def get_joining_fields(self) -> tuple[tuple[Field[Any, Any, Any], Field[Any, Any, Any]], ...]: ... def get_extra_restriction(self, alias: str, related_alias: str) -> StartsWith | WhereNode | None: ... def set_field_name(self) -> None: ... @property @@ -104,7 +104,7 @@ class ManyToOneRel(ForeignObjectRel): parent_link: bool = False, on_delete: Callable[..., Any] | None = None, ) -> None: ... - def get_related_field(self) -> Field: ... + def get_related_field(self) -> Field[Any, Any, Any]: ... @override def get_accessor_name(self, model: type[Model] | None = None) -> str: ... @property @@ -126,7 +126,7 @@ class OneToOneRel(ManyToOneRel): ) -> None: ... class ManyToManyRel(ForeignObjectRel): - field: ManyToManyField[Any, Any] # type: ignore[assignment] + field: ManyToManyField[Any, Any] through: type[Model] | None through_fields: tuple[str, str] | None db_constraint: bool @@ -142,7 +142,7 @@ class ManyToManyRel(ForeignObjectRel): through_fields: tuple[str, str] | None = None, db_constraint: bool = True, ) -> None: ... - def get_related_field(self) -> Field: ... + def get_related_field(self) -> Field[Any, Any, Any]: ... @property @override def identity(self) -> tuple[Any, ...]: ... diff --git a/django-stubs/db/models/fields/tuple_lookups.pyi b/django-stubs/db/models/fields/tuple_lookups.pyi index c8b04b8f3..c016e5769 100644 --- a/django-stubs/db/models/fields/tuple_lookups.pyi +++ b/django-stubs/db/models/fields/tuple_lookups.pyi @@ -10,7 +10,7 @@ from typing_extensions import override class Tuple(Func): function: str - output_field: Field + output_field: Field[Any, Any, Any] def __len__(self) -> int: ... def __iter__(self) -> Iterator[Expression]: ... @override diff --git a/django-stubs/db/models/functions/comparison.pyi b/django-stubs/db/models/functions/comparison.pyi index 70117dfa6..c62e3fda5 100644 --- a/django-stubs/db/models/functions/comparison.pyi +++ b/django-stubs/db/models/functions/comparison.pyi @@ -9,7 +9,7 @@ from django.db.models.sql.compiler import SQLCompiler, _AsSqlType from typing_extensions import override class Cast(Func): - def __init__(self, expression: Combinable | str, output_field: str | Field) -> None: ... + def __init__(self, expression: Combinable | str, output_field: str | Field[Any, Any, Any]) -> None: ... @override def as_sql(self, compiler: SQLCompiler, connection: BaseDatabaseWrapper, **extra_context: Any) -> _AsSqlType: ... # type: ignore[override] def as_mysql(self, compiler: SQLCompiler, connection: BaseDatabaseWrapper, **extra_context: Any) -> _AsSqlType: ... diff --git a/django-stubs/db/models/functions/datetime.pyi b/django-stubs/db/models/functions/datetime.pyi index 167c15455..91a0d01ac 100644 --- a/django-stubs/db/models/functions/datetime.pyi +++ b/django-stubs/db/models/functions/datetime.pyi @@ -49,7 +49,7 @@ class TruncBase(TimezoneMixin, Transform): def __init__( self, expression: Combinable | str, - output_field: Field | None = None, + output_field: Field[Any, Any, Any] | None = None, tzinfo: tzinfo | None = None, **extra: Any, ) -> None: ... @@ -61,7 +61,7 @@ class Trunc(TruncBase): self, expression: Combinable | str, kind: str, - output_field: Field | None = None, + output_field: Field[Any, Any, Any] | None = None, tzinfo: tzinfo | None = None, **extra: Any, ) -> None: ... diff --git a/django-stubs/db/models/functions/text.pyi b/django-stubs/db/models/functions/text.pyi index 390e0a1e1..5307a64b7 100644 --- a/django-stubs/db/models/functions/text.pyi +++ b/django-stubs/db/models/functions/text.pyi @@ -3,9 +3,11 @@ from typing import Any, ClassVar from django.db import models from django.db.backends.base.base import BaseDatabaseWrapper from django.db.models import Func, Transform -from django.db.models.expressions import Combinable, Expression, Value, _OutputField +from django.db.models.expressions import Combinable, Expression, Value from django.db.models.sql.compiler import SQLCompiler, _AsSqlType -from typing_extensions import override +from typing_extensions import TypeVar, override + +_SubstrOutputField = TypeVar("_SubstrOutputField", bound=models.Field, default=models.CharField) class MySQLSHA2Mixin: def as_mysql(self, compiler: SQLCompiler, connection: BaseDatabaseWrapper, **extra_context: Any) -> _AsSqlType: ... @@ -36,14 +38,14 @@ class ConcatPair(Func): class Concat(Func): def __init__(self, *expressions: Any, **extra: Any) -> None: ... -class Left(Func[_OutputField]): +class Left(Func[_SubstrOutputField]): output_field: ClassVar[models.CharField] def __init__( self, expression: Combinable | str, length: Expression | int, *, - output_field: _OutputField | None = None, + output_field: _SubstrOutputField | None = None, **extra: Any, ) -> None: ... def get_substr(self) -> Substr: ... @@ -79,7 +81,7 @@ class Replace(Func): class Reverse(Transform): def as_oracle(self, compiler: SQLCompiler, connection: BaseDatabaseWrapper, **extra_context: Any) -> _AsSqlType: ... -class Right(Left[_OutputField]): +class Right(Left[_SubstrOutputField]): @override def get_substr(self) -> Substr: ... @@ -100,7 +102,7 @@ class StrIndex(Func): self, compiler: SQLCompiler, connection: BaseDatabaseWrapper, **extra_context: Any ) -> _AsSqlType: ... -class Substr(Func[_OutputField]): +class Substr(Func[_SubstrOutputField]): output_field: ClassVar[models.CharField] def __init__( self, @@ -108,7 +110,7 @@ class Substr(Func[_OutputField]): pos: Expression | int, length: Expression | int | None = None, *, - output_field: _OutputField | None = None, + output_field: _SubstrOutputField | None = None, **extra: Any, ) -> None: ... def as_oracle(self, compiler: SQLCompiler, connection: BaseDatabaseWrapper, **extra_context: Any) -> _AsSqlType: ... diff --git a/django-stubs/db/models/options.pyi b/django-stubs/db/models/options.pyi index 30eb489ce..2df6395e0 100644 --- a/django-stubs/db/models/options.pyi +++ b/django-stubs/db/models/options.pyi @@ -42,7 +42,7 @@ class Options(Generic[_M]): FORWARD_PROPERTIES: set[str] REVERSE_PROPERTIES: set[str] default_apps: Any - local_fields: list[Field] + local_fields: list[Field[Any, Any, Any]] local_many_to_many: list[ManyToManyField] private_fields: list[Any] local_managers: list[Manager] @@ -67,15 +67,15 @@ class Options(Generic[_M]): required_db_features: _ListOrTuple[str] required_db_vendor: str | None meta: type | None - pk: Field - auto_field: AutoField | None + pk: Field[Any, Any, Any] + auto_field: AutoField[Any, Any, Any] | None abstract: bool managed: bool proxy: bool proxy_for_model: type[Model] | None concrete_model: type[Model] | None swappable: str | None - parents: dict[type[Model], Field | None] + parents: dict[type[Model], Field[Any, Any, Any] | None] auto_created: bool related_fkey_lookups: list[Any] apps: Apps diff --git a/django-stubs/db/models/query_utils.pyi b/django-stubs/db/models/query_utils.pyi index e89c094fc..031144f2c 100644 --- a/django-stubs/db/models/query_utils.pyi +++ b/django-stubs/db/models/query_utils.pyi @@ -28,8 +28,8 @@ class class_or_instance_method: class PathInfo(NamedTuple): from_opts: Options to_opts: Options - target_fields: tuple[Field, ...] - join_field: Field + target_fields: tuple[Field[Any, Any, Any], ...] + join_field: Field[Any, Any, Any] m2m: bool direct: bool filtered_relation: FilteredRelation | None @@ -67,8 +67,8 @@ class Q(tree.Node): def referenced_base_fields(self) -> set[str]: ... class DeferredAttribute: - field: Field - def __init__(self, field: Field) -> None: ... + field: Field[Any, Any, Any] + def __init__(self, field: Field[Any, Any, Any]) -> None: ... def __get__(self, instance: Model | None, cls: type[Model] | None = None) -> Any: ... _R = TypeVar("_R", bound=type) @@ -96,10 +96,10 @@ class RegisterLookupMixin: def _unregister_lookup(cls, lookup: type[Lookup], lookup_name: str | None = None) -> None: ... def select_related_descend( - field: Field, + field: Field[Any, Any, Any], restricted: bool, requested: Mapping[str, Any] | None, - select_mask: set[Field] | None, + select_mask: set[Field[Any, Any, Any]] | None, ) -> bool: ... _E = TypeVar("_E", bound=BaseExpression) diff --git a/django-stubs/db/models/sql/query.pyi b/django-stubs/db/models/sql/query.pyi index 2b5d1be54..f5a5ffb89 100644 --- a/django-stubs/db/models/sql/query.pyi +++ b/django-stubs/db/models/sql/query.pyi @@ -15,12 +15,12 @@ from django.utils.functional import cached_property from typing_extensions import override class JoinInfo(NamedTuple): - final_field: Field + final_field: Field[Any, Any, Any] targets: tuple[Any, ...] opts: Any joins: list[str] path: list[Any] - transform_function: Callable[[Field, str], Expression] + transform_function: Callable[[Field[Any, Any, Any], str], Expression] class RawQuery: high_mark: int | None @@ -88,7 +88,7 @@ class Query(BaseExpression): def __init__(self, model: type[Model] | None, alias_cols: bool = True) -> None: ... @property @override - def output_field(self) -> Field: ... + def output_field(self) -> Field[Any, Any, Any]: ... @property def has_select_fields(self) -> bool: ... @cached_property @@ -165,8 +165,8 @@ class Query(BaseExpression): allow_many: bool = True, ) -> JoinInfo: ... def trim_joins( - self, targets: tuple[Field, ...], joins: list[str], path: list[PathInfo] - ) -> tuple[tuple[Field, ...], str, list[str]]: ... + self, targets: tuple[Field[Any, Any, Any], ...], joins: list[str], path: list[PathInfo] + ) -> tuple[tuple[Field[Any, Any, Any], ...], str, list[str]]: ... def resolve_ref( self, name: str, allow_joins: bool = True, reuse: set[str] | None = None, summarize: bool = False ) -> Expression: ... @@ -216,7 +216,7 @@ class Query(BaseExpression): @property def extra_select(self) -> dict[str, Any]: ... def trim_start(self, names_with_path: list[tuple[str, list[PathInfo]]]) -> tuple[str, bool]: ... - def is_nullable(self, field: Field) -> bool: ... + def is_nullable(self, field: Field[Any, Any, Any]) -> bool: ... def check_filterable(self, expression: Any) -> None: ... def build_lookup(self, lookups: Sequence[str], lhs: Expression | Query, rhs: Any) -> Lookup: ... def try_transform(self, lhs: Expression | Query, name: str, lookups: Sequence[str] | None = ...) -> Transform: ... diff --git a/django-stubs/db/models/sql/subqueries.pyi b/django-stubs/db/models/sql/subqueries.pyi index cd38504fc..0b78c4bab 100644 --- a/django-stubs/db/models/sql/subqueries.pyi +++ b/django-stubs/db/models/sql/subqueries.pyi @@ -15,12 +15,12 @@ class UpdateQuery(Query): def __init__(self, *args: Any, **kwargs: Any) -> None: ... def update_batch(self, pk_list: list[int], values: dict[str, int | None], using: str) -> None: ... def add_update_values(self, values: dict[str, Any]) -> None: ... - def add_update_fields(self, values_seq: list[tuple[Field, type[Model] | None, Case]]) -> None: ... - def add_related_update(self, model: type[Model], field: Field, value: int | str) -> None: ... + def add_update_fields(self, values_seq: list[tuple[Field[Any, Any, Any], type[Model] | None, Case]]) -> None: ... + def add_related_update(self, model: type[Model], field: Field[Any, Any, Any], value: int | str) -> None: ... def get_related_updates(self) -> list[UpdateQuery]: ... class InsertQuery(Query): - fields: Iterable[Field] + fields: Iterable[Field[Any, Any, Any]] objs: list[Model] raw: bool def __init__( @@ -31,7 +31,7 @@ class InsertQuery(Query): unique_fields: Any | None = ..., **kwargs: Any, ) -> None: ... - def insert_values(self, fields: Iterable[Field], objs: list[Model], raw: bool = False) -> None: ... + def insert_values(self, fields: Iterable[Field[Any, Any, Any]], objs: list[Model], raw: bool = False) -> None: ... class AggregateQuery(Query): inner_query: Query diff --git a/mypy_django_plugin/django/context.py b/mypy_django_plugin/django/context.py index 73376c720..b2730b1e6 100644 --- a/mypy_django_plugin/django/context.py +++ b/mypy_django_plugin/django/context.py @@ -11,7 +11,7 @@ from django.db import models from django.db.models.base import Model from django.db.models.constants import LOOKUP_SEP -from django.db.models.fields import AutoField, CharField, Field +from django.db.models.fields import CharField, Field from django.db.models.fields.related import ForeignKey, RelatedField from django.db.models.fields.reverse_related import ForeignObjectRel from django.db.models.lookups import Exact, In @@ -40,7 +40,7 @@ class ArrayField: # type: ignore[no-redef] from django.db.models.expressions import Expression from django.db.models.options import _AnyField from mypy.checker import TypeChecker - from mypy.nodes import TypeInfo + from mypy.nodes import Context from mypy.plugin import MethodContext @@ -82,35 +82,6 @@ class LookupsAreUnsupported(Exception): pass -def get_field_type_from_model_type_info(info: TypeInfo | None, field_name: str) -> Instance | None: - if info is None: - return None - field_node = info.get(field_name) - if field_node is None: - return None - field_type = get_proper_type(field_node.type) - if not isinstance(field_type, Instance): - return None - # Field declares a set and a get type arg. Fallback to `None` when we can't find any args - if len(field_type.args) != 2: - return None - return field_type - - -def _get_field_set_type_from_model_type_info(info: TypeInfo | None, field_name: str) -> MypyType | None: - field_type = get_field_type_from_model_type_info(info, field_name) - if field_type is not None: - return field_type.args[0] - return None - - -def _get_field_get_type_from_model_type_info(info: TypeInfo | None, field_name: str) -> MypyType | None: - field_type = get_field_type_from_model_type_info(info, field_name) - if field_type is not None: - return field_type.args[1] - return None - - class DjangoContext: def __init__(self, django_settings_module: str) -> None: self.django_settings_module = django_settings_module @@ -158,7 +129,9 @@ def get_model_relations(self, model_cls: type[Model]) -> Iterator[ForeignObjectR if isinstance(field, ForeignObjectRel): yield field - def get_field_lookup_exact_type(self, api: TypeChecker, field: Field[Any, Any] | ForeignObjectRel) -> MypyType: + def get_field_lookup_exact_type( + self, api: TypeChecker, field: Field[Any, Any, Any] | ForeignObjectRel, context: Context + ) -> MypyType: if isinstance(field, RelatedField | ForeignObjectRel): related_model_cls = self.get_field_related_model_cls(field) rel_model_info = helpers.lookup_class_typeinfo(api, related_model_cls) @@ -166,7 +139,9 @@ def get_field_lookup_exact_type(self, api: TypeChecker, field: Field[Any, Any] | return AnyType(TypeOfAny.explicit) primary_key_field = self.get_primary_key_field(related_model_cls) - primary_key_type = self.get_field_get_type(api, rel_model_info, primary_key_field, method="init") + primary_key_type = helpers.get_field_get_type_from_model_type_info( + api, context, rel_model_info, primary_key_field.attname + ) model_and_primary_key_type = UnionType.make_union([Instance(rel_model_info, []), primary_key_type]) return make_optional_type(model_and_primary_key_type) @@ -196,32 +171,28 @@ def get_primary_key_field(self, model_cls: type[Model]) -> Field[Any, Any]: return field raise ValueError("No primary key defined") - def get_expected_types(self, api: TypeChecker, model_cls: type[Model], *, method: str) -> dict[str, MypyType]: + def get_expected_types(self, api: TypeChecker, model_cls: type[Model], context: Context) -> dict[str, MypyType]: + """Return a mapping of field name to the type accepted when assigning/passing it as a kwarg.""" contenttypes_in_apps = self.apps_registry.is_installed("django.contrib.contenttypes") + model_info = helpers.lookup_class_typeinfo(api, model_cls) expected_types = {} - # add pk if not abstract=True if not model_cls._meta.abstract: primary_key_field = self.get_primary_key_field(model_cls) - field_set_type = self.get_field_set_type(api, primary_key_field, method=method) - expected_types["pk"] = field_set_type + pk_set_type = helpers.get_field_set_type_from_model_type_info( + api, context, model_info, primary_key_field.attname + ) + # Setting pk to None is allowed to copy instances + # https://docs.djangoproject.com/en/6.0/topics/db/queries/#copying-model-instances + expected_types["pk"] = make_optional_type(pk_set_type) - model_info = helpers.lookup_class_typeinfo(api, model_cls) for field in model_cls._meta.get_fields(): if contenttypes_in_apps: from django.contrib.contenttypes.fields import GenericForeignKey if isinstance(field, GenericForeignKey): - # it's generic, so cannot set specific model - field_name = field.name - gfk_info = helpers.lookup_class_typeinfo(api, field.__class__) - if gfk_info is None: - gfk_set_type: MypyType = AnyType(TypeOfAny.unannotated) - else: - gfk_set_type = helpers.get_private_descriptor_type( - gfk_info, "_pyi_private_set_type", is_nullable=True - ) - expected_types[field_name] = gfk_set_type + # A GenericForeignKey can reference any model, so we can't narrow the assignment type (yet!) + expected_types[field.name] = AnyType(TypeOfAny.unannotated) continue if isinstance(field, Field): @@ -231,44 +202,17 @@ def get_expected_types(self, api: TypeChecker, model_cls: type[Model], *, method # recursive abstract model and we need to not crash and gracefully exit. if field.related_model == "self" and model_cls._meta.abstract: # type: ignore[comparison-overlap, unreachable] continue # type: ignore[unreachable] - # Try to retrieve set type from a model's TypeInfo object and fallback to retrieving it manually - # from django-stubs own declaration. This is to align with the setter types declared for - # assignment. - field_set_type = _get_field_set_type_from_model_type_info( - model_info, field_name - ) or self.get_field_set_type(api, field, method=method) - expected_types[field_name] = field_set_type + expected_types[field_name] = helpers.get_field_set_type_from_model_type_info( + api, context, model_info, field_name + ) if isinstance(field, ForeignKey): - field_name = field.name - foreign_key_info = helpers.lookup_class_typeinfo(api, field.__class__) - if foreign_key_info is None: - # maybe there's no type annotation for the field - expected_types[field_name] = AnyType(TypeOfAny.unannotated) - continue - - try: - related_model = self.get_field_related_model_cls(field) - except UnregisteredModelError: - # Recognise the field but don't say anything about its type.. - expected_types[field_name] = AnyType(TypeOfAny.from_error) - continue - - if related_model._meta.proxy_for_model is not None: - related_model = related_model._meta.proxy_for_model - - related_model_info = helpers.lookup_class_typeinfo(api, related_model) - if related_model_info is None: - expected_types[field_name] = AnyType(TypeOfAny.unannotated) - continue - - is_nullable = self.get_field_nullability(field, method) - foreign_key_set_type = helpers.get_private_descriptor_type( - foreign_key_info, "_pyi_private_set_type", is_nullable=is_nullable + # In the case of a FK, we need to register both `fk_name` and `fk_name_id` + # - `field.attname` -> `fk_name_id` + # - `field.name` -> `fk_name` + expected_types[field.name] = helpers.get_field_set_type_from_model_type_info( + api, context, model_info, field.name ) - model_set_type = helpers.convert_any_to_type(foreign_key_set_type, Instance(related_model_info, [])) - - expected_types[field_name] = model_set_type return expected_types @@ -292,82 +236,14 @@ def model_class_fullnames_by_label(self) -> Mapping[str, str]: if klass is not models.Model } - def get_field_nullability(self, field: Field[Any, Any] | ForeignObjectRel, method: str | None) -> bool: - if method in ("values", "values_list"): - return field.null - + def get_field_nullability(self, field: Field[Any, Any] | ForeignObjectRel) -> bool: nullable = field.null if not nullable and isinstance(field, CharField) and field.blank: return True - if method == "__init__": - if (isinstance(field, Field) and field.primary_key) or isinstance(field, ForeignKey): - return True - if method == "create": - if isinstance(field, AutoField): - return True if isinstance(field, Field) and field.has_default(): return True return nullable - def get_field_set_type( - self, api: TypeChecker, field: Field[Any, Any] | ForeignObjectRel, *, method: str - ) -> MypyType: - """Get a type of __set__ for this specific Django field.""" - target_field = field - if isinstance(field, ForeignKey): - try: - # We gotta be careful for exceptions when we're triggering '__get__'. - # Related model could very well be unresolvable - target_field = field.target_field - except ValueError: - return AnyType(TypeOfAny.from_error) - - field_info = helpers.lookup_class_typeinfo(api, target_field.__class__) - if field_info is None: - return AnyType(TypeOfAny.from_error) - - field_set_type = helpers.get_private_descriptor_type( - field_info, "_pyi_private_set_type", is_nullable=self.get_field_nullability(field, method) - ) - if isinstance(target_field, ArrayField): - argument_field_type = self.get_field_set_type(api, target_field.base_field, method=method) - field_set_type = helpers.convert_any_to_type(field_set_type, argument_field_type) - return field_set_type - - def get_field_get_type( - self, - api: TypeChecker, - model_info: TypeInfo | None, - field: Field[Any, Any] | ForeignObjectRel, - *, - method: str, - ) -> MypyType: - """Get a type of __get__ for this specific Django field.""" - if isinstance(field, Field): - get_type = _get_field_get_type_from_model_type_info(model_info, getattr(field, "attname", field.name)) - if get_type is not None: - return get_type - - field_info = helpers.lookup_class_typeinfo(api, field.__class__) - if field_info is None: - return AnyType(TypeOfAny.unannotated) - - is_nullable = self.get_field_nullability(field, method) - if isinstance(field, RelatedField): - related_model_cls = self.get_field_related_model_cls(field) - rel_model_info = helpers.lookup_class_typeinfo(api, related_model_cls) - - if method in ("values", "values_list"): - primary_key_field = self.get_primary_key_field(related_model_cls) - return self.get_field_get_type(api, rel_model_info, primary_key_field, method=method) - - model_info = helpers.lookup_class_typeinfo(api, related_model_cls) - if model_info is None: - return AnyType(TypeOfAny.unannotated) - - return Instance(model_info, []) - return helpers.get_private_descriptor_type(field_info, "_pyi_private_get_type", is_nullable=is_nullable) - def get_field_related_model_cls(self, field: RelatedField[Any, Any] | ForeignObjectRel) -> type[Model]: if isinstance(field, RelatedField): related_model_cls = field.remote_field.model @@ -477,7 +353,8 @@ def _resolve_lookup_type_from_lookup_class( Returns: The resolved type, or None if it couldn't be determined """ - lookup_info = helpers.lookup_class_typeinfo(helpers.get_typechecker_api(ctx), lookup_cls) + api = helpers.get_typechecker_api(ctx) + lookup_info = helpers.lookup_class_typeinfo(api, lookup_cls) if lookup_info is None: return None @@ -488,12 +365,16 @@ def _resolve_lookup_type_from_lookup_class( if field is None: # No field available (e.g., annotation), can't resolve further return None - field_info = helpers.lookup_class_typeinfo(helpers.get_typechecker_api(ctx), field.__class__) + field_info = helpers.lookup_class_typeinfo(api, field.__class__) if field_info is None: return None - return get_proper_type( - helpers.get_private_descriptor_type(field_info, "_pyi_private_get_type", is_nullable=field.null) - ) + defaults = helpers.fill_field_defaults(field_info, api) + field_type_args = helpers.get_field_type_args(defaults) + assert field_type_args is not None + result: MypyType = field_type_args.get + if field.null: + result = make_optional_type(result) + return result return lookup_type return None @@ -568,10 +449,10 @@ def resolve_lookup_expected_type( return AnyType(TypeOfAny.explicit) if lookup_cls is None or issubclass(lookup_cls, Exact): - return self.get_field_lookup_exact_type(helpers.get_typechecker_api(ctx), field) + return self.get_field_lookup_exact_type(helpers.get_typechecker_api(ctx), field, ctx.context) if issubclass(lookup_cls, In): - exact_type = self.get_field_lookup_exact_type(helpers.get_typechecker_api(ctx), field) + exact_type = self.get_field_lookup_exact_type(helpers.get_typechecker_api(ctx), field, ctx.context) return ctx.api.named_generic_type("typing.Iterable", [exact_type]) resolved_type = self._resolve_lookup_type_from_lookup_class(ctx, lookup_cls, field) diff --git a/mypy_django_plugin/lib/helpers.py b/mypy_django_plugin/lib/helpers.py index 0450e2fd6..738ced6c2 100644 --- a/mypy_django_plugin/lib/helpers.py +++ b/mypy_django_plugin/lib/helpers.py @@ -38,7 +38,7 @@ MethodContext, ) from mypy.semanal import SemanticAnalyzer -from mypy.typeanal import make_optional_type +from mypy.typeanal import fix_instance, make_optional_type from mypy.types import ( AnyType, CallableType, @@ -436,6 +436,27 @@ def get_private_descriptor_type(type_info: TypeInfo, private_field_name: str, is return AnyType(TypeOfAny.explicit) +def fill_field_defaults( + field_info: TypeInfo, + api: TypeChecker | SemanticAnalyzer, + *, + is_set_nullable: bool = False, +) -> Instance: + """Build an Instance of `field_info` with PEP 696 TypeVar defaults applied.""" + inst = Instance(field_info, ()) + fix_instance( + inst, + api.fail, + api.note, + disallow_any=False, + options=api.options, + use_generic_error=True, + ) + if is_set_nullable and inst.args: + inst = inst.copy_modified(args=[make_optional_type(inst.args[0]), *inst.args[1:]]) + return inst + + class FieldTypeArgs(NamedTuple): set: ProperType get: ProperType @@ -562,6 +583,26 @@ def convert_any_to_type(typ: MypyType, referred_to_type: MypyType) -> MypyType: return typ +def reparametrize_field_type( + field_type: Instance, set_type: MypyType, get_type: MypyType, *, is_nullable: bool | None = None +) -> Instance: + """Replace the set/get type args of a Field Instance, preserving any additional args. + + Uses `convert_any_to_type` to substitute `Any` placeholders in the existing args; + the ``_NT`` nullability flag (``args[2]``) is updated when *is_nullable* is given. + """ + trailing = list(field_type.args[2:]) + nt_proper = get_proper_type(trailing[0]) if trailing else None + if is_nullable is not None and isinstance(nt_proper, LiteralType): + trailing[0] = LiteralType(value=is_nullable, fallback=nt_proper.fallback) + args = [ + convert_any_to_type(field_type.args[0], set_type), + convert_any_to_type(field_type.args[1], get_type), + *trailing, + ] + return field_type.copy_modified(args=args) + + def _get_fallback_typeddict(api: SemanticAnalyzer | CheckerPluginInterface) -> Instance: if isinstance(api, CheckerPluginInterface): return api.named_generic_type("typing._TypedDict", []) @@ -787,6 +828,7 @@ def analyze_member_access( is_operator: bool, original_type: MypyType, chk: TypeChecker, + suppress_errors: bool = False, ) -> MypyType: # TODO: [mypy 1.16+] Remove this workaround for passing `msg` to `analyze_member_access()`. extra: dict[str, Any] = {} @@ -801,5 +843,69 @@ def analyze_member_access( is_operator=is_operator, original_type=original_type, chk=chk, + suppress_errors=suppress_errors, **extra, ) + + +def _resolve_field_descriptor_type( + api: TypeChecker, info: TypeInfo | None, field_name: str, context: Context, *, is_lvalue: bool +) -> ProperType: + if info is None: + return AnyType(TypeOfAny.from_error) + instance = Instance(info, []) + return get_proper_type( + analyze_member_access( + name=field_name, + typ=instance, + context=context, + is_lvalue=is_lvalue, + is_super=False, + is_operator=False, + original_type=instance, + chk=api, + suppress_errors=True, + ) + ) + + +def get_field_set_type_from_model_type_info( + api: TypeChecker, context: Context, info: TypeInfo | None, field_name: str +) -> ProperType: + """Resolve `. = ...` rvalue type via Field's `__set__` overloads. + + For `name = CharField()` -> `str | int | Combinable`. + For `name = CharField(null=True)` -> `str | int | Combinable | None`. + """ + return _resolve_field_descriptor_type(api, info, field_name, context, is_lvalue=True) + + +def get_field_get_type_from_model_type_info( + api: TypeChecker, context: Context, info: TypeInfo | None, field_name: str +) -> ProperType: + """Resolve `.` read type via Field's `__get__` overloads. + + For `name = CharField()` -> `str`. + For `name = CharField(null=True)` -> `str | None`. + """ + return _resolve_field_descriptor_type(api, info, field_name, context, is_lvalue=False) + + +def get_field_type_from_model_type_info( + api: TypeChecker, context: Context, model_info: TypeInfo, field_name: str +) -> Instance | None: + """Resolve `.`'s full Field descriptor type with set/get args filled in. + + For `name = CharField()` -> `CharField[str | int | Combinable, str, Literal[False]]`. + For `name = CharField(null=True)` -> `CharField[str | int | Combinable | None, str | None, Literal[True]]`. + """ + if ( + (field_sym_node := model_info.get(field_name)) is None + or not isinstance(field_type := get_proper_type(field_sym_node.type), Instance) + or len(field_type.args) < 2 + ): + return None + + resolved_get_type = get_field_get_type_from_model_type_info(api, context, model_info, field_name) + resolved_set_type = get_field_set_type_from_model_type_info(api, context, model_info, field_name) + return field_type.copy_modified(args=[resolved_set_type, resolved_get_type, *field_type.args[2:]]) diff --git a/mypy_django_plugin/transformers/fields.py b/mypy_django_plugin/transformers/fields.py index 2eb8b7ec7..9d6ea7ec4 100644 --- a/mypy_django_plugin/transformers/fields.py +++ b/mypy_django_plugin/transformers/fields.py @@ -1,12 +1,13 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, NamedTuple, cast +from typing import TYPE_CHECKING, Any, cast from django.core.exceptions import FieldDoesNotExist from django.db.models.fields import AutoField, Field from django.db.models.fields.related import RelatedField from mypy.nodes import AssignmentStmt, NameExpr, TypeInfo -from mypy.types import AnyType, Instance, NoneType, ProperType, TypeOfAny, UninhabitedType, UnionType, get_proper_type +from mypy.typeanal import make_optional_type +from mypy.types import AnyType, Instance, LiteralType, ProperType, TypeOfAny, get_proper_type from mypy.types import Type as MypyType from mypy_django_plugin.exceptions import UnregisteredModelError @@ -48,14 +49,6 @@ def _get_current_field_from_assignment( return None -def reparametrize_related_field_type(related_field_type: Instance, set_type: MypyType, get_type: MypyType) -> Instance: - args = [ - helpers.convert_any_to_type(related_field_type.args[0], set_type), - helpers.convert_any_to_type(related_field_type.args[1], get_type), - ] - return related_field_type.copy_modified(args=args) - - def fill_descriptor_types_for_related_field(ctx: FunctionContext, django_context: DjangoContext) -> MypyType: current_field = _get_current_field_from_assignment(ctx, django_context) if current_field is None: @@ -80,7 +73,7 @@ def fill_descriptor_types_for_related_field(ctx: FunctionContext, django_context derived_model_info = helpers.lookup_class_typeinfo(helpers.get_typechecker_api(ctx), model_cls) if derived_model_info is not None: fk_ref_type = Instance(derived_model_info, []) - derived_fk_type = reparametrize_related_field_type( + derived_fk_type = helpers.reparametrize_field_type( default_related_field_type, set_type=fk_ref_type, get_type=fk_ref_type ) helpers.add_new_sym_for_info(derived_model_info, name=current_field.name, sym_type=derived_fk_type) @@ -108,23 +101,18 @@ def fill_descriptor_types_for_related_field(ctx: FunctionContext, django_context else: related_model_to_set_type = Instance(related_model_to_set_info, []) - # replace Any with referred_to_type - return reparametrize_related_field_type( - default_related_field_type, set_type=related_model_to_set_type, get_type=related_model_type - ) - - -class FieldDescriptorTypes(NamedTuple): - set: MypyType - get: MypyType + is_nullable = helpers.get_bool_call_argument_by_name(ctx, "null", default=False) + set_type: MypyType = related_model_to_set_type + get_type: MypyType = related_model_type + if is_nullable: + set_type = make_optional_type(set_type) + get_type = make_optional_type(get_type) -def get_field_descriptor_types( - field_info: TypeInfo, *, is_set_nullable: bool, is_get_nullable: bool -) -> FieldDescriptorTypes: - set_type = helpers.get_private_descriptor_type(field_info, "_pyi_private_set_type", is_nullable=is_set_nullable) - get_type = helpers.get_private_descriptor_type(field_info, "_pyi_private_get_type", is_nullable=is_get_nullable) - return FieldDescriptorTypes(set=set_type, get=get_type) + # replace Any with referred_to_type + return helpers.reparametrize_field_type( + default_related_field_type, set_type=set_type, get_type=get_type, is_nullable=is_nullable + ) def set_descriptor_types_for_field_callback(ctx: FunctionContext, django_context: DjangoContext) -> MypyType: @@ -136,95 +124,34 @@ def set_descriptor_types_for_field_callback(ctx: FunctionContext, django_context return set_descriptor_types_for_field(ctx) -def set_descriptor_types_for_field( - ctx: FunctionContext, *, is_set_nullable: bool = False, is_get_nullable: bool = False -) -> Instance: +def set_descriptor_types_for_field(ctx: FunctionContext, *, is_set_nullable: bool = False) -> Instance: default_return_type = cast("Instance", ctx.default_return_type) + if len(default_return_type.args) != 3: + # Explicitly bound fields. For ex: + # `class CustomValueField(fields.Field[CustomFieldValue | int, CustomFieldValue])` + return default_return_type is_nullable = helpers.get_bool_call_argument_by_name(ctx, "null", default=False) + is_primary_key = helpers.get_bool_call_argument_by_name(ctx, "primary_key", default=False) - # Allow setting field value to `None` when a field is primary key and has a default that can produce a value default_expr = helpers.get_call_argument_by_name(ctx, "default") if default_expr is not None: is_set_nullable = is_primary_key - set_type, get_type = get_field_descriptor_types( - default_return_type.type, - is_set_nullable=is_set_nullable or is_nullable, - is_get_nullable=is_get_nullable or is_nullable, - ) - - # reconcile set and get types with the base field class - mapped_types = helpers.get_field_type_args(default_return_type) - assert mapped_types is not None - - # bail if either mapped set or get type is Never - if not (isinstance(mapped_types.set, UninhabitedType) or isinstance(mapped_types.get, UninhabitedType)): - # always replace set_type and get_type with (non-Any) mapped types - set_type = helpers.convert_any_to_type(mapped_types.set, set_type) - get_type = get_proper_type(helpers.convert_any_to_type(mapped_types.get, get_type)) + set_type = default_return_type.args[0] + get_type = default_return_type.args[1] + # Handle `primary_key` + `default` allows setting to None + if is_set_nullable: + set_type = make_optional_type(set_type) - # the get_type must be optional if the field is nullable - if (is_get_nullable or is_nullable) and not ( - isinstance(get_type, NoneType) or helpers.is_optional(get_type) or isinstance(get_type, AnyType) - ): - ctx.api.fail( - f"{default_return_type.type.name} is nullable but its generic get type parameter is not optional", - ctx.context, - ) - - return default_return_type.copy_modified(args=[set_type, get_type]) - - -def determine_type_of_array_field(ctx: FunctionContext, django_context: DjangoContext) -> MypyType: - default_return_type = set_descriptor_types_for_field(ctx) - - base_field_arg_type = get_proper_type(helpers.get_call_argument_type_by_name(ctx, "base_field")) - if not base_field_arg_type or not isinstance(base_field_arg_type, Instance): - return default_return_type - - def drop_combinable(_type: MypyType) -> MypyType | None: - _type = get_proper_type(_type) - if isinstance(_type, Instance) and _type.type.has_base(fullnames.COMBINABLE_EXPRESSION_FULLNAME): - return None - if isinstance(_type, UnionType): - items_without_combinable = [] - for item in _type.items: - reduced = drop_combinable(item) - if reduced is not None: - items_without_combinable.append(reduced) - - if len(items_without_combinable) > 1: - return UnionType( - items_without_combinable, - line=_type.line, - column=_type.column, - is_evaluated=_type.is_evaluated, - uses_pep604_syntax=_type.uses_pep604_syntax, - ) - if len(items_without_combinable) == 1: - return items_without_combinable[0] - return None - - return _type - - # Both base_field and return type should derive from Field and thus expect 2 arguments - assert len(base_field_arg_type.args) == len(default_return_type.args) == 2 - args = [] - for new_type, default_arg in zip(base_field_arg_type.args, default_return_type.args, strict=False): - # Drop any base_field Combinable type - reduced = drop_combinable(new_type) - if reduced is None: - ctx.api.fail( - f"Can't have ArrayField expecting {fullnames.COMBINABLE_EXPRESSION_FULLNAME!r} as data type", - ctx.context, - ) - else: - new_type = reduced - - args.append(helpers.convert_any_to_type(default_arg, new_type)) - - return default_return_type.copy_modified(args=args) + # Update the _NT (null flag) type argument to match the resolved nullability. + # In the future, we should be able to remove that once `primary_key` and `default` + # are also part of the type and can hence be used to derive the `_NT` value + trailing = list(default_return_type.args[2:]) + nt_proper = get_proper_type(trailing[0]) if trailing else None + if isinstance(nt_proper, LiteralType): + trailing[0] = LiteralType(value=is_nullable, fallback=nt_proper.fallback) + return default_return_type.copy_modified(args=[set_type, get_type, *trailing]) def transform_into_proper_return_type(ctx: FunctionContext, django_context: DjangoContext) -> MypyType: @@ -233,7 +160,7 @@ def transform_into_proper_return_type(ctx: FunctionContext, django_context: Djan outer_model_info = helpers.get_typechecker_api(ctx).scope.active_class() if outer_model_info is None or not helpers.is_model_type(outer_model_info): - return set_descriptor_types_for_field(ctx) + return default_return_type assert isinstance(outer_model_info, TypeInfo) @@ -244,7 +171,4 @@ def transform_into_proper_return_type(ctx: FunctionContext, django_context: Djan if helpers.has_any_of_bases(default_return_type.type, fullnames.RELATED_FIELDS_CLASSES): return fill_descriptor_types_for_related_field(ctx, django_context) - if default_return_type.type.has_base(fullnames.ARRAY_FIELD_FULLNAME): - return determine_type_of_array_field(ctx, django_context) - return set_descriptor_types_for_field_callback(ctx, django_context) diff --git a/mypy_django_plugin/transformers/init_create.py b/mypy_django_plugin/transformers/init_create.py index 837c5d660..1eabb6f6a 100644 --- a/mypy_django_plugin/transformers/init_create.py +++ b/mypy_django_plugin/transformers/init_create.py @@ -43,7 +43,7 @@ def typecheck_model_method( ) -> None: """Type-checks positional and keyword arguments for Model methods like __init__(), create(), and acreate().""" typechecker_api = helpers.get_typechecker_api(ctx) - expected_types = django_context.get_expected_types(typechecker_api, model_cls, method=method) + expected_types = django_context.get_expected_types(typechecker_api, model_cls, ctx.context) expected_keys = [key for key in expected_types.keys() if key != "pk"] min_arg_count = helpers.get_min_argument_count(ctx) diff --git a/mypy_django_plugin/transformers/managers.py b/mypy_django_plugin/transformers/managers.py index 2f8900aea..48ecef1af 100644 --- a/mypy_django_plugin/transformers/managers.py +++ b/mypy_django_plugin/transformers/managers.py @@ -9,6 +9,7 @@ Decorator, FuncBase, FuncDef, + IndexExpr, MemberExpr, Node, OverloadedFuncDef, @@ -16,6 +17,7 @@ RefExpr, StrExpr, SymbolTableNode, + TupleExpr, TypeInfo, ) from mypy.plugins.common import add_method_to_class @@ -606,10 +608,25 @@ def _defer() -> None: ) -def _is_omitted_generic(arg: MypyType) -> bool: - """True if ``arg`` is the ``Any`` mypy substitutes for an unparametrized generic.""" - proper_arg = get_proper_type(arg) - return isinstance(proper_arg, AnyType) and proper_arg.type_of_any is TypeOfAny.from_omitted_generics +def _count_user_supplied_type_args(ctx: ClassDefContext, parent_type: TypeInfo) -> int: + """Count type args the user actually wrote at ``parent_type``'s bracket position. + + We can't infer this from the parent ``Instance.args`` because mypy's PEP 696 + default substitution leaves no marker distinguishing user-supplied slots from bdefaulted ones. + """ + base_index_expr = next( + ( + expr + for expr in ctx.cls.base_type_exprs + if isinstance(expr, IndexExpr) and isinstance(expr.base, RefExpr) and expr.base.node is parent_type + ), + None, + ) + if base_index_expr is None: + return 0 + if isinstance(base_index_expr.index, TupleExpr): + return len(base_index_expr.index.items) + return 1 def reparametrize_generic_class( @@ -648,31 +665,27 @@ def reparametrize_generic_class( if parent_class is None or not parent_class.args: return - # Bind explicit args only when the direct parent is the canonical base class - is_direct_parent = parent_class.type.fullname == base_class_fullname - if not (bind_explicit_args and is_direct_parent): - if not _is_omitted_generic(parent_class.args[0]): + parent_type_vars = list(parent_class.type.defn.type_vars) + num_type_arg_supplied = _count_user_supplied_type_args(ctx, parent_class.type) + + if bind_explicit_args and num_type_arg_supplied > 0: + if parent_class.type.fullname != base_class_fullname: + # User wrote `[...]` on a non-direct parent. For ex `HtmlField(TextField[...])` return - type_vars = list(parent_class.type.defn.type_vars) + + # Synthesize a TypeVar reusing parent bound and default. + type_vars = [ + type_var.copy_modified(id=type_var.id, upper_bound=parent_type_var, default=parent_type_var) + for type_var, parent_type_var in zip(parent_type_vars, parent_class.args, strict=True) + ] + new_parent_args: list[MypyType] = list(type_vars) + elif num_type_arg_supplied < len(parent_type_vars): + # Bare class or partial parametrization: reuse parent's trailing TypeVar(s) to keep the subclass generic. + type_vars = parent_type_vars[num_type_arg_supplied:] + new_parent_args = list(parent_class.args[:num_type_arg_supplied]) + list(type_vars) else: - type_vars = [] - for type_var, parent_type_var in zip(parent_class.type.defn.type_vars, parent_class.args, strict=True): - if _is_omitted_generic(parent_type_var): - # Arg was omitted -- reuse the parent's TypeVar so its existing - # PEP 696 default (e.g. ``_Row = TypeVar(default=_Model)``) is preserved. - type_vars.append(type_var) - else: - # Arg was supplied explicitly -- synthesize a TypeVar bounded by the - # user's arg so the subclass stays effectively concrete in user-facing - # contexts (defaults + bound + covariance) while still being generic - # enough for the plugin to flow annotation row types through. - type_vars.append( - type_var.copy_modified( - id=type_var.id, - upper_bound=parent_type_var, - default=parent_type_var, - ) - ) + # User wrote a full parametrization; respect it. + return # If we end up with placeholders we need to defer so the placeholders are # resolved in a future iteration @@ -682,7 +695,7 @@ def reparametrize_generic_class( else: return - parent_class.args = tuple(type_vars) + parent_class.args = tuple(new_parent_args) class_info.node.defn.type_vars = type_vars class_info.node.add_type_vars() diff --git a/mypy_django_plugin/transformers/meta.py b/mypy_django_plugin/transformers/meta.py index f063c66ab..6ffd6d7e6 100644 --- a/mypy_django_plugin/transformers/meta.py +++ b/mypy_django_plugin/transformers/meta.py @@ -6,34 +6,32 @@ from mypy.types import AnyType, Instance, TypeOfAny, get_proper_type from mypy.types import Type as MypyType -from mypy_django_plugin.django.context import DjangoContext, get_field_type_from_model_type_info from mypy_django_plugin.lib import helpers from mypy_django_plugin.lib.helpers import DjangoModel if TYPE_CHECKING: from mypy.plugin import MethodContext + from mypy_django_plugin.django.context import DjangoContext + def return_proper_field_type_from_get_field(ctx: MethodContext, django_context: DjangoContext) -> MypyType: if not ( isinstance(ctx.type, Instance) and ctx.type.args and isinstance(model_type := get_proper_type(ctx.type.args[0]), Instance) + and (django_model := DjangoModel.from_model_type(model_type, django_context)) is not None and (field_name_expr := helpers.get_call_argument_by_name(ctx, "field_name")) is not None and (field_name := helpers.resolve_string_attribute_value(field_name_expr, django_context)) is not None ): return ctx.default_return_type - field_type = get_field_type_from_model_type_info(model_type.type, field_name) - if field_type is not None: - return field_type - - if (django_model := DjangoModel.from_model_type(model_type, django_context)) is None: - return ctx.default_return_type - try: field = django_model.cls._meta.get_field(field_name) - if field_info := helpers.lookup_class_typeinfo(helpers.get_typechecker_api(ctx), field.__class__): + api = helpers.get_typechecker_api(ctx) + if field_type := helpers.get_field_type_from_model_type_info(api, ctx.context, model_type.type, field_name): + return field_type + if field_info := helpers.lookup_class_typeinfo(api, field.__class__): return Instance(field_info, []) except FieldDoesNotExist as e: ctx.api.fail(str(e), ctx.context) diff --git a/mypy_django_plugin/transformers/models.py b/mypy_django_plugin/transformers/models.py index ddae84828..9366d296e 100644 --- a/mypy_django_plugin/transformers/models.py +++ b/mypy_django_plugin/transformers/models.py @@ -23,7 +23,7 @@ ) from mypy.plugins import common from mypy.semanal import SemanticAnalyzer -from mypy.typeanal import TypeAnalyser +from mypy.typeanal import TypeAnalyser, make_optional_type from mypy.types import AnyType, Instance, ProperType, TypedDictType, TypeOfAny, TypeType, TypeVarType, get_proper_type from mypy.types import Type as MypyType from mypy.typevars import fill_typevars @@ -33,7 +33,6 @@ from mypy_django_plugin.exceptions import UnregisteredModelError from mypy_django_plugin.lib import fullnames, helpers from mypy_django_plugin.lib.fullnames import ANNOTATIONS_FULLNAME -from mypy_django_plugin.transformers.fields import FieldDescriptorTypes, get_field_descriptor_types from mypy_django_plugin.transformers.managers import ( MANAGER_METHODS_RETURNING_QUERYSET, create_manager_info_from_from_queryset_call, @@ -295,14 +294,9 @@ def create_autofield( if existing_field: auto_field_fullname = helpers.get_class_fullname(auto_field.__class__) auto_field_info = self.lookup_typeinfo_or_incomplete_defn_error(auto_field_fullname) - - set_type, get_type = get_field_descriptor_types( - auto_field_info, - is_set_nullable=True, - is_get_nullable=False, - ) - - self.add_new_var_to_model_class(dest_name, Instance(auto_field_info, [set_type, get_type])) + # Auto fields accept None (for auto-increment / omitting the value) + field_instance = helpers.fill_field_defaults(auto_field_info, self.api, is_set_nullable=True) + self.add_new_var_to_model_class(dest_name, field_instance) class AddPrimaryKeyAlias(AddDefaultPrimaryKey): @@ -350,11 +344,17 @@ def run_with_model_cls(self, model_cls: type[Model]) -> None: raise exc continue - is_nullable = self.django_context.get_field_nullability(field, None) - set_type, get_type = get_field_descriptor_types( - field_info, is_set_nullable=is_nullable, is_get_nullable=is_nullable - ) - self.add_new_var_to_model_class(field.attname, Instance(field_info, [set_type, get_type])) + field_instance = helpers.fill_field_defaults(field_info, self.api) + is_nullable = self.django_context.get_field_nullability(field) + if is_nullable: + field_instance = field_instance.copy_modified( + args=[ + make_optional_type(field_instance.args[0]), + make_optional_type(field_instance.args[1]), + *field_instance.args[2:], + ] + ) + self.add_new_var_to_model_class(field.attname, field_instance) class AddManagers(ModelClassInitializer): @@ -770,10 +770,7 @@ def default_pk_instance(self) -> Instance: default_pk_field = self.lookup_typeinfo(self.django_context.settings.DEFAULT_AUTO_FIELD) if default_pk_field is None: raise helpers.IncompleteDefnException() - return Instance( - default_pk_field, - list(get_field_descriptor_types(default_pk_field, is_set_nullable=True, is_get_nullable=False)), - ) + return helpers.fill_field_defaults(default_pk_field, self.api, is_set_nullable=True) @cached_property def model_pk_instance(self) -> Instance: @@ -808,8 +805,8 @@ def manager_info(self) -> TypeInfo: return info @cached_property - def fk_field_types(self) -> FieldDescriptorTypes: - return get_field_descriptor_types(self.fk_field, is_set_nullable=False, is_get_nullable=False) + def fk_field_defaults(self) -> Instance: + return helpers.fill_field_defaults(self.fk_field, self.api) @cached_property def many_related_manager(self) -> TypeInfo: @@ -842,6 +839,12 @@ def create_through_table_class( through_model = self.lookup_typeinfo(model_fullname) if through_model is not None: return through_model + # Ensure the default PK field type is available before creating the through + # model class. Accessing this cached_property may raise IncompleteDefnException + # if the DEFAULT_AUTO_FIELD hasn't been analyzed yet. By resolving it first, + # we avoid registering an empty through model that would be returned on the + # next pass without its attributes populated. + default_pk = self.default_pk_instance # Declare a new, empty, implicitly generated through model class named: '_' through_model = self.add_new_class_for_current_module(model_name, bases=[Instance(self.model_base, [])]) # We attempt to be a bit clever here and store the generated through model's fullname in @@ -852,21 +855,18 @@ def create_through_table_class( model_metadata.setdefault("m2m_throughs", {}) model_metadata["m2m_throughs"][field_name] = through_model.fullname # Add a 'pk' symbol to the model class - helpers.add_new_sym_for_info(through_model, name="pk", sym_type=self.default_pk_instance.copy_modified()) + helpers.add_new_sym_for_info(through_model, name="pk", sym_type=default_pk.copy_modified()) # Add an 'id' symbol to the model class - helpers.add_new_sym_for_info(through_model, name="id", sym_type=self.default_pk_instance.copy_modified()) + helpers.add_new_sym_for_info(through_model, name="id", sym_type=default_pk.copy_modified()) # Add the foreign key to the model containing the 'ManyToManyField' call: # or from_ from_name = f"from_{self.model_classdef.name.lower()}" if m2m_args.to.self else self.model_classdef.name.lower() + containing_model_type = Instance(self.model_classdef.info, []) helpers.add_new_sym_for_info( through_model, name=from_name, - sym_type=Instance( - self.fk_field, - [ - helpers.convert_any_to_type(self.fk_field_types.set, Instance(self.model_classdef.info, [])), - helpers.convert_any_to_type(self.fk_field_types.get, Instance(self.model_classdef.info, [])), - ], + sym_type=helpers.reparametrize_field_type( + self.fk_field_defaults, set_type=containing_model_type, get_type=containing_model_type ), ) # Add the foreign key's '_id' field: _id or from__id @@ -882,12 +882,8 @@ def create_through_table_class( helpers.add_new_sym_for_info( through_model, name=to_name, - sym_type=Instance( - self.fk_field, - [ - helpers.convert_any_to_type(self.fk_field_types.set, m2m_args.to.model), - helpers.convert_any_to_type(self.fk_field_types.get, m2m_args.to.model), - ], + sym_type=helpers.reparametrize_field_type( + self.fk_field_defaults, set_type=m2m_args.to.model, get_type=m2m_args.to.model ), ) # Add the foreign key's '_id' field: _id or to__id diff --git a/mypy_django_plugin/transformers/orm_lookups.py b/mypy_django_plugin/transformers/orm_lookups.py index d58652827..a842f1df0 100644 --- a/mypy_django_plugin/transformers/orm_lookups.py +++ b/mypy_django_plugin/transformers/orm_lookups.py @@ -76,7 +76,7 @@ def _typecheck_defaults_kwarg( return api = helpers.get_typechecker_api(ctx) - expected_types = django_context.get_expected_types(api, django_model.cls, method="create") + expected_types = django_context.get_expected_types(api, django_model.cls, ctx.context) model_name = django_model.cls.__name__ for idx in defaults_positions: diff --git a/mypy_django_plugin/transformers/querysets.py b/mypy_django_plugin/transformers/querysets.py index da21359f1..6408cbee7 100644 --- a/mypy_django_plugin/transformers/querysets.py +++ b/mypy_django_plugin/transformers/querysets.py @@ -106,15 +106,13 @@ def get_field_type_from_lookup( if lookup_field is None: return AnyType(TypeOfAny.implementation_artifact) - if (isinstance(lookup_field, RelatedField) and lookup_field.column == lookup) or isinstance( - lookup_field, ForeignObjectRel - ): + if isinstance(lookup_field, RelatedField) or isinstance(lookup_field, ForeignObjectRel): model_cls = django_context.get_field_related_model_cls(lookup_field) lookup_field = django_context.get_primary_key_field(model_cls) api = helpers.get_typechecker_api(ctx) model_info = helpers.lookup_class_typeinfo(api, model_cls) - return django_context.get_field_get_type(api, model_info, lookup_field, method=method) + return helpers.get_field_get_type_from_model_type_info(api, ctx.context, model_info, lookup_field.attname) def get_values_list_row_type( @@ -143,8 +141,8 @@ def get_values_list_row_type( if named: column_types: dict[str, MypyType] = {} for field in django_context.get_model_fields(model_cls): - column_type = django_context.get_field_get_type( - typechecker_api, model_info, field, method="values_list" + column_type = helpers.get_field_get_type_from_model_type_info( + typechecker_api, ctx.context, model_info, field.attname ) column_types[field.attname] = column_type column_types.update(annotation_types) @@ -291,11 +289,14 @@ def reparameterize_func_output_field(ctx: FunctionContext) -> MypyType: return ctx.default_return_type # Use the output_field argument type to fill the generic param - output_field_type = helpers.get_call_argument_type_by_name(ctx, "output_field") - if output_field_type is not None: - field_type = get_proper_type(output_field_type) - if isinstance(field_type, Instance): - return default.copy_modified(args=[field_type]) + output_field_type = get_proper_type(helpers.get_call_argument_type_by_name(ctx, "output_field")) + if output_field_type is not None and isinstance(output_field_type, Instance): + output_field_type = ( + helpers.fill_field_defaults(output_field_type.type, helpers.get_typechecker_api(ctx)) + if all(isinstance(get_proper_type(arg), AnyType) for arg in output_field_type.args) + else output_field_type + ) + return default.copy_modified(args=[output_field_type]) return ctx.default_return_type @@ -340,10 +341,10 @@ def _resolve_output_field_type(expr_type: MypyType) -> MypyType | None: if isinstance(func_type, CallableType): field_type = get_proper_type(func_type.ret_type) - if isinstance(field_type, Instance): - result = helpers.get_private_descriptor_type(field_type.type, "_pyi_private_get_type", is_nullable=False) - if not isinstance(get_proper_type(result), AnyType): - return result + if isinstance(field_type, Instance) and field_type.type.has_base(fullnames.FIELD_FULLNAME): + type_args = helpers.get_field_type_args(field_type) + if type_args is not None and not isinstance(type_args.get, AnyType): + return type_args.get return None diff --git a/scripts/django_tests_settings.py b/scripts/django_tests_settings.py index f7d92fb0a..0d3bba957 100644 --- a/scripts/django_tests_settings.py +++ b/scripts/django_tests_settings.py @@ -2,6 +2,26 @@ # The following installed apps are required for stubtest to run correctly. from __future__ import annotations +import pathlib + +_REPO_ROOT = pathlib.Path(__file__).resolve().parent.parent +_ASSERT_TYPE_APPS_ROOT = _REPO_ROOT / "tests" / "assert_type" + + +def _discover_assert_type_apps() -> list[str]: + """Discover Django apps under ``tests/assert_type``. + + Each directory containing a ``models.py`` is registered as an app, + with the directory name as the implicit ``app_label``. + """ + if not _ASSERT_TYPE_APPS_ROOT.is_dir(): + return [] + return sorted( + ".".join(models_py.parent.relative_to(_REPO_ROOT).parts) + for models_py in _ASSERT_TYPE_APPS_ROOT.rglob("models.py") + ) + + INSTALLED_APPS = [ "django.contrib.auth", "django.contrib.admin", @@ -10,6 +30,7 @@ "django.contrib.redirects", "django.contrib.sessions", "django.contrib.sites", + *_discover_assert_type_apps(), ] STATIC_URL = "static/" diff --git a/scripts/stubtest/allowlist.txt b/scripts/stubtest/allowlist.txt index 5a9ea0b09..650675596 100644 --- a/scripts/stubtest/allowlist.txt +++ b/scripts/stubtest/allowlist.txt @@ -237,14 +237,15 @@ django.urls.resolvers.ResolverMatch.__iter__ django.template.smartif.key django.template.smartif.op -# Field.__get__/__set__ are stub-only for the mypy plugin, they don't exist at runtime +# Field.__get__/__set__ are stub-only for the type checkers, they don't exist at runtime django.db.models.fields.Field.__get__ django.db.models.fields.Field.__set__ django.db.models.fields.files.FileField.__get__ -django.db.models.fields.files.ImageField.__get__ django.db.models.fields.related.ForeignObject.__get__ +django.db.models.fields.related.ForeignObject.__set__ django.db.models.fields.related.ManyToManyField.__get__ django.db.models.fields.related.OneToOneField.__get__ +django.db.models.fields.related.OneToOneField.__set__ # These are dynamically added at runtime with loose types django.utils.functional.LazyObject.__bool__ diff --git a/tests/assert_type/contrib/admin/test_options.py b/tests/assert_type/contrib/admin/test_options.py index f905b561b..2f32874a8 100644 --- a/tests/assert_type/contrib/admin/test_options.py +++ b/tests/assert_type/contrib/admin/test_options.py @@ -53,7 +53,7 @@ class FullModelAdmin(admin.ModelAdmin[FullAdminModel]): "another_field": admin.HORIZONTAL, } prepopulated_fields = {"slug": ("title",)} - formfield_overrides = {models.TextField: {"widget": Textarea}} # pyright: ignore[reportUnknownVariableType] + formfield_overrides = {models.TextField: {"widget": Textarea}} readonly_fields = ("date_modified",) ordering = ("-pk", "date_modified") sortable_by = ["pk"] @@ -177,7 +177,7 @@ class InlineParentModel(models.Model): class InlineChildModel(models.Model): - parent = models.ForeignKey(InlineParentModel, on_delete=models.CASCADE) # pyright: ignore[reportUnknownVariableType] + parent = models.ForeignKey(InlineParentModel, on_delete=models.CASCADE) class ParentObjInline(admin.StackedInline[InlineChildModel, InlineParentModel]): diff --git a/tests/assert_type/contrib/admin/test_utils.py b/tests/assert_type/contrib/admin/test_utils.py index a258623a6..24e75b357 100644 --- a/tests/assert_type/contrib/admin/test_utils.py +++ b/tests/assert_type/contrib/admin/test_utils.py @@ -19,13 +19,13 @@ @admin.display(description="Name") def upper_case_name(obj: Person) -> str: - return f"{obj.first_name} {obj.last_name}".upper() # pyright: ignore[reportUnknownMemberType] + return f"{obj.first_name} {obj.last_name}".upper() class Person(models.Model): - first_name = models.CharField(max_length=None) # pyright: ignore[reportUnknownVariableType] - last_name = models.CharField(max_length=None) # pyright: ignore[reportUnknownVariableType] - birthday = models.DateField() # pyright: ignore[reportUnknownVariableType] + first_name = models.CharField(max_length=None) + last_name = models.CharField(max_length=None) + birthday = models.DateField() class PersonListAdmin(admin.ModelAdmin[Person]): diff --git a/tests/assert_type/contrib/auth/test_decorators.py b/tests/assert_type/contrib/auth/test_decorators.py index 8f0434e95..38a77a7e0 100644 --- a/tests/assert_type/contrib/auth/test_decorators.py +++ b/tests/assert_type/contrib/auth/test_decorators.py @@ -12,11 +12,11 @@ lazy_url = reverse_lazy("namespace:url") -@user_passes_test(lambda user: user.is_active, login_url=reversed_url) +@user_passes_test(lambda user: user.is_active, login_url=reversed_url) # pyrefly: ignore[bad-argument-type] def my_view1(request: HttpRequest) -> HttpResponse: raise NotImplementedError -@user_passes_test(lambda user: user.is_active, login_url=lazy_url) +@user_passes_test(lambda user: user.is_active, login_url=lazy_url) # pyrefly: ignore[bad-argument-type] def my_view2(request: HttpRequest) -> HttpResponse: raise NotImplementedError diff --git a/tests/assert_type/db/models/fields/direct_field_null_true_does_not_trigger_nullability_check/models.py b/tests/assert_type/db/models/fields/direct_field_null_true_does_not_trigger_nullability_check/models.py new file mode 100644 index 000000000..2337c199c --- /dev/null +++ b/tests/assert_type/db/models/fields/direct_field_null_true_does_not_trigger_nullability_check/models.py @@ -0,0 +1,27 @@ +from __future__ import annotations + +from typing import Literal + +from django.db import models +from django.db.models import IntegerField +from django.db.models.expressions import OuterRef, Subquery +from typing_extensions import assert_type + + +class Article(models.Model): + pass + + +def direct_field_null_true_does_not_trigger_nullability_check() -> None: + null_field = models.IntegerField(null=True) + assert_type(null_field, IntegerField[float | int | str, int, Literal[True]]) + + not_null_field = models.IntegerField(null=False) + assert_type(not_null_field, IntegerField[float | int | str, int, Literal[False]]) + + Article.objects.annotate( + other_id=Subquery( + Article.objects.filter(id=OuterRef("id")).values_list("id", flat=True)[:1], + output_field=models.IntegerField(null=False), + ) + ) diff --git a/tests/assert_type/db/models/fields/setting_value_to_an_array_of_ints/models.py b/tests/assert_type/db/models/fields/setting_value_to_an_array_of_ints/models.py new file mode 100644 index 000000000..0b3d345de --- /dev/null +++ b/tests/assert_type/db/models/fields/setting_value_to_an_array_of_ints/models.py @@ -0,0 +1,39 @@ +from __future__ import annotations + +from django.contrib.postgres.fields import ArrayField +from django.db import models +from django.db.models import F + + +class MyModel(models.Model): + array = ArrayField(base_field=models.IntegerField()) + + +non_init = MyModel() + + +def setting_value_to_a_tuple_of_ints_ok() -> None: + array_val: tuple[int, ...] = (1,) + MyModel(array=array_val) + non_init.array = array_val + + +def setting_value_to_an_array_of_ints_ok() -> None: + array_val2: list[int] = [1] + MyModel(array=array_val2) + non_init.array = array_val2 + + +def setting_value_to_an_array_of_invalid_type_error() -> None: + class NotAValid: + pass + + array_val3: list[NotAValid] = [NotAValid()] + MyModel(array=array_val3) # type: ignore[misc] + non_init.array = array_val3 # type: ignore[assignment] # pyright: ignore[reportAttributeAccessIssue] # ty:ignore[invalid-assignment] # pyrefly: ignore[no-matching-overload] + + +def setting_value_to_an_array_of_combinable_error() -> None: + array_val4: list[F] = [F("id")] + MyModel(array=array_val4) # type: ignore[misc] + non_init.array = array_val4 # type: ignore[assignment] # pyright: ignore[reportAttributeAccessIssue] # ty:ignore[invalid-assignment] # pyrefly: ignore[no-matching-overload] diff --git a/tests/assert_type/db/models/fields/test_add_id_field_if_no_primary_key_defined/models.py b/tests/assert_type/db/models/fields/test_add_id_field_if_no_primary_key_defined/models.py new file mode 100644 index 000000000..2c3c456dd --- /dev/null +++ b/tests/assert_type/db/models/fields/test_add_id_field_if_no_primary_key_defined/models.py @@ -0,0 +1,12 @@ +from __future__ import annotations + +from django.db import models +from typing_extensions import assert_type + + +class User(models.Model): + pass + + +def test_add_id_field_if_no_primary_key_defined() -> None: + assert_type(User().id, int) # pyrefly:ignore[missing-attribute, assert-type] # ty:ignore[type-assertion-failure, unresolved-attribute] # pyright:ignore[reportAttributeAccessIssue,reportUnknownMemberType, reportAssertTypeFailure] diff --git a/tests/assert_type/db/models/fields/test_array_field.py b/tests/assert_type/db/models/fields/test_array_field.py new file mode 100644 index 000000000..e43d490d5 --- /dev/null +++ b/tests/assert_type/db/models/fields/test_array_field.py @@ -0,0 +1,38 @@ +from __future__ import annotations + +import uuid +from typing import Any + +from django.contrib.postgres.fields import ArrayField +from django.db import models +from typing_extensions import assert_type + + +def nullable_array_field() -> None: + class MyModel(models.Model): + lst = ArrayField(base_field=models.CharField(max_length=100), null=False) + null_lst = ArrayField(base_field=models.CharField(max_length=100), null=True) + + assert_type(MyModel().lst, list[str]) # False positive -> # pyrefly: ignore[assert-type] + assert_type(MyModel().null_lst, list[str] | None) # False positive -> # pyrefly: ignore[assert-type] + + my_model = MyModel() + random_uuid = uuid.uuid4() + + my_model.lst = None # type: ignore[call-overload] # pyright: ignore[reportAttributeAccessIssue] # ty:ignore[invalid-assignment] # pyrefly: ignore[no-matching-overload] + my_model.lst = [random_uuid, random_uuid] # type: ignore[list-item] # pyright: ignore[reportAttributeAccessIssue] # ty:ignore[invalid-assignment] # pyrefly: ignore[no-matching-overload] + + my_model.null_lst = None # OK + my_model.null_lst = [random_uuid, random_uuid] # type: ignore[list-item] # pyright: ignore[reportAttributeAccessIssue] # ty:ignore[invalid-assignment] # pyrefly: ignore[no-matching-overload] + + +def array_field_base_field_parsed_into_generic_typevar() -> None: + class MyModel(models.Model): + untyped = ArrayField(base_field=models.Field()) + members = ArrayField(base_field=models.IntegerField()) + members_as_text = ArrayField(base_field=models.CharField(max_length=255)) + + my_model = MyModel(untyped=[], members=[1, 2], members_as_text=["A", "B"]) + assert_type(my_model.untyped, list[Any]) + assert_type(my_model.members, list[int]) # False positive -> # pyrefly: ignore[assert-type] + assert_type(my_model.members_as_text, list[str]) # False positive -> # pyrefly: ignore[assert-type] diff --git a/tests/assert_type/db/models/fields/test_base.py b/tests/assert_type/db/models/fields/test_base.py new file mode 100644 index 000000000..f11cbaff6 --- /dev/null +++ b/tests/assert_type/db/models/fields/test_base.py @@ -0,0 +1,249 @@ +from __future__ import annotations + +import datetime +import decimal +import uuid +from typing import Any, Literal, NewType, cast + +from django.db import models +from django.db.models import CharField +from django.db.models.fields import IntegerField, _FieldDescriptor +from typing_extensions import assert_type + + +class AllFields(models.Model): + id = models.AutoField(primary_key=True) + + # Integer-family + integer = models.IntegerField() + small_int = models.SmallIntegerField() + big_int = models.BigIntegerField() + pos_int = models.PositiveIntegerField() + pos_small_int = models.PositiveSmallIntegerField() + pos_big_int = models.PositiveBigIntegerField() + null_integer = models.IntegerField(null=True) + null_small_int = models.SmallIntegerField(null=True) + null_big_int = models.BigIntegerField(null=True) + null_pos_int = models.PositiveIntegerField(null=True) + null_pos_small_int = models.PositiveSmallIntegerField(null=True) + null_pos_big_int = models.PositiveBigIntegerField(null=True) + + # Float / Decimal + flt = models.FloatField() + null_flt = models.FloatField(null=True) + dec = models.DecimalField(max_digits=10, decimal_places=5) + null_dec = models.DecimalField(max_digits=10, decimal_places=5, null=True) + + # Char-family + name = models.CharField(max_length=255) + null_name = models.CharField(max_length=255, null=True) + slug = models.SlugField(max_length=255) + null_slug = models.SlugField(max_length=255, null=True) + text = models.TextField() + null_text = models.TextField(null=True) + csv_int = models.CommaSeparatedIntegerField(max_length=255) + null_csv_int = models.CommaSeparatedIntegerField(max_length=255, null=True) + email = models.EmailField() + null_email = models.EmailField(null=True) + url = models.URLField() + null_url = models.URLField(null=True) + + # Boolean + flag = models.BooleanField() + null_flag = models.BooleanField(null=True) + + # IP addresses + ip = models.IPAddressField() + null_ip = models.IPAddressField(null=True) + gen_ip = models.GenericIPAddressField() + null_gen_ip = models.GenericIPAddressField(null=True) + + # Date / time / duration + day = models.DateField() + null_day = models.DateField(null=True) + moment = models.DateTimeField() + null_moment = models.DateTimeField(null=True) + clock = models.TimeField() + null_clock = models.TimeField(null=True) + duration = models.DurationField() + null_duration = models.DurationField(null=True) + + # UUID + uid = models.UUIDField() + null_uid = models.UUIDField(null=True) + + # Binary + blob = models.BinaryField() + null_blob = models.BinaryField(null=True) + + # JSON / FilePath + payload = models.JSONField() + null_payload = models.JSONField(null=True) + payload_with_db_default = models.JSONField(default=dict, db_default={}) + path = models.FilePathField() + null_path = models.FilePathField(null=True) + + +instance = AllFields() +assert_type(instance.id, int) + +assert_type(instance.integer, int) +assert_type(instance.null_integer, int | None) +assert_type(instance.small_int, int) +assert_type(instance.null_small_int, int | None) +assert_type(instance.big_int, int) +assert_type(instance.null_big_int, int | None) +assert_type(instance.pos_int, int) +assert_type(instance.null_pos_int, int | None) +assert_type(instance.pos_small_int, int) +assert_type(instance.null_pos_small_int, int | None) +assert_type(instance.pos_big_int, int) +assert_type(instance.null_pos_big_int, int | None) + +assert_type(instance.flt, float) +assert_type(instance.null_flt, float | None) +assert_type(instance.dec, decimal.Decimal) +assert_type(instance.null_dec, decimal.Decimal | None) + +assert_type(instance.name, str) +assert_type(instance.null_name, str | None) +assert_type(instance.slug, str) +assert_type(instance.null_slug, str | None) +assert_type(instance.text, str) +assert_type(instance.null_text, str | None) +assert_type(instance.csv_int, str) +assert_type(instance.null_csv_int, str | None) +assert_type(instance.email, str) +assert_type(instance.null_email, str | None) +assert_type(instance.url, str) +assert_type(instance.null_url, str | None) + +assert_type(instance.flag, bool) +assert_type(instance.null_flag, bool | None) + +assert_type(instance.ip, str) +assert_type(instance.null_ip, str | None) +assert_type(instance.gen_ip, str) +assert_type(instance.null_gen_ip, str | None) + +assert_type(instance.day, datetime.date) +assert_type(instance.null_day, datetime.date | None) +assert_type(instance.moment, datetime.datetime) +assert_type(instance.null_moment, datetime.datetime | None) +assert_type(instance.clock, datetime.time) +assert_type(instance.null_clock, datetime.time | None) +assert_type(instance.duration, datetime.timedelta) +assert_type(instance.null_duration, datetime.timedelta | None) + +assert_type(instance.uid, uuid.UUID) +assert_type(instance.null_uid, uuid.UUID | None) + +assert_type(instance.blob, bytes | memoryview[int]) +assert_type(instance.null_blob, bytes | memoryview[int] | None) + +assert_type(instance.payload, Any) +assert_type(instance.null_payload, Any | None) +assert_type(instance.payload_with_db_default, Any) +assert_type(instance.path, Any) +assert_type(instance.null_path, Any | None) + + +def if_field_called_on_class_return_field_itself() -> None: + assert_type( + AllFields.name.field, + CharField[str | int, str, Literal[False]], + ) + assert_type( + AllFields.null_name.field, + CharField[str | int, str, Literal[True]], + ) + + +def null_char_field_allows_none() -> None: + AllFields(null_name="") + AllFields(null_name=None) + AllFields().null_name = None + + +def not_null_charfield_does_not_allow_none() -> None: + AllFields(name="") + AllFields(name=None) + AllFields().name = None # type: ignore[call-overload] # pyrefly:ignore[no-matching-overload] # ty:ignore[invalid-assignment] # pyright:ignore[reportAttributeAccessIssue] + + +def fields_on_non_model_classes_resolve_to_field_type() -> None: + class MyClass: + myfield = models.IntegerField[int, int]() + + assert_type(MyClass.myfield, _FieldDescriptor[IntegerField[int, int, Literal[False]]]) + assert_type(MyClass.myfield.field, IntegerField[int, int, Literal[False]]) + assert_type(MyClass().myfield, IntegerField[int, int, Literal[False]]) + + +def fields_inside_mixins_used_in_model_subclasses_resolved_as_primitives() -> None: + class AuthMixin(models.Model): + class Meta: + abstract = True + + username = models.CharField(max_length=100) + null_username = models.CharField(max_length=100, null=True) + + class MyModel(AuthMixin, models.Model): + pass + + assert_type(MyModel().username, str) + assert_type(MyModel().null_username, str | None) + + +def test_small_auto_field_class_presents_as_int() -> None: + class MyModel(models.Model): + small = models.SmallAutoField(primary_key=True) + + obj = MyModel() + + assert_type(obj.small, int) + + +def can_narrow_field_type() -> None: + Year = NewType("Year", int) + + class Book(models.Model): + published = cast("models.Field[Year, Year]", models.IntegerField()) + + book = Book() + assert_type(book.published, Year) + book.published = ( # type: ignore[call-overload] # pyrefly:ignore[no-matching-overload] # ty:ignore[invalid-assignment] # pyright:ignore[reportAttributeAccessIssue] + 2006 + ) + book.published = Year(2006) + assert_type(book.published, Year) # N: Revealed type is "main.Year" + + def accepts_int(arg: int) -> None: ... + + accepts_int(book.published) + + +def test_ignores_renamed_field() -> None: + """ + Ref: https://github.com/typeddjango/django-stubs/issues/1261 + Django modifies the model so it doesn't have 'modelname', but we don't follow + along. But the 'name=' argument to a field isn't a documented feature. + """ + + class RenamedField(models.Model): + modelname = models.IntegerField(name="fieldname", choices=((1, "One"),)) + + instance = RenamedField() + assert_type(instance.modelname, int) + instance.fieldname # type: ignore[attr-defined] # pyrefly:ignore[missing-attribute] # ty:ignore[unresolved-attribute] # pyright:ignore[reportAttributeAccessIssue,reportUnknownMemberType] + instance.modelname = 1 + instance.fieldname = 1 # type: ignore[attr-defined] # pyrefly:ignore[missing-attribute] # ty:ignore[unresolved-attribute] # pyright:ignore[reportAttributeAccessIssue] + + +def nullable_field_with_strict_optional_true() -> None: + class MyModel(models.Model): + text_nullable = models.CharField(max_length=100, null=True) + text = models.CharField(max_length=100) + + MyModel().text = None # type: ignore[call-overload] # pyrefly:ignore[no-matching-overload] # ty:ignore[invalid-assignment] # pyright:ignore[reportAttributeAccessIssue] + MyModel().text_nullable = None diff --git a/tests/assert_type/db/models/fields/test_custom_fields.py b/tests/assert_type/db/models/fields/test_custom_fields.py new file mode 100644 index 000000000..d063b79bd --- /dev/null +++ b/tests/assert_type/db/models/fields/test_custom_fields.py @@ -0,0 +1,233 @@ +from __future__ import annotations + +from typing import Any, Generic, Literal + +from django.db import models +from django.db.models.expressions import Combinable, F +from django.db.models.fields import _GT, _NT, _ST +from typing_extensions import TypeVar, assert_type + +T = TypeVar("T") + + +class CustomFieldValue: ... + + +# `bool` is not assignable to upper bound `Literal[False, True]` of type variable `_NT` +# TODO: ty should reject that too +class InvalidCustomField(models.Field[_ST, _GT, bool]): # type:ignore[type-var] # pyrefly: ignore[bad-specialization] # pyright: ignore[reportInvalidTypeArguments] + pass + + +def custom_generic_field_override_typevar_defaults() -> None: + class GenericField(models.Field[_ST, _GT, _NT]): ... + + class MyModel(models.Model): + field = GenericField[CustomFieldValue | int, CustomFieldValue]() + null_field = GenericField[CustomFieldValue | int, CustomFieldValue | None, Literal[True]](null=True) + conflict_field = GenericField[CustomFieldValue | int, CustomFieldValue, Literal[False]](null=True) # type: ignore[arg-type] # pyright: ignore[reportArgumentType] # pyrefly: ignore[bad-argument-type] # ty: ignore[invalid-argument-type] + conflict_null_field = GenericField[CustomFieldValue | int, CustomFieldValue, Literal[True]](null=False) # type: ignore[arg-type] # pyright: ignore[reportArgumentType] # pyrefly: ignore[bad-argument-type] # ty: ignore[invalid-argument-type] + + instance = MyModel() + assert_type(instance.field, CustomFieldValue) + assert_type(instance.null_field, CustomFieldValue | None) + + +def single_type_field() -> None: + class SingleTypeField(models.Field[T, T, _NT]): ... + + class MyModel(models.Model): + field = SingleTypeField[bool]() + explicit_null_field = SingleTypeField[bool | None, Literal[True]](null=True) + conflict_null_field = SingleTypeField[bool](null=True) # pyright: ignore[reportArgumentType] # pyrefly: ignore[bad-argument-type] # ty: ignore[invalid-argument-type] + explicit_conflict_null_field = SingleTypeField[bool, Literal[False]](null=True) # type: ignore[arg-type] # pyright: ignore[reportArgumentType] # pyrefly: ignore[bad-argument-type] # ty: ignore[invalid-argument-type] + + instance = MyModel() + assert_type(instance.field, bool) + assert_type(instance.explicit_null_field, bool | None) + + +def custom_explicit_get_set_field() -> None: + class CustomValueField(models.Field[CustomFieldValue | int, CustomFieldValue, _NT]): ... + + class MyModel(models.Model): + field = CustomValueField() + null_field = CustomValueField(null=True) + + instance = MyModel() + assert_type(instance.field, CustomFieldValue) + assert_type(instance.null_field, CustomFieldValue | None) + instance.field = CustomFieldValue() + instance.field = 12 + instance.field = "NoNo" # type: ignore[call-overload] # pyrefly:ignore[no-matching-overload] # ty:ignore[invalid-assignment] # pyright:ignore[reportAttributeAccessIssue] + + +def custom_generic_field() -> None: + _ST_Int = TypeVar("_ST_Int", contravariant=True, default=float | int | str | Combinable) + _GT_Int = TypeVar("_GT_Int", covariant=True, default=int) + + class CustomSmallIntegerField(models.SmallIntegerField[_ST_Int, _GT_Int, _NT]): ... + + class MyModel(models.Model): + field = CustomSmallIntegerField() + null_field = CustomSmallIntegerField(null=True) + + instance = MyModel() + assert_type(instance.field, int) + assert_type(instance.null_field, int | None) + instance.field = 1.2 + instance.field = 12 + instance.field = "12" + instance.field = F("id") + instance.field = CustomFieldValue() # type: ignore[call-overload] # pyrefly:ignore[no-matching-overload] # ty:ignore[invalid-assignment] # pyright:ignore[reportAttributeAccessIssue] + + +def additional_typevar_field() -> None: + _ST_Custom = TypeVar("_ST_Custom", contravariant=True, default=CustomFieldValue | int) + _GT_Custom = TypeVar("_GT_Custom", covariant=True, default=CustomFieldValue) + + class AdditionalTypeVarField( + models.Field[_ST_Custom, _GT_Custom, _NT], Generic[T, _ST_Custom, _GT_Custom, _NT] + ): ... + + class MyModel(models.Model): + field = AdditionalTypeVarField[bool]() + null_field = AdditionalTypeVarField[bool, CustomFieldValue | int, CustomFieldValue, Literal[True]](null=True) + + instance = MyModel() + assert_type(instance.field, CustomFieldValue) + assert_type(instance.null_field, CustomFieldValue | None) + + +def field_implicit_any() -> None: + # This is inferred as models.Field[Any, Any, Literal[False]] + class FieldImplicitAny(models.Field): ... + + class MyModel(models.Model): + field = FieldImplicitAny() + null_field = FieldImplicitAny(null=True) # pyright: ignore[reportArgumentType] # pyrefly: ignore[bad-argument-type] # ty: ignore[invalid-argument-type] + + instance = MyModel() + assert_type(instance.field, Any) + assert_type(instance.null_field, Any) # type:ignore[assert-type] # Mypy says `Any | None` which is a bit odd + + +def field_explicit_any() -> None: + class FieldExplicitAny(models.Field[Any, Any, Any]): ... + + class MyModel(models.Model): + field = FieldExplicitAny() + null_field = FieldExplicitAny(null=True) + + instance = MyModel() + assert_type(instance.field, Any) + assert_type(instance.null_field, Any) + + +def field_two_typevar_form_is_still_accepted() -> None: + class LegacyField(models.Field[CustomFieldValue | int, CustomFieldValue]): ... + + class MyModel(models.Model): + field = LegacyField() + null_field = LegacyField(null=True) # pyright: ignore[reportArgumentType] # pyrefly: ignore[bad-argument-type] # ty: ignore[invalid-argument-type] + + instance = MyModel() + assert_type(instance.field, CustomFieldValue) + assert_type(instance.null_field, CustomFieldValue | None) # pyright: ignore[reportAssertTypeFailure] # pyrefly: ignore[assert-type] # ty: ignore[type-assertion-failure] + instance.field = CustomFieldValue() + instance.field = 12 + + +def field_two_typevar_form_in_user_annotation() -> None: + # Legacy `field: Field[A, B] = CustomField()` annotations with a legacy `CustomField` (without `_NT`) + class CustomField(models.Field[CustomFieldValue | int, CustomFieldValue]): ... + + class MyModel(models.Model): + field: models.Field[CustomFieldValue | int, CustomFieldValue] = CustomField() + implicit_null_field = CustomField(null=True) # pyright: ignore[reportArgumentType] # pyrefly: ignore[bad-argument-type] # ty: ignore[invalid-argument-type] + explicit_null_field: models.Field[CustomFieldValue | int, CustomFieldValue | None, Literal[True]] = CustomField( # pyright: ignore[reportAssignmentType] # pyrefly: ignore[bad-assignment] # ty: ignore[invalid-assignment] + null=True # pyright: ignore[reportArgumentType] # pyrefly: ignore[bad-argument-type] # ty: ignore[invalid-argument-type] + ) + null_field: models.Field[CustomFieldValue | int, CustomFieldValue | None] = CustomField(null=True) # type: ignore[arg-type] # pyright: ignore[reportArgumentType] # pyrefly: ignore[bad-argument-type] # ty: ignore[invalid-argument-type] + + instance = MyModel() + assert_type(instance.field, CustomFieldValue) + assert_type(instance.null_field, CustomFieldValue | None) + assert_type(instance.implicit_null_field, CustomFieldValue | None) # pyright: ignore[reportAssertTypeFailure] # pyrefly: ignore[assert-type] # ty: ignore[type-assertion-failure] + assert_type(instance.explicit_null_field, CustomFieldValue | None) + instance.field = CustomFieldValue() + instance.field = 12 + instance.field = "no" # type: ignore[call-overload] # pyrefly: ignore[no-matching-overload] # ty: ignore[invalid-assignment] # pyright: ignore[reportAttributeAccessIssue] + instance.null_field = None # type: ignore[call-overload] # pyrefly: ignore[no-matching-overload] # ty: ignore[invalid-assignment] # pyright: ignore[reportAttributeAccessIssue] + instance.null_field = CustomFieldValue() + + +def nullable_subclass_inherits_null_overload() -> None: + _ST_Text = TypeVar("_ST_Text", contravariant=True, default=str | int | Combinable) + _GT_Text = TypeVar("_GT_Text", covariant=True, default=str) + _ST_Int = TypeVar("_ST_Int", contravariant=True, default=float | int | str | Combinable) + _GT_Int = TypeVar("_GT_Int", covariant=True, default=int) + _NT = TypeVar("_NT", Literal[True], Literal[False], default=Literal[False]) + + class HtmlField(models.TextField[_ST_Text, _GT_Text, _NT]): ... + + class IntWrap(models.IntegerField[_ST_Int, _GT_Int, _NT]): ... + + class Article(models.Model): + body = HtmlField() + body_nullable = HtmlField(null=True) + count = IntWrap() + count_nullable = IntWrap(null=True) + + assert_type(Article().body, str) + assert_type(Article().body_nullable, str | None) + assert_type(Article().count, int) + assert_type(Article().count_nullable, int | None) + + +def nullable_field_subclass_without_explicit_type_vars() -> None: + """ + We auto add typevars in mypy plugin which avoid issues here. + + TODO: False positive pyrefly/ty/pyright + """ + + class HTMLField(models.TextField): ... + + class IntWrap(models.IntegerField): ... + + class FieldMixin: ... + + class MySlugField(models.SlugField, FieldMixin): ... + + class Article(models.Model): + body = HTMLField() + body_nullable = HTMLField(null=True) # pyright: ignore[reportArgumentType] # pyrefly: ignore[bad-argument-type] # ty: ignore[invalid-argument-type] + count_nullable = IntWrap(null=True) # pyright: ignore[reportArgumentType] # pyrefly: ignore[bad-argument-type] # ty: ignore[invalid-argument-type] + slug_nullable = MySlugField(null=True) # pyright: ignore[reportArgumentType] # pyrefly: ignore[bad-argument-type] # ty: ignore[invalid-argument-type] + + assert_type(Article().body, str) + assert_type(Article().body_nullable, str | None) # pyrefly: ignore[assert-type] # ty: ignore[type-assertion-failure] # pyright: ignore[reportAssertTypeFailure] + assert_type(Article().count_nullable, int | None) # pyrefly: ignore[assert-type] # ty: ignore[type-assertion-failure] # pyright: ignore[reportAssertTypeFailure] + assert_type(Article().slug_nullable, str | None) # pyrefly: ignore[assert-type] # ty: ignore[type-assertion-failure] # pyright: ignore[reportAssertTypeFailure] + + +def test_custom_model_fields_override_init() -> None: + """ + TODO: False positive ty/mypy + """ + _ST_Int = TypeVar("_ST_Int", contravariant=True, default=float | int | str) + _GT_Int = TypeVar("_GT_Int", covariant=True, default=int) + _NT = TypeVar("_NT", Literal[True], Literal[False], default=Literal[False]) + + class MyIntegerField(models.IntegerField[_ST_Int, _GT_Int, _NT]): + def __init__(self, *args: Any, null: _NT = False, **kwargs: Any) -> None: # type:ignore[assignment] # ty:ignore[invalid-parameter-default] + kwargs["null"] = null + super().__init__(*args, **kwargs) + + class User(models.Model): + custom_int = MyIntegerField(null=False) + custom_int_nullable = MyIntegerField(null=True) + + assert_type(User().custom_int, int) + assert_type(User().custom_int_nullable, int | None) diff --git a/tests/assert_type/db/models/fields/test_do_not_add_id_if_field_with_primary_key_True_defined/models.py b/tests/assert_type/db/models/fields/test_do_not_add_id_if_field_with_primary_key_True_defined/models.py new file mode 100644 index 000000000..bb4cf847b --- /dev/null +++ b/tests/assert_type/db/models/fields/test_do_not_add_id_if_field_with_primary_key_True_defined/models.py @@ -0,0 +1,13 @@ +from __future__ import annotations + +from django.db import models +from typing_extensions import assert_type + + +class User(models.Model): + my_pk = models.IntegerField(primary_key=True) + + +def test_do_not_add_id_if_field_with_primary_key_True_defined() -> None: + assert_type(User().my_pk, int) + User().id # type: ignore[attr-defined] # pyrefly:ignore[missing-attribute] # ty:ignore[unresolved-attribute] # pyright:ignore[reportAttributeAccessIssue,reportUnknownMemberType] diff --git a/tests/assert_type/db/models/fields/test_dummy_migration.py b/tests/assert_type/db/models/fields/test_dummy_migration.py new file mode 100644 index 000000000..96ef16db0 --- /dev/null +++ b/tests/assert_type/db/models/fields/test_dummy_migration.py @@ -0,0 +1,135 @@ +# Regression test for false positives like this due to incomplete `Field` annotations. +# Type annotations must provide the type vars otherwise they get inferred as `Field[Any, Any, Literal[False]]` in some +# contexts, which is usually wrong. Annotations accepting any field should use `Field[Any, Any, Any]` +from __future__ import annotations + +from typing import TYPE_CHECKING + +import django.db.models.deletion +from django.db import migrations, models + +if TYPE_CHECKING: + from django.db.backends.base.schema import BaseDatabaseSchemaEditor + from django.db.migrations.state import StateApps + + +def forwards(apps: StateApps, schema_editor: BaseDatabaseSchemaEditor) -> None: + pass + + +def backwards(apps: StateApps, schema_editor: BaseDatabaseSchemaEditor) -> None: + pass + + +class Migration(migrations.Migration): + dependencies = [ + ("corporate", "0037_customerplanoffer"), + ] + + operations = [ + migrations.AddField( + model_name="customerplanoffer", + name="sent_invoice_id", + field=models.CharField(max_length=255, null=True), + ), + migrations.AlterField( + model_name="customerplanoffer", + name="sent_invoice_id", + field=models.CharField(max_length=512, null=True), + preserve_default=False, + ), + migrations.RenameField( + model_name="customerplanoffer", + old_name="sent_invoice_id", + new_name="invoice_id", + ), + migrations.RemoveField( + model_name="customerplanoffer", + name="invoice_id", + ), + migrations.CreateModel( + name="Invoice", + fields=[ + ( + "id", + models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID"), + ), + ("stripe_invoice_id", models.CharField(max_length=255, unique=True)), + ("status", models.SmallIntegerField()), + ( + "customer", + models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to="corporate.customer"), + ), + ], + options={"db_table": "invoice"}, + ), + migrations.RenameModel(old_name="Invoice", new_name="StripeInvoice"), + migrations.AlterModelTable(name="stripeinvoice", table="billing_stripe_invoice"), + migrations.AlterModelTableComment(name="stripeinvoice", table_comment="Stripe invoices"), + migrations.AlterModelOptions( + name="stripeinvoice", + options={"ordering": ["-id"], "verbose_name": "Stripe invoice"}, + ), + migrations.AlterModelManagers( + name="stripeinvoice", + managers=[("objects", models.Manager())], + ), + migrations.AlterUniqueTogether( + name="stripeinvoice", + unique_together={("stripe_invoice_id", "customer")}, + ), + migrations.AlterIndexTogether( + name="stripeinvoice", + index_together={("status", "customer")}, + ), + migrations.AlterOrderWithRespectTo( + name="stripeinvoice", + order_with_respect_to="customer", + ), + migrations.AddIndex( + model_name="stripeinvoice", + index=models.Index(fields=["stripe_invoice_id"], name="invoice_stripe_idx"), + ), + migrations.RenameIndex( + model_name="stripeinvoice", + new_name="invoice_stripe_id_idx", + old_name="invoice_stripe_idx", + ), + migrations.RemoveIndex( + model_name="stripeinvoice", + name="invoice_stripe_id_idx", + ), + migrations.AddConstraint( + model_name="stripeinvoice", + constraint=models.UniqueConstraint(fields=["stripe_invoice_id"], name="invoice_stripe_id_uniq"), + ), + migrations.AlterConstraint( + model_name="stripeinvoice", + name="invoice_stripe_id_uniq", + constraint=models.UniqueConstraint(fields=["stripe_invoice_id", "customer"], name="invoice_stripe_id_uniq"), + ), + migrations.RemoveConstraint( + model_name="stripeinvoice", + name="invoice_stripe_id_uniq", + ), + migrations.RunSQL( + sql="UPDATE billing_stripe_invoice SET status = 0 WHERE status IS NULL", + reverse_sql=migrations.RunSQL.noop, + ), + migrations.RunPython( + code=forwards, + reverse_code=backwards, + ), + migrations.SeparateDatabaseAndState( + database_operations=[ + migrations.RunSQL("CREATE INDEX legacy_idx ON billing_stripe_invoice(status)"), + ], + state_operations=[ + migrations.AddIndex( + model_name="stripeinvoice", + index=models.Index(fields=["status"], name="legacy_idx"), + ), + ], + ), + migrations.DeleteModel(name="StripeInvoice"), + ] diff --git a/tests/assert_type/db/models/fields/test_field_set_type_honors_type_redefinition/models.py b/tests/assert_type/db/models/fields/test_field_set_type_honors_type_redefinition/models.py new file mode 100644 index 000000000..12196c0f8 --- /dev/null +++ b/tests/assert_type/db/models/fields/test_field_set_type_honors_type_redefinition/models.py @@ -0,0 +1,45 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, cast + +from django.contrib.postgres.fields import ArrayField +from django.db import models +from typing_extensions import assert_type + +if TYPE_CHECKING: + from collections.abc import Sequence + + +class FieldRedefinitionModel(models.Model): + redefined_set_type = cast("models.Field[int, int]", models.IntegerField()) + redefined_union_set_type = cast("models.Field[int | float, int]", models.IntegerField()) + redefined_array_set_type = cast( + "ArrayField[Sequence[int | float], int]", + ArrayField(base_field=models.IntegerField()), + ) + default_set_type = models.IntegerField() + unset_set_type = cast("models.Field", models.IntegerField()) + + +def test_field_set_type_honors_type_redefinition() -> None: + non_init = FieldRedefinitionModel() + assert_type(non_init.redefined_set_type, int) + assert_type(non_init.redefined_union_set_type, int) + assert_type(non_init.redefined_array_set_type, list[int]) + assert_type(non_init.default_set_type, int) + assert_type(non_init.unset_set_type, Any) + + non_init.redefined_set_type = "invalid" # type: ignore[call-overload] # pyright: ignore[reportAttributeAccessIssue] # pyrefly: ignore[no-matching-overload] # ty: ignore[invalid-assignment] + non_init.redefined_union_set_type = "invalid" # type: ignore[call-overload] # pyright: ignore[reportAttributeAccessIssue] # pyrefly: ignore[no-matching-overload] # ty: ignore[invalid-assignment] + array_val: list[str] = ["invalid"] + non_init.redefined_array_set_type = array_val # type: ignore[assignment] # pyright: ignore[reportAttributeAccessIssue] # pyrefly: ignore[no-matching-overload] # ty: ignore[invalid-assignment] + non_init.default_set_type = [] # type: ignore[call-overload] # pyright: ignore[reportAttributeAccessIssue] # pyrefly: ignore[no-matching-overload] # ty: ignore[invalid-assignment] + non_init.unset_set_type = [] + + FieldRedefinitionModel( # type: ignore[misc] + redefined_set_type="invalid", + redefined_union_set_type="invalid", + redefined_array_set_type=33, + default_set_type=[], + unset_set_type=[], + ) diff --git a/tests/assert_type/db/models/fields/test_files.py b/tests/assert_type/db/models/fields/test_files.py index 44a5be26b..3d63aaf31 100644 --- a/tests/assert_type/db/models/fields/test_files.py +++ b/tests/assert_type/db/models/fields/test_files.py @@ -9,7 +9,21 @@ class MyModel(models.Model): file = models.FileField() image = models.ImageField() + null_file = models.FileField(null=True) + null_image = models.ImageField(null=True) + instance = MyModel() -assert_type(instance.file, FieldFile) # pyrefly: ignore[assert-type] -assert_type(instance.image, ImageFieldFile) # pyrefly: ignore[assert-type] +# At runtime, FileDescriptor.__get__ ALWAYS returns a FieldFile even when the underlying database value is NULL. +# It wraps None in FieldFile(instance, field, name=None). +# For ex: +# In [4]: Page.objects.get(video__isnull=False).video +# Out[4]: +# +# In [5]: Page.objects.get(video__isnull=True).video +# Out[5]: + +assert_type(instance.file, FieldFile) +assert_type(instance.image, ImageFieldFile) +assert_type(instance.null_file, FieldFile) +assert_type(instance.null_image, ImageFieldFile) diff --git a/tests/assert_type/db/models/fields/test_generic_typevars.py b/tests/assert_type/db/models/fields/test_generic_typevars.py new file mode 100644 index 000000000..59dff280f --- /dev/null +++ b/tests/assert_type/db/models/fields/test_generic_typevars.py @@ -0,0 +1,61 @@ +from __future__ import annotations + +from typing import Literal + +from django.db.models import AutoField, CharField, IntegerField + +# --- `_ST` is contravariant --- +# A field whose set-type is wider can stand in for one whose set-type is narrower. +wide_set: AutoField[int | str, int] = AutoField() # ty: ignore[invalid-assignment] +narrow_set: AutoField[int, int] = wide_set + +# Reverse direction is rejected: a narrower set-type cannot stand in for a wider one. +narrow_set2: AutoField[int, int] = AutoField[int, int]() +rejected_set: AutoField[int | str, int] = narrow_set2 # type: ignore[assignment] # pyright: ignore[reportAssignmentType] # pyrefly: ignore[bad-assignment] # ty: ignore[invalid-assignment] + + +# --- `_GT` is covariant --- +# A field whose get-type is narrower can stand in for one whose get-type is wider. +narrow_get: AutoField[int, int] = AutoField[int, int]() +wide_get: AutoField[int, int | str] = narrow_get + +# Reverse direction is rejected: a wider get-type cannot stand in for a narrower one. +wide_get2: AutoField[int, int | str] = AutoField[int, int | str]() +rejected_get: AutoField[int, int] = wide_get2 # type: ignore[assignment] # pyright: ignore[reportAssignmentType] # pyrefly: ignore[bad-assignment] # ty: ignore[invalid-assignment] + + +# --- `_NT` is invariant --- +# Nullable and non-nullable fields are not interchangeable in either direction. +not_null: IntegerField[int, int, Literal[False]] = IntegerField() +nullable: IntegerField[int, int, Literal[True]] = IntegerField(null=True) +bad_to_nullable: IntegerField[int, int, Literal[True]] = not_null # type: ignore[assignment] # pyright: ignore[reportAssignmentType] # pyrefly: ignore[bad-assignment] # ty: ignore[invalid-assignment] +bad_to_non_nullable: IntegerField[int, int, Literal[False]] = nullable # type: ignore[assignment] # pyright: ignore[reportAssignmentType] # pyrefly: ignore[bad-assignment] # ty: ignore[invalid-assignment] + + +# --- Variance with non-default TypeVar bounds --- +# CharField specializes the bounds (`_ST=str | int`, `_GT=str`); the same variance rules still apply. + +# ST contravariance: source set-type wider than target → OK. +char_wide_set: CharField[str | int, str] = CharField() # ty: ignore[invalid-assignment] +char_narrow_set: CharField[str, str] = char_wide_set + +# GT covariance: source get-type narrower than target → OK. +char_narrow_get: CharField[str, str] = CharField[str, str]() +char_wide_get: CharField[str, str | bytes] = char_narrow_get + +# ST narrowing in source → forbidden. +char_narrow: CharField[str, str] = CharField[str, str]() +char_rejected_st: CharField[str | int, str] = char_narrow # type: ignore[assignment] # pyright: ignore[reportAssignmentType] # pyrefly: ignore[bad-assignment] # ty: ignore[invalid-assignment] + +# GT widening in source → forbidden. +char_wide: CharField[str, str | bytes] = CharField[str, str | bytes]() +char_rejected_gt: CharField[str, str] = char_wide # type: ignore[assignment] # pyright: ignore[reportAssignmentType] # pyrefly: ignore[bad-assignment] # ty: ignore[invalid-assignment] + + +# --- Subclass relationships respect ST/GT variance --- +# A more concrete `Field` subtype is assignable to its base `Field` so long as ST/GT are compatible. +auto: AutoField[int, int] = AutoField[int, int]() +as_int: IntegerField[int, int] = auto + +# The reverse — base to derived — is not allowed. +as_auto: AutoField[int, int] = IntegerField[int, int]() # type: ignore[assignment] # pyright: ignore[reportAssignmentType] # pyrefly: ignore[bad-assignment] # ty: ignore[invalid-assignment] diff --git a/tests/assert_type/db/models/fields/test_nullable.py b/tests/assert_type/db/models/fields/test_nullable.py new file mode 100644 index 000000000..e93d89678 --- /dev/null +++ b/tests/assert_type/db/models/fields/test_nullable.py @@ -0,0 +1,17 @@ +from __future__ import annotations + +from typing import Any + +from django.db import models + + +def field_null_true_expression_does_not_trigger_nullability_check() -> None: + """ + Field[Any, Any, Any] as function type arg should accept both nullable and non-nullable fields + """ + + def take_field(f: models.Field[Any, Any, Any]) -> None: + return None + + take_field(models.IntegerField(null=True)) + take_field(models.IntegerField(null=False)) diff --git a/tests/assert_type/db/models/fields/test_postgres.py b/tests/assert_type/db/models/fields/test_postgres.py new file mode 100644 index 000000000..bad0142e8 --- /dev/null +++ b/tests/assert_type/db/models/fields/test_postgres.py @@ -0,0 +1,17 @@ +from __future__ import annotations + +from typing import Any + +from django.contrib.postgres import fields as pg_fields +from django.db import models +from typing_extensions import assert_type + + +class Booking(models.Model): + time_range = pg_fields.DateTimeRangeField(null=False) + null_time_range = pg_fields.DateTimeRangeField(null=True) + + +booking = Booking() +assert_type(booking.time_range, Any) # pyrefly: ignore[assert-type] # ty: ignore[type-assertion-failure] # pyright: ignore[reportAssertTypeFailure] +assert_type(booking.null_time_range, Any) # pyrefly: ignore[assert-type]# ty: ignore[type-assertion-failure]# pyright: ignore[reportAssertTypeFailure] diff --git a/tests/assert_type/db/models/fields/test_related.py b/tests/assert_type/db/models/fields/test_related.py deleted file mode 100644 index 3c2b9e7ab..000000000 --- a/tests/assert_type/db/models/fields/test_related.py +++ /dev/null @@ -1,15 +0,0 @@ -from __future__ import annotations - -from django.db import models - - -class Author(models.Model): - pass - - -class Book(models.Model): - author = models.ForeignKey(Author, on_delete=models.CASCADE, swappable=False) # pyright: ignore[reportUnknownVariableType] - - -class Profile(models.Model): - user = models.OneToOneField(Author, on_delete=models.CASCADE, swappable=False) # pyright: ignore[reportUnknownVariableType] diff --git a/tests/assert_type/db/models/fields/test_related_forward/__init__.py b/tests/assert_type/db/models/fields/test_related_forward/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/assert_type/db/models/fields/test_related_forward/models.py b/tests/assert_type/db/models/fields/test_related_forward/models.py new file mode 100644 index 000000000..8a0d9b235 --- /dev/null +++ b/tests/assert_type/db/models/fields/test_related_forward/models.py @@ -0,0 +1,21 @@ +from __future__ import annotations + +from django.db import models +from typing_extensions import assert_type + + +class Author(models.Model): + pass + + +class Book(models.Model): + author = models.ForeignKey(Author, on_delete=models.CASCADE, swappable=False) + + +class Profile(models.Model): + user = models.OneToOneField(Author, on_delete=models.CASCADE, swappable=False) + + +def test_related() -> None: + assert_type(Book().author, Author) # ty: ignore[type-assertion-failure] # pyright: ignore[reportAssertTypeFailure] + assert_type(Profile().user, Author) # ty: ignore[type-assertion-failure] # pyright: ignore[reportAssertTypeFailure] diff --git a/tests/assert_type/db/models/functions/test_func.py b/tests/assert_type/db/models/functions/test_func.py new file mode 100644 index 000000000..1adc9591c --- /dev/null +++ b/tests/assert_type/db/models/functions/test_func.py @@ -0,0 +1,49 @@ +from __future__ import annotations + +from typing import Literal + +from django.db.models import BinaryField +from django.db.models.fields import CharField +from django.db.models.functions import Left, Right, Substr +from typing_extensions import assert_type + + +def func_resolve_output_field() -> None: + def expect_func_binary(func: Substr[BinaryField] | Left[BinaryField] | Right[BinaryField]) -> None: + return None + + bin_sub = Substr("username", 1, 100, output_field=BinaryField()) + str_sub = Substr("username", 1, 100) # Default to `CharField` per `Substr.output_field` + + bin_left = Left("username", 5, output_field=BinaryField()) + str_left = Left("username", 5) # Default to `CharField` per `Left.output_field` + + bin_right = Right("username", 5, output_field=BinaryField()) + str_right = Right("username", 5) # Default to `CharField` per `Right.output_field` + + assert_type( # False positive -> # ty: ignore[type-assertion-failure] + bin_sub, + Substr[BinaryField[bytes | bytearray | memoryview[int], bytes | memoryview[int], Literal[False]]], + ) + assert_type(str_sub, Substr[CharField[str | int, str, Literal[False]]]) + + assert_type( # False positive -> # ty: ignore[type-assertion-failure] + bin_left, + Left[BinaryField[bytes | bytearray | memoryview[int], bytes | memoryview[int], Literal[False]]], + ) + assert_type(str_left, Left[CharField[str | int, str, Literal[False]]]) + + assert_type( # False positive -> # ty: ignore[type-assertion-failure] + bin_right, + Right[BinaryField[bytes | bytearray | memoryview[int], bytes | memoryview[int], Literal[False]]], + ) + assert_type(str_right, Right[CharField[str | int, str, Literal[False]]]) + + expect_func_binary(bin_sub) # False positive -> # ty: ignore[invalid-argument-type] + expect_func_binary(str_sub) # type: ignore[arg-type] # pyright: ignore[reportArgumentType] # pyrefly: ignore[bad-argument-type] # ty: ignore[invalid-argument-type] + + expect_func_binary(bin_left) # False positive -> # ty: ignore[invalid-argument-type] + expect_func_binary(str_left) # type: ignore[arg-type] # pyright: ignore[reportArgumentType] # pyrefly: ignore[bad-argument-type] # ty: ignore[invalid-argument-type] + + expect_func_binary(bin_right) # False positive -> # ty: ignore[invalid-argument-type] + expect_func_binary(str_right) # type: ignore[arg-type] # pyright: ignore[reportArgumentType] # pyrefly: ignore[bad-argument-type] # ty: ignore[invalid-argument-type] diff --git a/tests/assert_type/views/test_generic.py b/tests/assert_type/views/test_generic.py index 9140f5380..384112ac5 100644 --- a/tests/assert_type/views/test_generic.py +++ b/tests/assert_type/views/test_generic.py @@ -10,7 +10,8 @@ from typing_extensions import assert_type -class MyModel(models.Model): ... +class MyModel(models.Model): + pass class MyDetailView(SingleObjectMixin[MyModel]): ... diff --git a/tests/typecheck/fields/test_base.yml b/tests/typecheck/fields/test_base.yml deleted file mode 100644 index 1ffc566c7..000000000 --- a/tests/typecheck/fields/test_base.yml +++ /dev/null @@ -1,217 +0,0 @@ -- case: test_model_fields_classes_present_as_primitives - main: | - from typing_extensions import reveal_type - from myapp.models import User - user = User(small_int=1, name='user', slug='user', text='user') - reveal_type(user.id) # N: Revealed type is "int" - reveal_type(user.small_int) # N: Revealed type is "int" - reveal_type(user.name) # N: Revealed type is "str" - reveal_type(user.slug) # N: Revealed type is "str" - reveal_type(user.text) # N: Revealed type is "str" - installed_apps: - - myapp - files: - - path: myapp/__init__.py - - path: myapp/models.py - content: | - from django.db import models - class User(models.Model): - id = models.AutoField(primary_key=True) - small_int = models.SmallIntegerField() - name = models.CharField(max_length=255) - slug = models.SlugField(max_length=255) - text = models.TextField() - -- case: test_model_field_classes_from_existing_locations - main: | - from typing_extensions import reveal_type - from myapp.models import Booking - booking = Booking() - reveal_type(booking.id) # N: Revealed type is "int" - reveal_type(booking.time_range) # N: Revealed type is "Any" - reveal_type(booking.some_decimal) # N: Revealed type is "decimal.Decimal" - installed_apps: - - myapp - files: - - path: myapp/__init__.py - - path: myapp/models.py - content: | - from django.db import models - from django.contrib.postgres import fields as pg_fields - from decimal import Decimal - - class Booking(models.Model): - id = models.AutoField(primary_key=True) - time_range = pg_fields.DateTimeRangeField(null=False) - some_decimal = models.DecimalField(max_digits=10, decimal_places=5) - -- case: test_add_id_field_if_no_primary_key_defined - disable_cache: true - main: | - from typing_extensions import reveal_type - from myapp.models import User - reveal_type(User().id) # N: Revealed type is "int" - installed_apps: - - myapp - files: - - path: myapp/__init__.py - - path: myapp/models.py - content: | - from django.db import models - class User(models.Model): - pass - -- case: test_do_not_add_id_if_field_with_primary_key_True_defined - disable_cache: true - main: | - from typing_extensions import reveal_type - from myapp.models import User - reveal_type(User().my_pk) # N: Revealed type is "int" - User().id # E: "User" has no attribute "id" [attr-defined] - installed_apps: - - myapp - files: - - path: myapp/__init__.py - - path: myapp/models.py - content: | - from django.db import models - class User(models.Model): - my_pk = models.IntegerField(primary_key=True) - -- case: blank_and_null_char_field_allows_none - main: | - from typing_extensions import reveal_type - from myapp.models import MyModel - MyModel(nulltext="") - MyModel(nulltext=None) - MyModel().nulltext=None - reveal_type(MyModel().nulltext) # N: Revealed type is "str | None" - installed_apps: - - myapp - files: - - path: myapp/__init__.py - - path: myapp/models.py - content: | - from django.db import models - class MyModel(models.Model): - nulltext=models.CharField(max_length=1, blank=True, null=True) - -- case: blank_and_not_null_charfield_does_not_allow_none - main: | - from typing_extensions import reveal_type - from myapp.models import MyModel - MyModel(notnulltext=None) # E: Incompatible type for "notnulltext" of "MyModel" (got "None", expected "str | int | Combinable") [misc] - MyModel(notnulltext="") - MyModel().notnulltext = None # E: Incompatible types in assignment (expression has type "None", variable has type "str | int | Combinable") [assignment] - reveal_type(MyModel().notnulltext) # N: Revealed type is "str" - installed_apps: - - myapp - files: - - path: myapp/__init__.py - - path: myapp/models.py - content: | - from django.db import models - class MyModel(models.Model): - notnulltext=models.CharField(max_length=1, blank=True, null=False) - -- case: if_field_called_on_class_return_field_itself - main: | - from typing_extensions import reveal_type - from myapp.models import MyUser - reveal_type(MyUser.name.field) # N: Revealed type is "django.db.models.fields.CharField[str | int | django.db.models.expressions.Combinable, str]" - installed_apps: - - myapp - files: - - path: myapp/__init__.py - - path: myapp/models.py - content: | - from django.db import models - class MyUser(models.Model): - name = models.CharField(max_length=100) - -- case: fields_on_non_model_classes_resolve_to_field_type - main: | - from typing_extensions import reveal_type - from django.db import models - class MyClass: - myfield: models.IntegerField[int, int] - reveal_type(MyClass.myfield) # N: Revealed type is "django.db.models.fields._FieldDescriptor[django.db.models.fields.IntegerField[int, int]]" - reveal_type(MyClass.myfield.field) # N: Revealed type is "django.db.models.fields.IntegerField[int, int]" - reveal_type(MyClass().myfield) # N: Revealed type is "django.db.models.fields.IntegerField[int, int]" - -- case: fields_inside_mixins_used_in_model_subclasses_resolved_as_primitives - main: | - from typing_extensions import reveal_type - from myapp.models import MyModel, AuthMixin - reveal_type(MyModel().username) # N: Revealed type is "str" - installed_apps: - - myapp - files: - - path: myapp/__init__.py - - path: myapp/models.py - content: | - from django.db import models - class AuthMixin(models.Model): - class Meta: - abstract = True - username = models.CharField(max_length=100) - - class MyModel(AuthMixin, models.Model): - pass -- case: can_narrow_field_type - main: | - from typing import cast, NewType - from typing_extensions import reveal_type - from django.db import models - Year = NewType("Year", int) - class Book(models.Model): - published = cast(models.Field[Year, Year], models.IntegerField()) - book = Book() - reveal_type(book.published) # N: Revealed type is "main.Year" - book.published = 2006 # E: Incompatible types in assignment (expression has type "int", variable has type "Year") [assignment] - book.published = Year(2006) - reveal_type(book.published) # N: Revealed type is "main.Year" - def accepts_int(arg: int) -> None: ... - accepts_int(book.published) - -- case: test_binary_field_return_types - main: | - from typing_extensions import reveal_type - from django.db import models - class EncodedMessage(models.Model): - message = models.BinaryField() - obj = EncodedMessage(b'\x010') - - reveal_type(obj.message) # N: Revealed type is "bytes | memoryview[int]" - -- case: test_small_auto_field_class_presents_as_int - main: | - from typing_extensions import reveal_type - from django.db import models - class MyModel(models.Model): - small = models.SmallAutoField(primary_key=True) - obj = MyModel() - - reveal_type(obj.small) # N: Revealed type is "int" - -- case: test_ignores_renamed_field - main: | - # Ref: https://github.com/typeddjango/django-stubs/issues/1261 - # Django modifies the model so it doesn't have 'modelname', but we don't follow - # along. But the 'name=' argument to a field isn't a documented feature. - from typing_extensions import reveal_type - from myapp.models import RenamedField - instance = RenamedField() - reveal_type(instance.modelname) # N: Revealed type is "int" - instance.fieldname # E: "RenamedField" has no attribute "fieldname" [attr-defined] - instance.modelname = 1 - instance.fieldname = 1 # E: "RenamedField" has no attribute "fieldname" [attr-defined] - installed_apps: - - myapp - files: - - path: myapp/__init__.py - - path: myapp/models.py - content: | - from django.db import models - class RenamedField(models.Model): - modelname = models.IntegerField(name="fieldname", choices=((1, 'One'),)) diff --git a/tests/typecheck/fields/test_custom_fields.yml b/tests/typecheck/fields/test_custom_fields.yml deleted file mode 100644 index cf7455f7f..000000000 --- a/tests/typecheck/fields/test_custom_fields.yml +++ /dev/null @@ -1,88 +0,0 @@ -- case: test_custom_model_fields_with_generic_type - main: | - from typing_extensions import reveal_type - from myapp.models import User, CustomFieldValue - user = User() - reveal_type(user.id) # N: Revealed type is "int" - reveal_type(user.my_custom_field1) # N: Revealed type is "myapp.models.CustomFieldValue" - reveal_type(user.my_custom_field2) # N: Revealed type is "myapp.models.CustomFieldValue" - reveal_type(user.my_custom_field3) # N: Revealed type is "bool" - reveal_type(user.my_custom_field4) # N: Revealed type is "myapp.models.CustomFieldValue" - reveal_type(user.my_custom_field5) # N: Revealed type is "myapp.models.CustomFieldValue" - reveal_type(user.my_custom_field6) # N: Revealed type is "myapp.models.CustomFieldValue" - reveal_type(user.my_custom_field7) # N: Revealed type is "bool" - reveal_type(user.my_custom_field8) # N: Revealed type is "myapp.models.CustomFieldValue" - reveal_type(user.my_custom_field9) # N: Revealed type is "myapp.models.CustomFieldValue" - reveal_type(user.my_custom_field10) # N: Revealed type is "bool" - reveal_type(user.my_custom_field11) # N: Revealed type is "bool" - reveal_type(user.my_custom_field12) # N: Revealed type is "myapp.models.CustomFieldValue | None" - reveal_type(user.my_custom_field13) # N: Revealed type is "myapp.models.CustomFieldValue | None" - reveal_type(user.my_custom_field14) # N: Revealed type is "bool | None" - reveal_type(user.my_custom_field15) # N: Revealed type is "None" - - reveal_type(user.my_custom_field_any1) # N: Revealed type is "Any" - reveal_type(user.my_custom_field_any2) # N: Revealed type is "Any" - reveal_type(user.my_custom_field_any3) # N: Revealed type is "Any | None" - reveal_type(user.my_custom_field_any4) # N: Revealed type is "Any" - monkeypatch: true - installed_apps: - - myapp - files: - - path: myapp/__init__.py - - path: myapp/models.py - content: | - from django.db import models - from django.db.models import fields - - from typing import Any, Generic - from typing_extensions import TypeVar - - _ST = TypeVar("_ST", contravariant=True) - _GT = TypeVar("_GT", covariant=True) - - T = TypeVar("T") - - class CustomFieldValue: ... - - class GenericField(fields.Field[_ST, _GT]): ... - - class SingleTypeField(fields.Field[T, T]): ... - - class CustomValueField(fields.Field[CustomFieldValue | int, CustomFieldValue]): ... - - class AdditionalTypeVarField(fields.Field[_ST, _GT], Generic[_ST, _GT, T]): ... - - class CustomSmallIntegerField(fields.SmallIntegerField[_ST, _GT]): ... - - class FieldImplicitAny(fields.Field): ... - class FieldExplicitAny(fields.Field[Any, Any]): ... - - class User(models.Model): - id = models.AutoField(primary_key=True) - my_custom_field1 = GenericField[CustomFieldValue | int, CustomFieldValue]() - my_custom_field2 = CustomValueField() - my_custom_field3 = SingleTypeField[bool]() - my_custom_field4 = AdditionalTypeVarField[CustomFieldValue | int, CustomFieldValue, bool]() - my_custom_field_any1 = FieldImplicitAny() - my_custom_field_any2 = FieldExplicitAny() - - # test null=True on fields with non-optional generic types throw error - my_custom_field5 = GenericField[CustomFieldValue | int, CustomFieldValue](null=True) # E: GenericField is nullable but its generic get type parameter is not optional [misc] - my_custom_field6 = CustomValueField(null=True) # E: CustomValueField is nullable but its generic get type parameter is not optional [misc] - my_custom_field7 = SingleTypeField[bool](null=True) # E: SingleTypeField is nullable but its generic get type parameter is not optional [misc] - my_custom_field8 = AdditionalTypeVarField[CustomFieldValue | int, CustomFieldValue, bool](null=True) # E: AdditionalTypeVarField is nullable but its generic get type parameter is not optional [misc] - my_custom_field9 = fields.Field[CustomFieldValue | int, CustomFieldValue](null=True) # E: Field is nullable but its generic get type parameter is not optional [misc] - - # test overriding fields that set _pyi_private_set_type or _pyi_private_get_type - my_custom_field10 = fields.SmallIntegerField[bool, bool]() - my_custom_field11 = CustomSmallIntegerField[bool, bool]() - - # test null=True on fields with non-optional generic types throw no errors - my_custom_field12 = fields.Field[CustomFieldValue | int, CustomFieldValue | None](null=True) - my_custom_field13 = GenericField[CustomFieldValue | int, CustomFieldValue | None](null=True) - my_custom_field14 = SingleTypeField[bool | None](null=True) - my_custom_field15 = fields.Field[None, None](null=True) - - # test null=True on Any does not raise - my_custom_field_any3 = FieldImplicitAny(null=True) - my_custom_field_any4 = FieldExplicitAny(null=True) diff --git a/tests/typecheck/fields/test_nullable.yml b/tests/typecheck/fields/test_nullable.yml index cf00f3bf5..d264776e4 100644 --- a/tests/typecheck/fields/test_nullable.yml +++ b/tests/typecheck/fields/test_nullable.yml @@ -31,48 +31,18 @@ pass class MyModelExplicitPK(models.Model): id = models.AutoField(primary_key=True) -- case: nullable_field_with_strict_optional_true - main: | - from typing_extensions import reveal_type - from myapp.models import MyModel - reveal_type(MyModel().text) # N: Revealed type is "str" - reveal_type(MyModel().text_nullable) # N: Revealed type is "str | None" - MyModel().text = None # E: Incompatible types in assignment (expression has type "None", variable has type "str | int | Combinable") [assignment] - MyModel().text_nullable = None - installed_apps: - - myapp - files: - - path: myapp/__init__.py - - path: myapp/models.py - content: | - from django.db import models - class MyModel(models.Model): - text_nullable = models.CharField(max_length=100, null=True) - text = models.CharField(max_length=100) - -- case: nullable_array_field - main: | - from typing_extensions import reveal_type - from myapp.models import MyModel - reveal_type(MyModel().lst) # N: Revealed type is "list[str] | None" - installed_apps: - - myapp - files: - - path: myapp/__init__.py - - path: myapp/models.py - content: | - from django.db import models - from django.contrib.postgres.fields import ArrayField - - class MyModel(models.Model): - lst = ArrayField(base_field=models.CharField(max_length=100), null=True) - case: nullable_foreign_key main: | from typing_extensions import reveal_type from myapp.models import Publisher, Book - reveal_type(Book().publisher) # N: Revealed type is "myapp.models.Publisher | None" - Book().publisher = 11 # E: Incompatible types in assignment (expression has type "int", variable has type "Publisher | Combinable | None") [assignment] + reveal_type(Book().publisher) + Book().publisher = 11 + out: | + main:3: note: Revealed type is "myapp.models.Publisher | None" + main:4: error: No overload variant of "__set__" of "ForeignObject" matches argument types "Book", "int" [call-overload] + main:4: note: Possible overload variants: + main:4: note: def __set__(self, instance: Any, value: Publisher | Combinable | None) -> None installed_apps: - myapp files: @@ -105,93 +75,3 @@ from django.db import models class Inventory(models.Model): parent = models.ForeignKey('self', on_delete=models.SET_NULL, null=True) - -- case: nullable_subclass_inherits_null_overload - main: | - from typing_extensions import reveal_type - from myapp.models import Article - - reveal_type(Article().body) # N: Revealed type is "str" - reveal_type(Article().body_nullable) # N: Revealed type is "str | None" - reveal_type(Article().count) # N: Revealed type is "int" - reveal_type(Article().count_nullable) # N: Revealed type is "int | None" - monkeypatch: true - installed_apps: - - myapp - files: - - path: myapp/__init__.py - - path: myapp/models.py - content: | - from typing_extensions import TypeVar - from django.db import models - - _ST = TypeVar("_ST", contravariant=True) - _GT = TypeVar("_GT", covariant=True) - - class HtmlField(models.TextField[_ST, _GT]): ... - class IntWrap(models.IntegerField[_ST, _GT]): ... - - class Article(models.Model): - body = HtmlField() - body_nullable = HtmlField(null=True) - count = IntWrap() - count_nullable = IntWrap(null=True) - -- case: direct_field_null_true_does_not_trigger_nullability_check - main: | - from typing_extensions import reveal_type - from django.db import models - from django.db.models import OuterRef, Subquery - from myapp.models import Article - - field = models.IntegerField(null=True) - reveal_type(field) # N: Revealed type is "django.db.models.fields.IntegerField[float | int | str | django.db.models.expressions.Combinable | None, int | None]" - Article.objects.annotate( - other_id=Subquery( - Article.objects.filter(id=OuterRef("id")).values_list("id", flat=True)[:1], - output_field=models.IntegerField(null=True), - ) - ) - installed_apps: - - myapp - files: - - path: myapp/__init__.py - - path: myapp/models.py - content: | - from django.db import models - class Article(models.Model): - pass - -- case: field_null_true_expression_does_not_trigger_nullability_check - main: | - from typing_extensions import reveal_type - from django.db import models - - def take_field(f: models.Field) -> None: - return None - - take_field(models.IntegerField(null=True)) - take_field(models.IntegerField(null=False)) - -- case: nullable_field_subclass_without_explicit_type_vars - main: | - from typing_extensions import reveal_type - from myapp.models import Article - reveal_type(Article().body) # N: Revealed type is "str" - reveal_type(Article().body_nullable) # N: Revealed type is "str | None" - reveal_type(Article().count_nullable) # N: Revealed type is "int | None" - installed_apps: - - myapp - files: - - path: myapp/__init__.py - - path: myapp/models.py - content: | - from django.db import models - - class HTMLField(models.TextField): ... - class IntWrap(models.IntegerField): ... - - class Article(models.Model): - body = HTMLField() - body_nullable = HTMLField(null=True) - count_nullable = IntWrap(null=True) diff --git a/tests/typecheck/fields/test_postgres_fields.yml b/tests/typecheck/fields/test_postgres_fields.yml deleted file mode 100644 index bce86bb53..000000000 --- a/tests/typecheck/fields/test_postgres_fields.yml +++ /dev/null @@ -1,37 +0,0 @@ -- case: array_field_descriptor_access - main: | - from typing_extensions import reveal_type - from myapp.models import User - user = User(array=[]) - reveal_type(user.array) # N: Revealed type is "list[Any]" - installed_apps: - - myapp - files: - - path: myapp/__init__.py - - path: myapp/models.py - content: | - from django.db import models - from django.contrib.postgres.fields import ArrayField - - class User(models.Model): - array = ArrayField(base_field=models.Field()) - -- case: array_field_base_field_parsed_into_generic_typevar - main: | - from typing_extensions import reveal_type - from myapp.models import User - user = User() - reveal_type(user.members) # N: Revealed type is "list[int]" - reveal_type(user.members_as_text) # N: Revealed type is "list[str]" - installed_apps: - - myapp - files: - - path: myapp/__init__.py - - path: myapp/models.py - content: | - from django.db import models - from django.contrib.postgres.fields import ArrayField - - class User(models.Model): - members = ArrayField(base_field=models.IntegerField()) - members_as_text = ArrayField(base_field=models.CharField(max_length=255)) diff --git a/tests/typecheck/fields/test_related.yml b/tests/typecheck/fields/test_related.yml index 7565f9686..7dc6c3fac 100644 --- a/tests/typecheck/fields/test_related.yml +++ b/tests/typecheck/fields/test_related.yml @@ -48,11 +48,19 @@ from uuid import UUID book = Book() book.publisher = Publisher() - reveal_type(book.publisher_id) # N: Revealed type is "uuid.UUID" + reveal_type(book.publisher_id) book.publisher_id = '821850bb-c105-426f-b340-3974419d00ca' book.publisher_id = UUID('821850bb-c105-426f-b340-3974419d00ca') - book.publisher_id = [1] # E: Incompatible types in assignment (expression has type "list[int]", variable has type "str | UUID") [assignment] - book.publisher_id = Publisher() # E: Incompatible types in assignment (expression has type "Publisher", variable has type "str | UUID") [assignment] + book.publisher_id = [1] + book.publisher_id = Publisher() + out: | + main:6: note: Revealed type is "uuid.UUID" + main:9: error: No overload variant of "__set__" of "Field" matches argument types "Book", "list[int]" [call-overload] + main:9: note: Possible overload variants: + main:9: note: def __set__(self, instance: Any, value: str | UUID | Combinable) -> None + main:10: error: No overload variant of "__set__" of "Field" matches argument types "Book", "Publisher" [call-overload] + main:10: note: Possible overload variants: + main:10: note: def __set__(self, instance: Any, value: str | UUID | Combinable) -> None installed_apps: - myapp files: @@ -260,7 +268,7 @@ from myapp2.models import Profile reveal_type(Profile().user) # N: Revealed type is "myapp.models.user.User" reveal_type(Profile().user.profile) # N: Revealed type is "myapp2.models.Profile" - reveal_type(Profile.user.field) # N: Revealed type is "django.db.models.fields.related.OneToOneField[myapp.models.user.User | django.db.models.expressions.Combinable, myapp.models.user.User]" + reveal_type(Profile.user.field) # N: Revealed type is "django.db.models.fields.related.OneToOneField[myapp.models.user.User, myapp.models.user.User, Literal[False]]" installed_apps: - myapp - myapp2 @@ -655,9 +663,9 @@ reveal_type(Author.blogs) # N: Revealed type is "django.db.models.fields.related_descriptors.ManyToManyDescriptor[myapp.models.Blog, myapp.models.Author_blogs]" reveal_type(Author.blogs.through) # N: Revealed type is "type[myapp.models.Author_blogs]" reveal_type(Author().blogs) # N: Revealed type is "myapp.models.Blog_ManyRelatedManager[myapp.models.Author_blogs]" - reveal_type(Blog.publisher) # N: Revealed type is "django.db.models.fields.related_descriptors.ForwardManyToOneDescriptor[django.db.models.fields.related.ForeignKey[myapp.models.Publisher | django.db.models.expressions.Combinable, myapp.models.Publisher]]" - reveal_type(Publisher.profile) # N: Revealed type is "django.db.models.fields.related_descriptors.ForwardOneToOneDescriptor[django.db.models.fields.related.OneToOneField[myapp.models.Profile | django.db.models.expressions.Combinable, myapp.models.Profile]]" - reveal_type(Author.file) # N: Revealed type is "django.db.models.fields.files.FileDescriptor" + reveal_type(Blog.publisher) # N: Revealed type is "django.db.models.fields.related_descriptors.ForwardManyToOneDescriptor[django.db.models.fields.related.ForeignKey[myapp.models.Publisher, myapp.models.Publisher, Literal[False]]]" + reveal_type(Publisher.profile) # N: Revealed type is "django.db.models.fields.related_descriptors.ForwardOneToOneDescriptor[django.db.models.fields.related.OneToOneField[myapp.models.Profile, myapp.models.Profile, Literal[False]]]" + reveal_type(Author.file) # N: Revealed type is "django.db.models.fields._FieldDescriptor[django.db.models.fields.files.FileField[Any, django.db.models.fields.files.FieldFile, Literal[False]]]" installed_apps: - myapp @@ -842,14 +850,16 @@ - path: myapp/__init__.py - path: myapp/models.py content: | - from typing import Any + from typing import Any, Literal from typing_extensions import TypeVar from django.db import models + from django.db.models.expressions import Combinable - _ST = TypeVar("_ST", contravariant=True) - _GT = TypeVar("_GT", covariant=True) + _ST = TypeVar("_ST", contravariant=True, default=Any | Combinable) + _GT = TypeVar("_GT", covariant=True, default=Any) + _NT = TypeVar("_NT", Literal[True], Literal[False], default=Literal[False]) - class FK(models.ForeignKey[_ST, _GT]): + class FK(models.ForeignKey[_ST, _GT, _NT]): def __init__(self, *args: Any, **kwargs: Any) -> None: kwargs.setdefault('on_delete', models.CASCADE) super().__init__(*args, **kwargs) @@ -1262,7 +1272,7 @@ main:17: note: Revealed type is "myapp.models.Other" main:18: note: Revealed type is "int" main:20: note: Revealed type is "type[myapp.models.MyModel_auto_through]" - main:21: note: Revealed type is "django.db.models.fields.related_descriptors.ForwardManyToOneDescriptor[django.db.models.fields.related.ForeignKey[myapp.models.MyModel | django.db.models.expressions.Combinable, myapp.models.MyModel]]" + main:21: note: Revealed type is "django.db.models.fields.related_descriptors.ForwardManyToOneDescriptor[django.db.models.fields.related.ForeignKey[myapp.models.MyModel, myapp.models.MyModel, Literal[False]]]" main:22: note: Revealed type is "django.db.models.manager.Manager[myapp.models.MyModel_auto_through]" main:24: note: Revealed type is "type[myapp.models.MyModel_other_again]" main:25: note: Revealed type is "django.db.models.fields.related_descriptors.ManyToManyDescriptor[myapp.models.MyModel, myapp.models.MyModel_auto_through]" @@ -1275,13 +1285,13 @@ main:34: note: Revealed type is "django.db.models.manager.Manager[myapp.models.MyModel_auto_through]" main:36: note: Revealed type is "myapp.models.Other_ManyRelatedManager[myapp.models.MyModel_auto_through]" main:37: note: Revealed type is "type[myapp.models.MyModel_auto_through]" - main:38: note: Revealed type is "django.db.models.fields._FieldDescriptor[django.db.models.fields.BigAutoField[float | int | str | django.db.models.expressions.Combinable | None, int]]" + main:38: note: Revealed type is "django.db.models.fields._FieldDescriptor[django.db.models.fields.BigAutoField[int | str | None, int, Literal[False]]]" main:39: note: Revealed type is "myapp.models.Other_ManyRelatedManager[myapp.models.MyModel_other_again]" main:40: note: Revealed type is "type[myapp.models.MyModel_other_again]" - main:41: note: Revealed type is "django.db.models.fields._FieldDescriptor[django.db.models.fields.BigAutoField[float | int | str | django.db.models.expressions.Combinable | None, int]]" + main:41: note: Revealed type is "django.db.models.fields._FieldDescriptor[django.db.models.fields.BigAutoField[int | str | None, int, Literal[False]]]" main:42: note: Revealed type is "myapp.models.Other_ManyRelatedManager[myapp.models.CustomThrough]" main:43: note: Revealed type is "type[myapp.models.CustomThrough]" - main:44: note: Revealed type is "django.db.models.fields._FieldDescriptor[django.db.models.fields.BigAutoField[float | int | str | django.db.models.expressions.Combinable, int]]" + main:44: note: Revealed type is "django.db.models.fields._FieldDescriptor[django.db.models.fields.BigAutoField[int | str, int, Literal[False]]]" installed_apps: - myapp files: diff --git a/tests/typecheck/managers/querysets/test_annotate.yml b/tests/typecheck/managers/querysets/test_annotate.yml index 69610c929..f1f61e2a6 100644 --- a/tests/typecheck/managers/querysets/test_annotate.yml +++ b/tests/typecheck/managers/querysets/test_annotate.yml @@ -723,10 +723,11 @@ val_str = User.objects.annotate(val=char_subquery).get().val reveal_type(val_str) # N: Revealed type is "str" - # Subquery with nullable output_field → int | None + # Subquery with nullable output_field + # TODO: Should resolve to int | None, but Subquery doesn't propagate the nullable flag nullable_sub = Subquery(User.objects.filter(pk=OuterRef('pk')).values('id')[:1], output_field=IntegerField(null=True)) val_nullable = User.objects.annotate(val=nullable_sub).get().val - reveal_type(val_nullable) # N: Revealed type is "int | None" + reveal_type(val_nullable) # N: Revealed type is "int" # Subquery without output_field → Any plain_subquery = Subquery(User.objects.filter(pk=OuterRef('pk')).values('id')[:1]) diff --git a/tests/typecheck/managers/querysets/test_from_queryset.yml b/tests/typecheck/managers/querysets/test_from_queryset.yml index f95f9e123..8814f50bc 100644 --- a/tests/typecheck/managers/querysets/test_from_queryset.yml +++ b/tests/typecheck/managers/querysets/test_from_queryset.yml @@ -446,6 +446,7 @@ objects = NewManager() - case: from_queryset_with_manager_in_another_directory_and_imports + disable_cache: true main: | from typing_extensions import reveal_type from myapp.models import MyModel diff --git a/tests/typecheck/managers/querysets/test_prefetch_related.yml b/tests/typecheck/managers/querysets/test_prefetch_related.yml index cc71c3298..0eb597d06 100644 --- a/tests/typecheck/managers/querysets/test_prefetch_related.yml +++ b/tests/typecheck/managers/querysets/test_prefetch_related.yml @@ -560,9 +560,10 @@ - path: myapp/models.py content: | import typing + from typing_extensions import TypeVar from django.db import models - T = typing.TypeVar("T", bound="models.Model") + T = TypeVar("T", bound="models.Model") class BaseManager(models.Manager[T]): ... class BaseObject(models.Model): diff --git a/tests/typecheck/managers/querysets/test_values_list.yml b/tests/typecheck/managers/querysets/test_values_list.yml index 247280526..e36291f0e 100644 --- a/tests/typecheck/managers/querysets/test_values_list.yml +++ b/tests/typecheck/managers/querysets/test_values_list.yml @@ -64,7 +64,7 @@ from __future__ import annotations from django.db import models - class JSONField(models.TextField): pass # incomplete + class JSONField(models.Field): pass # incomplete class Concrete(models.Model): id = models.IntegerField() diff --git a/tests/typecheck/models/test_create.yml b/tests/typecheck/models/test_create.yml index f044dd85b..b7002e915 100644 --- a/tests/typecheck/models/test_create.yml +++ b/tests/typecheck/models/test_create.yml @@ -73,7 +73,7 @@ from django.db import models - class JSONField(models.TextField): pass # incomplete + class JSONField(models.Field): pass # incomplete class Base(models.Model): dct: models.Field[dict[str, str], dict[str, str]] = JSONField() @@ -87,7 +87,7 @@ Book.objects.create(id=None) # E: Incompatible type for "id" of "Book" (got "None", expected "float | int | str | Combinable") [misc] Book.objects.create(publisher=None) # E: Incompatible type for "publisher" of "Book" (got "None", expected "Publisher | Combinable") [misc] - Book.objects.create(publisher_id=None) # E: Incompatible type for "publisher_id" of "Book" (got "None", expected "float | int | str | Combinable") [misc] + Book.objects.create(publisher_id=None) # E: Incompatible type for "publisher_id" of "Book" (got "None", expected "int | str | Combinable") [misc] installed_apps: - myapp files: @@ -120,27 +120,43 @@ from typing_extensions import reveal_type from myapp.models import MyModel first = MyModel(id=None) - reveal_type(first.id) # N: Revealed type is "int" + reveal_type(first.id) first = MyModel.objects.create(id=None) - reveal_type(first.id) # N: Revealed type is "int" + reveal_type(first.id) first = MyModel() first.id = None - reveal_type(first.id) # N: Revealed type is "int" + reveal_type(first.id) from myapp.models import MyModel2 - MyModel2(id=None) # E: Incompatible type for "id" of "MyModel2" (got "None", expected "float | int | str | Combinable") [misc] - MyModel2.objects.create(id=None) # E: Incompatible type for "id" of "MyModel2" (got "None", expected "float | int | str | Combinable") [misc] + MyModel2(id=None) + MyModel2.objects.create(id=None) second = MyModel2() - second.id = None # E: Incompatible types in assignment (expression has type "None", variable has type "float | int | str | Combinable") [assignment] - reveal_type(second.id) # N: Revealed type is "int" + second.id = None + reveal_type(second.id) # default set but no primary key doesn't allow None from myapp.models import MyModel3 - MyModel3(default=None) # E: Incompatible type for "default" of "MyModel3" (got "None", expected "float | int | str | Combinable") [misc] - MyModel3.objects.create(default=None) # E: Incompatible type for "default" of "MyModel3" (got "None", expected "float | int | str | Combinable") [misc] + MyModel3(default=None) + MyModel3.objects.create(default=None) third = MyModel3() - third.default = None # E: Incompatible types in assignment (expression has type "None", variable has type "float | int | str | Combinable") [assignment] - reveal_type(third.default) # N: Revealed type is "int" + third.default = None + reveal_type(third.default) + out: | + main:4: note: Revealed type is "int" + main:6: note: Revealed type is "int" + main:9: note: Revealed type is "int" + main:12: error: Incompatible type for "id" of "MyModel2" (got "None", expected "float | int | str | Combinable") [misc] + main:13: error: Incompatible type for "id" of "MyModel2" (got "None", expected "float | int | str | Combinable") [misc] + main:15: error: No overload variant of "__set__" of "Field" matches argument types "MyModel2", "None" [call-overload] + main:15: note: Possible overload variants: + main:15: note: def __set__(self, instance: Any, value: float | int | str | Combinable) -> None + main:16: note: Revealed type is "Any" + main:20: error: Incompatible type for "default" of "MyModel3" (got "None", expected "float | int | str | Combinable") [misc] + main:21: error: Incompatible type for "default" of "MyModel3" (got "None", expected "float | int | str | Combinable") [misc] + main:23: error: No overload variant of "__set__" of "Field" matches argument types "MyModel3", "None" [call-overload] + main:23: note: Possible overload variants: + main:23: note: def __set__(self, instance: Any, value: float | int | str | Combinable) -> None + main:24: note: Revealed type is "Any" installed_apps: - myapp files: diff --git a/tests/typecheck/models/test_init.yml b/tests/typecheck/models/test_init.yml index e22290640..3d6e97dc3 100644 --- a/tests/typecheck/models/test_init.yml +++ b/tests/typecheck/models/test_init.yml @@ -127,7 +127,7 @@ from myapp.models import Publisher, PublisherDatetime, Book Book(publisher_id=1, publisher_dt_id=now) - Book(publisher_id=[], publisher_dt_id=now) # E: Incompatible type for "publisher_id" of "Book" (got "list[Any]", expected "float | int | str | Combinable") [misc] + Book(publisher_id=[], publisher_dt_id=now) # E: Incompatible type for "publisher_id" of "Book" (got "list[Any]", expected "int | str | Combinable") [misc] Book(publisher_id=1, publisher_dt_id=1) # E: Incompatible type for "publisher_dt_id" of "Book" (got "int", expected "str | datetime | date | Combinable") [misc] installed_apps: - myapp @@ -145,33 +145,6 @@ publisher = models.ForeignKey(Publisher, on_delete=models.CASCADE) publisher_dt = models.ForeignKey(PublisherDatetime, on_delete=models.CASCADE) -- case: setting_value_to_an_array_of_ints - main: | - from myapp.models import MyModel - array_val: tuple[int, ...] = (1,) - MyModel(array=array_val) - array_val2: list[int] = [1] - MyModel(array=array_val2) - class NotAValid: - pass - array_val3: list[NotAValid] = [NotAValid()] - MyModel(array=array_val3) # E: Incompatible type for "array" of "MyModel" (got "list[NotAValid]", expected "Sequence[float | int | str] | Combinable") [misc] - non_init = MyModel() - non_init.array = array_val - non_init.array = array_val2 - non_init.array = array_val3 # E: Incompatible types in assignment (expression has type "list[NotAValid]", variable has type "Sequence[float | int | str] | Combinable") [assignment] - installed_apps: - - myapp - files: - - path: myapp/__init__.py - - path: myapp/models.py - content: | - from django.db import models - from django.contrib.postgres.fields import ArrayField - - class MyModel(models.Model): - array = ArrayField(base_field=models.IntegerField()) - - case: if_no_explicit_primary_key_id_can_be_passed main: | from myapp.models import MyModel @@ -259,65 +232,6 @@ class MyModel(AbstractModel): pass - -- case: field_set_type_honors_type_redefinition - main: | - from typing_extensions import reveal_type - from myapp.models import MyModel - non_init = MyModel() - reveal_type(non_init.redefined_set_type) - reveal_type(non_init.redefined_union_set_type) - reveal_type(non_init.redefined_array_set_type) - reveal_type(non_init.default_set_type) - reveal_type(non_init.unset_set_type) - non_init.redefined_set_type = "invalid" - non_init.redefined_union_set_type = "invalid" - array_val: list[str] = ["invalid"] - non_init.redefined_array_set_type = array_val - non_init.default_set_type = [] - non_init.unset_set_type = [] - MyModel( - redefined_set_type="invalid", - redefined_union_set_type="invalid", - redefined_array_set_type=33, - default_set_type=[], - unset_set_type=[], - ) - out: | - main:4: note: Revealed type is "int" - main:5: note: Revealed type is "int" - main:6: note: Revealed type is "list[int]" - main:7: note: Revealed type is "int" - main:8: note: Revealed type is "Any" - main:9: error: Incompatible types in assignment (expression has type "str", variable has type "int") [assignment] - main:10: error: Incompatible types in assignment (expression has type "str", variable has type "int | float") [assignment] - main:12: error: Incompatible types in assignment (expression has type "list[str]", variable has type "Sequence[int | float]") [assignment] - main:13: error: Incompatible types in assignment (expression has type "list[Never]", variable has type "float | int | str | Combinable") [assignment] - main:15: error: Incompatible type for "redefined_set_type" of "MyModel" (got "str", expected "int") [misc] - main:15: error: Incompatible type for "redefined_union_set_type" of "MyModel" (got "str", expected "int | float") [misc] - main:15: error: Incompatible type for "redefined_array_set_type" of "MyModel" (got "int", expected "Sequence[int | float]") [misc] - main:15: error: Incompatible type for "default_set_type" of "MyModel" (got "list[Any]", expected "float | int | str | Combinable") [misc] - installed_apps: - - myapp - files: - - path: myapp/__init__.py - - path: myapp/models.py - content: | - from django.contrib.postgres.fields import ArrayField - from django.db import models - from collections.abc import Sequence - from typing import cast - - class MyModel(models.Model): - redefined_set_type = cast("models.Field[int, int]", models.IntegerField()) - redefined_union_set_type = cast("models.Field[int | float, int]", models.IntegerField()) - redefined_array_set_type = cast( - "ArrayField[Sequence[int | float], list[int]]", - ArrayField(base_field=models.IntegerField()), - ) - default_set_type = models.IntegerField() - unset_set_type = cast("models.Field", models.IntegerField()) - - case: too_many_positional_arguments_on_init main: | from myapp.models import MyUser @@ -344,3 +258,18 @@ from django.db import models class MyUser(models.Model): pass + +- case: blank_and_not_null_charfield_does_not_allow_none + main: | + from typing_extensions import reveal_type + from myapp.models import MyModel + MyModel(notnulltext=None) # E: Incompatible type for "notnulltext" of "MyModel" (got "None", expected "str | int | Combinable") [misc] + installed_apps: + - myapp + files: + - path: myapp/__init__.py + - path: myapp/models.py + content: | + from django.db import models + class MyModel(models.Model): + notnulltext=models.CharField(max_length=1, blank=True, null=False) diff --git a/tests/typecheck/models/test_meta_options.yml b/tests/typecheck/models/test_meta_options.yml index 242ed0ad9..deb6fe127 100644 --- a/tests/typecheck/models/test_meta_options.yml +++ b/tests/typecheck/models/test_meta_options.yml @@ -17,14 +17,14 @@ main: | from typing_extensions import reveal_type from myapp.models import MyUser - reveal_type(MyUser._meta.get_field('base_name')) # N: Revealed type is "django.db.models.fields.CharField[str | int | django.db.models.expressions.Combinable, str]" - reveal_type(MyUser.base_name.field) # N: Revealed type is "django.db.models.fields.CharField[str | int | django.db.models.expressions.Combinable, str]" - reveal_type(MyUser._meta.get_field('name')) # N: Revealed type is "django.db.models.fields.CharField[str | int | django.db.models.expressions.Combinable, str]" - reveal_type(MyUser.name.field) # N: Revealed type is "django.db.models.fields.CharField[str | int | django.db.models.expressions.Combinable, str]" - reveal_type(MyUser._meta.get_field('age')) # N: Revealed type is "django.db.models.fields.IntegerField[float | int | str | django.db.models.expressions.Combinable, int]" - reveal_type(MyUser.age.field) # N: Revealed type is "django.db.models.fields.IntegerField[float | int | str | django.db.models.expressions.Combinable, int]" - reveal_type(MyUser._meta.get_field('to_user')) # N: Revealed type is "django.db.models.fields.related.ForeignKey[myapp.models.MyUser | django.db.models.expressions.Combinable, myapp.models.MyUser]" - reveal_type(MyUser.to_user.field) # N: Revealed type is "django.db.models.fields.related.ForeignKey[myapp.models.MyUser | django.db.models.expressions.Combinable, myapp.models.MyUser]" + reveal_type(MyUser._meta.get_field('base_name')) # N: Revealed type is "django.db.models.fields.CharField[str | int | django.db.models.expressions.Combinable, str, Literal[False]]" + reveal_type(MyUser.base_name.field) # N: Revealed type is "django.db.models.fields.CharField[str | int, str, Literal[False]]" + reveal_type(MyUser._meta.get_field('name')) # N: Revealed type is "django.db.models.fields.CharField[str | int | django.db.models.expressions.Combinable, str, Literal[False]]" + reveal_type(MyUser.name.field) # N: Revealed type is "django.db.models.fields.CharField[str | int, str, Literal[False]]" + reveal_type(MyUser._meta.get_field('age')) # N: Revealed type is "django.db.models.fields.IntegerField[float | int | str | django.db.models.expressions.Combinable, int, Literal[False]]" + reveal_type(MyUser.age.field) # N: Revealed type is "django.db.models.fields.IntegerField[float | int | str, int, Literal[False]]" + reveal_type(MyUser._meta.get_field('to_user')) # N: Revealed type is "django.db.models.fields.related.ForeignKey[myapp.models.MyUser | django.db.models.expressions.Combinable, myapp.models.MyUser, Literal[False]]" + reveal_type(MyUser.to_user.field) # N: Revealed type is "django.db.models.fields.related.ForeignKey[myapp.models.MyUser, myapp.models.MyUser, Literal[False]]" MyUser._meta.get_field('unknown') # E: MyUser has no field named 'unknown' [misc] installed_apps: @@ -48,10 +48,11 @@ MyModel._meta.get_field('non_existant') # E: MyModel has no field named 'non_existant' [misc] - reveal_type(MyModel._meta.get_field('field')) # N: Revealed type is "django.contrib.postgres.fields.array.ArrayField[typing.Sequence[float | int | str] | django.db.models.expressions.Combinable, list[int]]" + reveal_type(MyModel._meta.get_field('field')) # N: Revealed type is "django.contrib.postgres.fields.array.ArrayField[typing.Sequence[float | int | str] | django.db.models.expressions.Combinable, list[int], Literal[False]]" + reveal_type(MyModel._meta.get_field('null_field')) # N: Revealed type is "django.contrib.postgres.fields.array.ArrayField[typing.Sequence[float | int | str] | django.db.models.expressions.Combinable | None, list[int] | None, Literal[True]]" field: str - reveal_type(MyModel._meta.get_field(field)) # N: Revealed type is "django.db.models.fields.Field[Any, Any] | django.db.models.fields.reverse_related.ForeignObjectRel" + reveal_type(MyModel._meta.get_field(field)) # N: Revealed type is "django.db.models.fields.Field[Any, Any, Literal[False]] | django.db.models.fields.reverse_related.ForeignObjectRel" installed_apps: - myapp files: @@ -67,6 +68,7 @@ class MyModel(AbstractModel): field = ArrayField(models.IntegerField(), default=[]) + null_field = ArrayField(models.IntegerField(), default=[], null=True) - case: get_field_reverse_fk_with_related_query_name main: | @@ -74,10 +76,10 @@ from myapp.models import ModelA reveal_type(ModelA._meta.get_field("model_b")) # N: Revealed type is "django.db.models.fields.reverse_related.ManyToOneRel" - reveal_type(ModelA.modelb_set.field) # N: Revealed type is "django.db.models.fields.related.ForeignKey[myapp.models.ModelB, myapp.models.ModelB]" + reveal_type(ModelA.modelb_set.field) # N: Revealed type is "django.db.models.fields.related.ForeignKey[myapp.models.ModelB, myapp.models.ModelB, Literal[False]]" reveal_type(ModelA._meta.get_field("model_b_bis")) # N: Revealed type is "django.db.models.fields.reverse_related.ManyToOneRel" - reveal_type(ModelA.model_b_bis.field) # N: Revealed type is "django.db.models.fields.related.ForeignKey[myapp.models.ModelB, myapp.models.ModelB]" + reveal_type(ModelA.model_b_bis.field) # N: Revealed type is "django.db.models.fields.related.ForeignKey[myapp.models.ModelB, myapp.models.ModelB, Literal[False]]" installed_apps: - myapp diff --git a/tests/typecheck/models/test_related_fields.yml b/tests/typecheck/models/test_related_fields.yml index deb41525e..e71439b75 100644 --- a/tests/typecheck/models/test_related_fields.yml +++ b/tests/typecheck/models/test_related_fields.yml @@ -94,7 +94,7 @@ from typing_extensions import reveal_type from app1.models import Model1, Model2 - reveal_type(Model2.model_1.field) # N: Revealed type is "django.db.models.fields.related.ForeignObject[app1.models.Model1, app1.models.Model1]" + reveal_type(Model2.model_1.field) # N: Revealed type is "django.db.models.fields.related.ForeignObject[app1.models.Model1, app1.models.Model1, Literal[False]]" reveal_type(Model2().model_1) # N: Revealed type is "app1.models.Model1" reveal_type(Model1.model_2s) # N: Revealed type is "django.db.models.fields.related_descriptors.ReverseManyToOneDescriptor[app1.models.Model2]" reveal_type(Model1().model_2s) # N: Revealed type is "django.db.models.fields.related_descriptors.RelatedManager[app1.models.Model2]"