diff --git a/AGENTS.md b/AGENTS.md index e588fef840f..50bb0dcc42c 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -69,16 +69,28 @@ const { title$ } = strings; // title$() returns translated string ### ⚠️ API Calls via Resource Classes Only Use `Resource` from `kolibri/apiResource`. Define in `apiResources.js`. Never use raw `fetch` or `axios`. -### ⚠️ Backend APIs: Use ValuesViewset -Use `ValuesViewset` (or `ReadOnlyValuesViewset`) from `kolibri.core.api` for new API endpoints — not `ModelViewSet`, `ViewSet`, or `GenericViewSet`: +### ⚠️ Backend APIs: Use ValuesViewset with Serializer Derivation +Use `ValuesViewset` (or `ReadOnlyValuesViewset`) from `kolibri.core.api` for new API endpoints — not `ModelViewSet`, `ViewSet`, or `GenericViewSet`. Define a DRF serializer as the source of truth; the viewset derives the `values()` query automatically: ```python +from rest_framework import serializers from kolibri.core.api import ReadOnlyValuesViewset +class MySerializer(serializers.ModelSerializer): + class Meta: + model = MyModel + fields = ("id", "title", "description") + class MyViewSet(ReadOnlyValuesViewset): - values = ("id", "title", "description") - # Define values tuple and annotate_queryset for computed fields + serializer_class = MySerializer + queryset = MyModel.objects.all() ``` -Viewset permissions use `KolibriAuthPermissions` from `kolibri.core.auth.api`. See `docs/backend_architecture/api_patterns.rst`. +Do **not** define explicit `values` tuples or `field_map` dicts on new viewsets — these are legacy patterns being migrated away. + +The model should define a default `ordering` in its `Meta`, or the viewset's `queryset` should set an explicit `order_by()` — response ordering (and pagination) is nondeterministic otherwise. + +Viewset permissions use `KolibriAuthPermissions` from `kolibri.core.auth.api`, which delegates object-level checks to the model's declarative permissions (e.g. `RoleBasedPermissions`). It only works for models that participate in Kolibri's auth/permissions system — models without those declarations need a different permission class. + +See `docs/backend_architecture/api_patterns.rst`. ### ⚠️ Testing is Required - **Python:** pytest is the test runner. Django API tests extend `APITestCase` from `rest_framework.test`. Other Django tests extend `django.test.TestCase`. Only use bare pytest-style function tests for non-Django code. diff --git a/docs/backend_architecture/api_patterns.rst b/docs/backend_architecture/api_patterns.rst index 571fcb024c7..be085cabb22 100644 --- a/docs/backend_architecture/api_patterns.rst +++ b/docs/backend_architecture/api_patterns.rst @@ -34,130 +34,241 @@ Overview Basic Usage ~~~~~~~~~~~ -A minimal ``ValuesViewset`` requires defining the ``values`` tuple: +Define a DRF serializer as the single source of truth for the API shape. The viewset automatically derives the ``values()`` query and field transformations from the serializer's field definitions: .. code-block:: python + from rest_framework import serializers from kolibri.core.api import ValuesViewset from kolibri.core.auth.api import KolibriAuthPermissions from .models import Lesson - # Permissions are typically defined in the same file as the viewset - class LessonPermissions(KolibriAuthPermissions): - pass + class LessonSerializer(serializers.ModelSerializer): + class Meta: + model = Lesson + fields = ("id", "title", "description", "is_active", "created_by", "date_created") class LessonViewset(ValuesViewset): + serializer_class = LessonSerializer queryset = Lesson.objects.all() - permission_classes = (LessonPermissions,) - - # Tuple of fields to fetch from database - values = ( - "id", - "title", - "description", - "is_active", - "created_by", - "date_created", - ) - -Key Attributes and Methods + permission_classes = (KolibriAuthPermissions,) + +From this, the viewset automatically derives: + +- **values tuple**: ``("id", "title", "description", "is_active", "created_by", "date_created")`` +- **field transformations**: Each field's ``to_representation()`` method handles type coercion where needed + +The model should define a default ``ordering`` in its ``Meta``, or the viewset's ``queryset`` should set an explicit ``order_by()`` — response ordering (and pagination) is nondeterministic otherwise. + +How Derivation Works +~~~~~~~~~~~~~~~~~~~~ + +The viewset introspects the serializer's fields to build the values tuple and field mappings. The rules are: + +.. list-table:: + :header-rows: 1 + :widths: 40 60 + + * - Serializer Pattern + - Derived Behavior + * - ``field = CharField()`` + - Add ``'field'`` to values + * - ``field = CharField(source='other')`` + - Add ``'other'`` to values, rename to ``'field'`` in output + * - ``field = BooleanField(source='x.y')`` + - Add ``'x__y'`` to values, ``field.to_representation()`` handles coercion + * - ``field = CharField(write_only=True)`` + - Skip (not in read output) + * - ``nested = NestedSerializer(many=True)`` + - Flatten nested fields with prefix, auto-consolidate child rows into a list per parent + * - ``nested = NestedSerializer()`` + - Flatten nested fields with prefix, extract as dict per row + * - Custom field with ``to_representation()`` + - Custom transformation applied automatically + * - ``field = ValuesMethodField(sources=(...))`` + - Add declared sources to values; invoke ``get_`` per row over a proxy of those sources + * - ``field = SerializerMethodField()`` + - Rejected at viewset init — use ``ValuesMethodField`` so sources are explicit + +Computed and Derived Fields ~~~~~~~~~~~~~~~~~~~~~~~~~~~ -``values`` (required) -^^^^^^^^^^^^^^^^^^^^^ +When an output value isn't a direct column read, the table below covers the common shapes. ``ValuesMethodField`` is fine as the default for one-off per-row computation; promote to a custom ``Field`` subclass only when the same transform recurs across serializers. + +.. list-table:: + :header-rows: 1 + :widths: 50 50 + + * - Intent + - Do this + * - Expose a (possibly null) related attribute + - ``BooleanField(source="dataset.x", default=False)`` + * - Constant value + - ``ReadOnlyField(default=...)`` + * - M2M PK collection + - Nested ``many=True`` serializer, or ``ArrayAgg`` annotation + * - Count/aggregate over relation + - ``annotate_queryset`` + * - Per-row transform or computation (one-off) + - ``ValuesMethodField(sources=(...))`` + * - Per-row transform reused across serializers + - Custom ``Field`` subclass with ``to_representation`` (e.g. ``SplitTextField``) + * - Per-row computation that needs request context + - ``ValuesMethodField(sources=(...))`` + ``self.context["request"]`` + +ValuesMethodField +^^^^^^^^^^^^^^^^^ -Tuple of database field names to fetch. Supports foreign key lookups using ``__`` notation: +A plain ``SerializerMethodField`` is rejected at viewset init — the viewset cannot infer which columns the method reads. Declare them with ``ValuesMethodField(sources=(...))``: .. code-block:: python - values = ( - "id", - "title", - "collection__id", # FK lookup: classroom ID - "collection__name", # FK lookup: classroom name - "collection__parent_id", # FK lookup: facility ID - ) + from kolibri.core.api import ValuesMethodField + + class UserSerializer(serializers.ModelSerializer): + contact_label = ValuesMethodField(sources=("full_name", "email")) + + def get_contact_label(self, row): + return "{} <{}>".format(row.full_name, row.email) + +- ``sources`` are added to the ``values()`` call. Dotted sources (``"publisher.name"``) are walked: ``row.publisher.name`` reads the ``publisher__name`` column. +- ``row`` is a proxy exposing only the declared paths; anything else raises ``AttributeError``. +- Values are Python types after Django's coercion, not serialized strings — a ``DateTimeField`` source is a ``datetime``. +- Sources referenced *only* by the method are stripped from the output — method inputs, not outputs. +- ``self.context`` carries per-request state (``request``, ``view``, ``format``) for the duration of each ``serialize()`` call. -``field_map`` (optional) -^^^^^^^^^^^^^^^^^^^^^^^^ +Nested Serializers +~~~~~~~~~~~~~~~~~~ -Dictionary mapping API field names to database fields or transformation functions: +Nested serializers can be handled in two ways: **joined** (default) or **deferred**. + +Joined (Default) — Auto-Consolidated +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +When a nested serializer is not listed in ``deferred_fields``, its fields are included in the main ``values()`` call with a prefix. The resulting flat rows are automatically consolidated back into nested structures: .. code-block:: python - # Simple string mapping (rename fields) - field_map = { - "active": "is_active", # API: active, DB: is_active - "classroom_id": "collection__id", # Rename FK field - } - - # Callable mapping (transform data) - def _transform_classroom(item): - """Restructure classroom data from flat to nested""" - return { - "id": item.pop("collection__id"), - "name": item.pop("collection__name"), - "parent": item.pop("collection__parent_id"), - } + class RoleSerializer(serializers.ModelSerializer): + class Meta: + model = Role + fields = ("id", "kind", "collection") - field_map = { - "classroom": _transform_classroom, # Returns nested object - } + class UserSerializer(serializers.ModelSerializer): + roles = RoleSerializer(many=True, read_only=True) -``annotate_queryset(queryset)`` (optional) -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + class Meta: + model = FacilityUser + fields = ("id", "username", "roles") -Method to add computed/aggregated fields before serialization: + class UserViewSet(ReadOnlyValuesViewset): + serializer_class = UserSerializer + queryset = FacilityUser.objects.all() + +The viewset fetches ``("id", "username", "roles__id", "roles__kind", "roles__collection")`` and auto-consolidates: .. code-block:: python - from kolibri.core.query import annotate_array_aggregate + # Raw values() output (multiple rows per user due to LEFT JOIN): + [ + {"id": "user1", "username": "alice", "roles__id": "r1", "roles__kind": "admin", ...}, + {"id": "user1", "username": "alice", "roles__id": "r2", "roles__kind": "coach", ...}, + {"id": "user2", "username": "bob", "roles__id": "r3", "roles__kind": "learner", ...}, + ] - class MyViewset(ValuesViewset): - # ... + # After auto-consolidation (grouped by primary key): + [ + {"id": "user1", "username": "alice", "roles": [ + {"id": "r1", "kind": "admin", ...}, + {"id": "r2", "kind": "coach", ...}, + ]}, + {"id": "user2", "username": "bob", "roles": [ + {"id": "r3", "kind": "learner", ...}, + ]}, + ] - def annotate_queryset(self, queryset): - """Add aggregated learner IDs""" - return annotate_array_aggregate( - queryset, - learner_ids="lesson_assignments__collection__membership__user_id" - ) +Auto-consolidation handles: + +- Grouping rows by parent primary key +- Deduplicating nested items (e.g., from annotation JOINs) +- NULL handling for LEFT JOINs (null FK → ``None`` for single nested, empty list for ``many=True``) +- Preserving original queryset ordering + +**Constraints:** + +- Only one ``many=True`` nested serializer may be joined per viewset (multiple would create a cartesian product). Additional ``many=True`` fields must be deferred. +- Deep nesting (nested serializers within nested serializers) is not supported for joined fields. Use ``deferred_fields`` and a custom ``consolidate()`` method instead. -``consolidate(items, queryset)`` (optional) -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +These constraints are checked at viewset instantiation time when ``DEBUG=True``. -Method to post-process the serialized items. Useful for adding related data that would be inefficient to fetch via ``values()`` (e.g., data that would cause complex subqueries worse for performance than a separate query). +Deferred — Fetched Separately in consolidate() +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -**Important:** Use ``__in`` lookups on the IDs of the fetched items, not on the original queryset, for efficient batch fetching: +For nested data that should be fetched with separate queries (for performance reasons, to avoid cartesian products, or when the relation is complex), list the field in ``deferred_fields`` and use ``serialize_queryset()`` in ``consolidate()``: .. code-block:: python - def consolidate(self, items, queryset): - """Add related assignment data""" - if items: - # Extract IDs from the already-fetched items - lesson_ids = [item["id"] for item in items] + class FileSerializer(serializers.ModelSerializer): + class Meta: + model = File + fields = ("id", "filename", "file_size") - # Fetch related data in a separate efficient query using __in - assignments = Assignment.objects.filter( - lesson_id__in=lesson_ids - ).select_related('collection') + class ContentNodeSerializer(serializers.ModelSerializer): + files = FileSerializer(many=True, read_only=True) + tags = TagSerializer(many=True, read_only=True) - assignments_by_lesson = { - a.lesson_id: a for a in assignments - } + class Meta: + model = ContentNode + fields = ("id", "title", "kind", "files", "tags") + + class ContentNodeViewSet(ReadOnlyValuesViewset): + serializer_class = ContentNodeSerializer + queryset = ContentNode.objects.all() + deferred_fields = ("files", "tags") + + def consolidate(self, items, queryset): + if not items: + return items + + node_ids = [item["id"] for item in items] + + files_map = self.serialize_queryset( + File.objects.filter(contentnode_id__in=node_ids), + "files", + group_by="contentnode_id", + ) + + tags_map = self.serialize_queryset( + ContentTag.objects.filter(tagged_content_id__in=node_ids), + "tags", + group_by="tagged_content_id", + ) for item in items: - item["assignment"] = assignments_by_lesson.get(item["id"]) - item["resources"] = item.get("resources") or [] + item["files"] = files_map.get(item["id"], []) + item["tags"] = tags_map.get(item["id"], []) - return items + return items + +The ``serialize_queryset()`` method applies the values pattern using the nested serializer's field definitions. It accepts a ``group_by`` parameter to return a dict mapping group keys to item lists, which is convenient for mapping back to parent items. + +Dev-Mode Validation +~~~~~~~~~~~~~~~~~~~~ + +When ``DEBUG=True``, ``serialize()`` validates that the output matches the serializer contract after ``consolidate()`` runs. This catches: + +- Missing fields (field in serializer but not in output) +- Extra fields (field in output but not in serializer) +- Nested field mismatches + +This validation only runs in development and has no production overhead. If your ``consolidate()`` modifies the output shape, the serializer must declare all output fields. Complete Example ~~~~~~~~~~~~~~~~ .. code-block:: python + from rest_framework import serializers from django_filters.rest_framework import DjangoFilterBackend from kolibri.core.api import ValuesViewset from kolibri.core.auth.api import KolibriAuthPermissions @@ -165,48 +276,37 @@ Complete Example from kolibri.core.auth.constants.collection_kinds import ADHOCLEARNERSGROUP from kolibri.core.query import annotate_array_aggregate from .models import Lesson, LessonAssignment - from .serializers import LessonSerializer - class LessonPermissions(KolibriAuthPermissions): - # Defined in the same file as the viewset (not a separate permissions module) - pass + class ClassroomSerializer(serializers.ModelSerializer): + class Meta: + model = Classroom + fields = ("id", "name", "parent_id") - def _map_lesson_classroom(item): - """Transform flat classroom fields to nested object""" - return { - "id": item.pop("collection__id"), - "name": item.pop("collection__name"), - "parent": item.pop("collection__parent_id"), - } + class LessonSerializer(serializers.ModelSerializer): + active = serializers.BooleanField(source="is_active") + classroom = ClassroomSerializer(source="collection", read_only=True) + learner_ids = serializers.ListField(read_only=True) + lesson_assignment_collections = serializers.ListField(read_only=True) + + class Meta: + model = Lesson + fields = ( + "id", "title", "description", "resources", + "active", "classroom", + "created_by", "date_created", + "learner_ids", "lesson_assignment_collections", + ) class LessonViewset(ValuesViewset): serializer_class = LessonSerializer queryset = Lesson.objects.all().order_by("-date_created") - permission_classes = (LessonPermissions,) + permission_classes = (KolibriAuthPermissions,) filter_backends = (KolibriAuthPermissionsFilter, DjangoFilterBackend) filterset_fields = ("collection", "id") - - values = ( - "id", - "title", - "description", - "resources", - "is_active", - "collection", # Classroom FK (as ID, used for filtering) - "collection__id", # Classroom ID (used by _map_lesson_classroom) - "collection__name", # Classroom name (used by _map_lesson_classroom) - "collection__parent_id",# Facility ID (used by _map_lesson_classroom) - "created_by", - "date_created", - ) - - field_map = { - "active": "is_active", # Rename field - "classroom": _map_lesson_classroom, # Transform to nested object - } + deferred_fields = ("classroom",) def annotate_queryset(self, queryset): """Add aggregated assignment collections""" @@ -216,29 +316,35 @@ Complete Example ) def consolidate(self, items, queryset): - """Add learner IDs for ad-hoc assignments""" - if items: - # Extract IDs from fetched items for efficient batch query - lesson_ids = [item["id"] for item in items] - - adhoc_assignments = LessonAssignment.objects.filter( - lesson_id__in=lesson_ids, - collection__kind=ADHOCLEARNERSGROUP - ) - adhoc_assignments = annotate_array_aggregate( - adhoc_assignments, - learner_ids="collection__membership__user_id" - ) - adhoc_map = { - a["lesson"]: a - for a in adhoc_assignments.values("lesson", "learner_ids") - } - - for item in items: - if item["id"] in adhoc_map: - item["learner_ids"] = adhoc_map[item["id"]]["learner_ids"] - else: - item["learner_ids"] = [] + """Add classroom data and learner IDs for ad-hoc assignments""" + if not items: + return items + + lesson_ids = [item["id"] for item in items] + + # Use serialize_queryset for deferred nested data + classroom_map = self.serialize_queryset( + Classroom.objects.filter(lesson__id__in=lesson_ids), + "classroom", + group_by="id", + ) + + adhoc_assignments = LessonAssignment.objects.filter( + lesson_id__in=lesson_ids, + collection__kind=ADHOCLEARNERSGROUP, + ) + adhoc_assignments = annotate_array_aggregate( + adhoc_assignments, + learner_ids="collection__membership__user_id", + ) + adhoc_map = { + a["lesson"]: a + for a in adhoc_assignments.values("lesson", "learner_ids") + } + + for item in items: + item["classroom"] = classroom_map.get(item["collection"], [None])[0] + item["learner_ids"] = adhoc_map.get(item["id"], {}).get("learner_ids", []) return items @@ -289,34 +395,71 @@ Full CRUD operations (Create, Retrieve, Update, Delete, List): Best Practices ~~~~~~~~~~~~~~ -1. **Only fetch needed fields**: Keep the ``values`` tuple minimal. Don't fetch fields you won't use. +1. **Serializer as source of truth**: Define the API shape in the serializer. Don't duplicate field definitions between serializer and viewset. -2. **Use field_map for clarity**: Rename fields in ``field_map`` rather than in templates/frontend to keep API consistent. +2. **Use source for renames**: Use ``source`` on serializer fields rather than ``field_map`` for renaming. -3. **Batch related queries in consolidate**: If you need related data, fetch it efficiently in ``consolidate`` using ``__in`` lookups on the IDs from already-fetched items. +3. **Defer wisely**: Use ``deferred_fields`` for ``many=True`` relations that would create large cartesian products, or for relations that require complex queries. Keep simple FK lookups as joined. -4. **Use annotate_queryset for aggregations**: Add computed fields via ``annotate_queryset`` rather than post-processing. +4. **Batch related queries in consolidate**: Fetch deferred data efficiently using ``serialize_queryset()`` with ``group_by`` and ``__in`` lookups on IDs from already-fetched items. -5. **Keep transformations simple**: Complex transformations in ``field_map`` callables can negate performance benefits. +5. **Use annotate_queryset for aggregations**: Add computed fields via ``annotate_queryset`` rather than post-processing. -6. **Test query performance**: Use Django Silk to profile your endpoints and verify query counts, execution time, and identify N+1 query issues. This is essential for ensuring your ValuesViewset implementation is actually performant. +6. **Test query performance**: Use Django Silk to profile your endpoints and verify query counts, execution time, and identify N+1 query issues. Common Pitfalls ~~~~~~~~~~~~~~~ -**Forgetting to include FK fields in values** +**Multiple many=True nested serializers without deferring** .. code-block:: python - # Wrong: field_map references collection__name but it's not in values - values = ("id", "title") - field_map = {"classroom": lambda x: x.pop("collection__name")} # KeyError! + # Wrong: cartesian product — two many=True JOINs multiply rows + class UserSerializer(serializers.ModelSerializer): + roles = RoleSerializer(many=True) + groups = GroupSerializer(many=True) + + class Meta: + model = FacilityUser + fields = ("id", "roles", "groups") - # Correct: include all referenced fields - values = ("id", "title", "collection__name") - field_map = {"classroom": lambda x: x.pop("collection__name")} + class UserViewSet(ReadOnlyValuesViewset): + serializer_class = UserSerializer # Raises TypeError in DEBUG -**Modifying items without returning them in consolidate** + # Correct: defer one of them + class UserViewSet(ReadOnlyValuesViewset): + serializer_class = UserSerializer + deferred_fields = ("groups",) + + def consolidate(self, items, queryset): + # Fetch groups separately + ... + +**Deep nesting without deferring** + +.. code-block:: python + + # Wrong: nested serializer within nested serializer + class GrandchildSerializer(serializers.ModelSerializer): + class Meta: + fields = ("id", "name") + + class ChildSerializer(serializers.ModelSerializer): + grandchildren = GrandchildSerializer(many=True) + class Meta: + fields = ("id", "grandchildren") + + class ParentSerializer(serializers.ModelSerializer): + children = ChildSerializer(many=True) + class Meta: + fields = ("id", "children") + + # Correct: defer deeply nested fields + class ParentViewSet(ReadOnlyValuesViewset): + serializer_class = ParentSerializer + deferred_fields = ("children",) # Fetch children (and grandchildren) in consolidate + +**Forgetting to return items from consolidate** .. code-block:: python @@ -332,18 +475,109 @@ Common Pitfalls item["foo"] = "bar" return items -**Using pop() in field_map callables without checking existence** +Migrating from Explicit Values +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Existing viewsets that use explicit ``values`` tuples and ``field_map`` dicts continue to work. To migrate to the serializer-derived pattern: + +1. **Ensure API tests exist** for the viewset. Write them if missing — they must pass before and after migration. + +2. **Capture a performance baseline** before making any changes. The benchmark script measures serialization timing, memory usage, and query count: + + .. code-block:: bash + + python integration_testing/scripts/viewset_serialization_benchmark.py \ + kolibri.core.auth.api.FacilityUserViewSet \ + -o baseline.json + + This saves timing, memory, query count, and a data hash to ``baseline.json``. + +3. **Update the serializer** to declare all read fields with correct ``source`` attributes: + + .. code-block:: python + + # Before: separate values/field_map + class MyViewSet(ValuesViewset): + serializer_class = MySerializer # may be write-only + values = ("id", "full_name", "devicepermissions__is_superuser") + field_map = { + "is_superuser": lambda x: bool(x.pop("devicepermissions__is_superuser")), + } + + # After: serializer defines everything + class MySerializer(serializers.ModelSerializer): + is_superuser = serializers.BooleanField( + source="devicepermissions.is_superuser", + read_only=True, + ) + + class Meta: + model = FacilityUser + fields = ("id", "full_name", "is_superuser") + + class MyViewSet(ValuesViewset): + serializer_class = MySerializer + # No values or field_map needed + +4. **Convert ``field_map`` callables** to one of: + + - A serializer field with ``source`` (for simple renames) + - A custom field class with ``to_representation()`` (for transforms repeated across serializers) + - A ``ValuesMethodField(sources=(...))`` (for one-off computation from one or more columns) + - Deferred field handling in ``consolidate()`` (for complex restructuring) + +5. **Convert manual consolidation** of nested data: + + - If the viewset manually does ``groupby`` to build nested lists, define a nested serializer with ``many=True`` and let auto-consolidation handle it + - If the nested data is fetched separately, add it to ``deferred_fields`` and use ``serialize_queryset()`` + +6. **Remove** the explicit ``values`` tuple and ``field_map`` dict. + +7. **Run tests** and verify output is identical. + +8. **Compare performance against the baseline**: + + .. code-block:: bash + + python integration_testing/scripts/viewset_serialization_benchmark.py \ + kolibri.core.auth.api.FacilityUserViewSet \ + --compare baseline.json + + The script compares timing and memory against the baseline and flags regressions that exceed configurable thresholds (default: 5% timing, 10% memory). It also compares data hashes to confirm output equivalence. + + If a regression is detected, investigate before proceeding — the serializer-derived path should be at least as fast as the explicit pattern. Common causes include unnecessary ``to_representation`` calls on fields that could use inferred types, or missing ``select_related``/``prefetch_related`` on the queryset. + +Explicit Values (Legacy) +~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. note:: + + The explicit ``values``/``field_map`` pattern described below is being replaced by the serializer-derived pattern above. Existing viewsets using this pattern continue to work, but new viewsets should use serializer derivation. + +A ``ValuesViewset`` can define an explicit ``values`` tuple and ``field_map`` dict: .. code-block:: python - # Wrong: KeyError if field doesn't exist - field_map = {"classroom": lambda x: x.pop("collection__name")} + class LessonViewset(ValuesViewset): + queryset = Lesson.objects.all() + values = ("id", "title", "is_active", "collection__name") + field_map = { + "active": "is_active", + "classroom": lambda x: x.pop("collection__name"), + } + +``values`` +^^^^^^^^^^^ + +Tuple of database field names to fetch. Supports foreign key lookups using ``__`` notation. + +``field_map`` +^^^^^^^^^^^^^^ - # Better: check existence - def _map_classroom(item): - return item.pop("collection__name", None) +Dictionary mapping output field names to either: - field_map = {"classroom": _map_classroom} +- **String**: simple rename (``"api_name": "db_field"``) +- **Callable**: transformation function receiving the item dict Related Documentation ~~~~~~~~~~~~~~~~~~~~~ diff --git a/docs/dataflow/index.rst b/docs/dataflow/index.rst index e6572bab2f7..1f72235e15d 100644 --- a/docs/dataflow/index.rst +++ b/docs/dataflow/index.rst @@ -11,7 +11,7 @@ In Django, Model data are usually exposed to users through webpages that are gen In the ``api.py`` files, Django REST framework ViewSets are defined which describe how the data is made available through the REST API. Each ViewSet also requires a defined Serializer, which describes the way in which the data from the Django model is serialized into JSON and returned through the REST API. Additionally, optional filters can be applied to the ViewSet which will allow queries to filter by particular features of the data (for example by a field) or by more complex constraints, such as which group the user associated with the data belongs to. Permissions can be applied to a ViewSet, allowing the API to implicitly restrict the data that is returned, based on the currently logged in user. -The default DRF use of Serializers for serialization to JSON tends to encourage the adoption of non-performant patterns of code, particularly ones that use DRF Serializer Method Fields, which then do further queries on a per model basis inside the method. This can easily result in the N + 1 query problem, whereby the number of queries required scales with the number of entities requested in the query. To make this and other performance issues less of a concern, we have created a special ValuesViewset class defined at :code:`kolibri/core/api.py`, which relies on queryset annotation and post query processing in order to serialize all the relevant data. In addition, to prevent the inflation of full Django models into memory, all queries are done with a `values` call resulting in lower memory overhead. +The default DRF use of Serializers for serialization to JSON tends to encourage the adoption of non-performant patterns of code, particularly ones that use DRF Serializer Method Fields, which then do further queries on a per model basis inside the method. This can easily result in the N + 1 query problem, whereby the number of queries required scales with the number of entities requested in the query. To make this and other performance issues less of a concern, we have created a special ValuesViewset class defined at :code:`kolibri/core/api.py`, which relies on queryset annotation and post query processing in order to serialize all the relevant data. In addition, to prevent the inflation of full Django models into memory, all queries are done with a `values` call resulting in lower memory overhead. The ValuesViewset derives its query configuration automatically from standard DRF serializer field definitions — see :doc:`/backend_architecture/api_patterns` for full details. Finally, in the ``api_urls.py`` file, the ViewSets are given a name (through the :code:`basename` keyword argument), which sets a particular URL namespace, which is then registered and exposed when the Django server runs. Sometimes, a more complex URL scheme is used, as in the content core app, where every query is required to be prefixed by a channel id (hence the :code:`` placeholder in that route's regex pattern) diff --git a/integration_testing/scripts/viewset_serialization_benchmark.py b/integration_testing/scripts/viewset_serialization_benchmark.py new file mode 100644 index 00000000000..1c2948b32a9 --- /dev/null +++ b/integration_testing/scripts/viewset_serialization_benchmark.py @@ -0,0 +1,860 @@ +#!/usr/bin/env python +""" +Benchmark any BaseValuesViewset serialization performance. + +Benchmarks the core serialization path of a given viewset, outputs results as +JSON, and optionally compares against a previous baseline to detect regressions. + +Usage: + python integration_testing/scripts/viewset_serialization_benchmark.py VIEWSET_PATH [options] + +Examples: + # Baseline run (uses existing data from KOLIBRI_HOME) + python .../viewset_serialization_benchmark.py kolibri.core.auth.api.FacilityUserViewSet \\ + --inherit-kolibri-home -o baseline.json + + # Comparison run + python .../viewset_serialization_benchmark.py kolibri.core.auth.api.FacilityUserViewSet \\ + --inherit-kolibri-home --compare baseline.json +""" +import argparse +import gc +import hashlib +import importlib +import json +import logging +import math +import os +import platform +import statistics +import sys +import time +import tracemalloc +from datetime import datetime +from unittest.mock import MagicMock + +from django.conf import settings +from django.db import connection +from rest_framework import serializers as drf_serializers +from rest_framework.request import Request + +logger = logging.getLogger(__name__) + + +def parse_args(): + parser = argparse.ArgumentParser( + description="Benchmark a BaseValuesViewset's serialization performance." + ) + parser.add_argument( + "viewset", + nargs="?", + default=None, + help="Dotted import path (e.g. kolibri.core.auth.api.FacilityUserViewSet)", + ) + parser.add_argument( + "--synthetic", + action="store_true", + help="Run with a synthetic viewset and mock data (no DB needed). " + "Autoscales at sizes 10, 20, 50, 100.", + ) + parser.add_argument( + "-o", + "--output", + default=None, + help="JSON report output path (default: _benchmark.json)", + ) + parser.add_argument( + "--compare", + default=None, + metavar="PATH", + help="Compare current run against a baseline JSON report", + ) + parser.add_argument( + "--iterations", + type=int, + default=10000, + help="Timing iterations (default: 10000)", + ) + parser.add_argument( + "--memory-iterations", + type=int, + default=100, + help="Memory measurement iterations (default: 100)", + ) + parser.add_argument( + "--warmup", + type=int, + default=5, + help="Warmup iterations (default: 5)", + ) + parser.add_argument( + "--time-threshold", + type=float, + default=5.0, + help="Acceptable time regression %% (default: 5.0)", + ) + parser.add_argument( + "--memory-threshold", + type=float, + default=10.0, + help="Acceptable memory regression %% (default: 10.0)", + ) + parser.add_argument( + "--inherit-kolibri-home", + action="store_true", + help="Use KOLIBRI_HOME from environment instead of /tmp/kolibri_benchmark", + ) + parser.add_argument( + "--quiet", + action="store_true", + help="Suppress stdout, only write JSON", + ) + return parser.parse_args() + + +def setup_kolibri(inherit_kolibri_home=False): + if not inherit_kolibri_home: + os.environ.setdefault("KOLIBRI_HOME", "/tmp/kolibri_benchmark") + + from kolibri.utils.main import initialize + + initialize() + + +def import_viewset_class(dotted_path): + from kolibri.core.api import BaseValuesViewset + + module_path, _, class_name = dotted_path.rpartition(".") + if not module_path: + logger.error( + "Invalid viewset path '%s'. Expected format: module.ClassName", + dotted_path, + ) + sys.exit(1) + + try: + module = importlib.import_module(module_path) + except ImportError as e: + logger.error("Could not import module '%s': %s", module_path, e) + sys.exit(1) + + cls = getattr(module, class_name, None) + if cls is None: + logger.error("Module '%s' has no attribute '%s'", module_path, class_name) + sys.exit(1) + + if not (isinstance(cls, type) and issubclass(cls, BaseValuesViewset)): + logger.error("'%s' is not a subclass of BaseValuesViewset", dotted_path) + sys.exit(1) + + return cls + + +def get_queryset_for_viewset(viewset_class): + queryset = getattr(viewset_class, "queryset", None) + if queryset is not None: + return queryset.all() + + try: + viewset = viewset_class() + return viewset.get_queryset() + except Exception as e: + logger.error("Could not obtain queryset for %s: %s", viewset_class.__name__, e) + sys.exit(1) + + +def _build_synthetic_viewset(): + """Build a viewset class with a serializer exercising all serialization paths.""" + + from kolibri.core.api import BaseValuesViewset + from kolibri.core.api import ListModelMixin + from kolibri.core.api import ValuesMethodField + + class TagSerializer(drf_serializers.Serializer): + id = drf_serializers.CharField() + label = drf_serializers.CharField() + + class DepartmentSerializer(drf_serializers.Serializer): + id = drf_serializers.CharField() + name = drf_serializers.CharField() + + class SyntheticSerializer(drf_serializers.Serializer): + id = drf_serializers.CharField() + # Flat field with rename (exercises simple_renames path) + display_name = drf_serializers.CharField(source="full_name") + email = drf_serializers.CharField() + score = drf_serializers.IntegerField() + # many=True nested (exercises _auto_consolidate groupby) + tags = TagSerializer(many=True, source="tag_set") + # Single nested FK (exercises _joined_single dict consolidation) + department = DepartmentSerializer(source="dept") + # Method field over multiple sources (exercises _SourcesProxy + + # invoker callable in field_map). + contact_label = ValuesMethodField(sources=("full_name", "email")) + + def get_contact_label(self, row): + return "{} <{}>".format(row.full_name, row.email) + + class SyntheticViewset(BaseValuesViewset, ListModelMixin): + serializer_class = SyntheticSerializer + + return SyntheticViewset + + +SYNTHETIC_SIZES = (10, 20, 50, 100) + + +def _generate_synthetic_data(n): + """ + Generate n parent records as flat dicts simulating QuerySet.values() output. + + Each parent has 2 tags (many=True join) and 1 department (FK join). + The flat output has n*2 rows because of the tag join expansion. + """ + rows = [] + for i in range(n): + for t in range(2): + rows.append( + { + "id": f"user-{i:04d}", + "full_name": f"User {i}", + "email": f"user{i}@example.com", + "score": 100 + i, + "tag_set__id": f"tag-{i}-{t}", + "tag_set__label": f"label-{i}-{t}", + "dept__id": f"dept-{i % 5:04d}", + "dept__name": f"Department {i % 5}", + } + ) + return rows + + +def _make_synthetic_queryset(flat_items): + """Wrap flat dict list in a mock queryset compatible with serialize().""" + + class StubMeta: + class pk: + name = "id" + + class StubModel: + _meta = StubMeta() + + mock_qs = MagicMock() + mock_qs.model = StubModel + mock_qs.values.side_effect = lambda *a, **kw: [item.copy() for item in flat_items] + mock_qs.count.return_value = len(set(row["id"] for row in flat_items)) + return mock_qs + + +def _make_viewset(viewset_class, queryset): + """Create a viewset instance with a DRF Request for standalone use.""" + from rest_framework.test import APIRequestFactory + + factory = APIRequestFactory() + django_request = factory.get("/") + drf_request = Request(django_request) + + viewset = viewset_class() + viewset.queryset = queryset + viewset.request = drf_request + viewset.kwargs = {} + viewset.format_kwarg = None + return viewset + + +def calculate_confidence_interval(data): + """ + Calculate a 95% confidence interval for the mean using the t-distribution. + + Uses a hardcoded table of critical t-values for common small sample sizes + (rather than adding scipy as a dependency) and falls back to the + large-sample normal approximation (z=1.96) for n not in the table. + + Returns (lower_bound, upper_bound, margin_of_error). + """ + n = len(data) + if n < 2: + mean = data[0] if data else 0 + return mean, mean, 0 + + mean = statistics.mean(data) + std_err = statistics.stdev(data) / math.sqrt(n) + + # Critical t-values at 95% CI, indexed by sample size (degrees of freedom n-1). + t_values_95 = { + 2: 12.71, + 3: 4.30, + 4: 3.18, + 5: 2.78, + 6: 2.57, + 7: 2.45, + 8: 2.36, + 9: 2.31, + 10: 2.26, + 11: 2.23, + 12: 2.20, + 15: 2.14, + 20: 2.09, + } + t_val = t_values_95.get(n, 1.96) + + margin = t_val * std_err + return mean - margin, mean + margin, margin + + +def benchmark_timing(viewset_class, queryset, iterations, warmup): + """ + Benchmark serialize() + JSON encoding. + + Returns dict with timing stats and json_size_bytes. + """ + from rest_framework.renderers import JSONRenderer + + viewset = _make_viewset(viewset_class, queryset) + renderer = JSONRenderer() + + for _ in range(warmup): + result = viewset.serialize(queryset) + renderer.render(result) + + gc.collect() + gc.disable() + times = [] + json_output = None + try: + for _ in range(iterations): + start = time.perf_counter() + result = viewset.serialize(queryset) + json_output = renderer.render(result) + end = time.perf_counter() + times.append(end - start) + finally: + gc.enable() + + ci_lower, ci_upper, ci_margin = calculate_confidence_interval(times) + + return { + "mean": statistics.mean(times), + "min": min(times), + "max": max(times), + "std": statistics.stdev(times) if len(times) > 1 else 0, + "ci_lower": ci_lower, + "ci_upper": ci_upper, + "ci_margin": ci_margin, + "json_size_bytes": len(json_output) if json_output else 0, + } + + +def benchmark_memory(viewset_class, queryset, iterations, warmup): + """ + Benchmark memory usage of serialize(). + + Returns dict with mean_bytes, peak_bytes, std_bytes. + """ + viewset = _make_viewset(viewset_class, queryset) + + for _ in range(warmup): + viewset.serialize(queryset) + gc.collect() + + peak_samples = [] + for _ in range(iterations): + gc.collect() + tracemalloc.start() + + result = viewset.serialize(queryset) + + _, peak = tracemalloc.get_traced_memory() + tracemalloc.stop() + peak_samples.append(peak) + + del result + + gc.collect() + + return { + "mean_bytes": statistics.mean(peak_samples), + "peak_bytes": max(peak_samples), + "std_bytes": statistics.stdev(peak_samples) if len(peak_samples) > 1 else 0, + } + + +def count_queries(viewset_class, queryset): + """Count the number of database queries for one serialize() call.""" + viewset = _make_viewset(viewset_class, queryset) + + old_debug = settings.DEBUG + settings.DEBUG = True + + try: + connection.queries_log.clear() + viewset.serialize(queryset) + query_count = len(connection.queries) + finally: + settings.DEBUG = old_debug + + return query_count + + +def capture_data_snapshot(viewset_class, queryset): + """ + Serialize once, compute SHA-256 hash of normalized output, extract sample. + + Returns {"output_hash": "sha256:...", "sample": [...]} + """ + viewset = _make_viewset(viewset_class, queryset) + result = viewset.serialize(queryset) + + result_json = json.dumps(result, default=str) + hash_hex = hashlib.sha256(result_json.encode("utf-8")).hexdigest() + + sample = result[:5] if isinstance(result, list) else [] + + return { + "output_hash": f"sha256:{hash_hex}", + "sample": sample, + } + + +def build_report( + viewset_class, + dotted_path, + record_count, + iterations, + memory_iterations, + warmup, + timing, + memory, + queries, + data_snapshot, + time_threshold, + memory_threshold, +): + has_explicit_values = "values" in viewset_class.__dict__ and isinstance( + viewset_class.__dict__["values"], tuple + ) + has_derived = viewset_class._cached_serializer is not None + + return { + "schema_version": 1, + "metadata": { + "viewset_class": dotted_path, + "has_explicit_values": has_explicit_values, + "has_derived_field_info": has_derived, + "timestamp": datetime.now().isoformat(timespec="seconds"), + "python_version": platform.python_version(), + "record_count": record_count, + "iterations": iterations, + "memory_iterations": memory_iterations, + "warmup_iterations": warmup, + }, + "timing": { + "mean_ms": timing["mean"] * 1000, + "min_ms": timing["min"] * 1000, + "max_ms": timing["max"] * 1000, + "std_ms": timing["std"] * 1000, + "ci_lower_ms": timing["ci_lower"] * 1000, + "ci_upper_ms": timing["ci_upper"] * 1000, + "ci_margin_ms": timing["ci_margin"] * 1000, + "json_size_bytes": timing["json_size_bytes"], + }, + "memory": { + "mean_bytes": memory["mean_bytes"], + "peak_bytes": memory["peak_bytes"], + "std_bytes": memory["std_bytes"], + }, + "queries": { + "count": queries, + }, + "data": data_snapshot, + "thresholds": { + "time_regression_pct": time_threshold, + "memory_regression_pct": memory_threshold, + }, + } + + +def write_report(report, path): + with open(path, "w") as f: + json.dump(report, f, indent=2, default=str) + f.write("\n") + + +def load_report(path): + with open(path) as f: + report = json.load(f) + if report.get("schema_version") != 1: + logger.warning( + "Baseline report has schema_version=%s, expected 1", + report.get("schema_version"), + ) + return report + + +def compare_reports(baseline, current, time_threshold, memory_threshold): + """Compare two reports and return a verdict dict.""" + b_time = baseline["timing"]["mean_ms"] + c_time = current["timing"]["mean_ms"] + time_diff_pct = ((c_time - b_time) / b_time * 100) if b_time > 0 else 0 + + # Sub-2ms timings are dominated by system noise — only check the + # percentage threshold when both baseline and current exceed 2ms. + if b_time < 2.0 and c_time < 2.0: + time_pass = True + else: + time_pass = time_diff_pct <= time_threshold + + b_mem = baseline["memory"]["mean_bytes"] + c_mem = current["memory"]["mean_bytes"] + mem_diff_pct = ((c_mem - b_mem) / b_mem * 100) if b_mem > 0 else 0 + mem_pass = mem_diff_pct <= memory_threshold + + b_queries = baseline["queries"]["count"] + c_queries = current["queries"]["count"] + if b_queries is not None and c_queries is not None: + queries_pass = c_queries <= b_queries + query_diff = c_queries - b_queries + else: + queries_pass = True + query_diff = 0 + + b_hash = baseline["data"]["output_hash"] + c_hash = current["data"]["output_hash"] + data_match = b_hash == c_hash + + overall_pass = time_pass and mem_pass and queries_pass and data_match + + time_below_floor = b_time < 2.0 and c_time < 2.0 + + return { + "time_diff_pct": time_diff_pct, + "time_below_floor": time_below_floor, + "time_pass": time_pass, + "mem_diff_pct": mem_diff_pct, + "mem_pass": mem_pass, + "query_diff": query_diff, + "queries_pass": queries_pass, + "data_match": data_match, + "overall_pass": overall_pass, + } + + +def _fmt_bytes(b): + """Format bytes as human-readable KB.""" + return f"{b / 1024:.1f} KB" + + +def _pattern_label(metadata): + if metadata.get("has_derived_field_info"): + return "derived" + elif metadata.get("has_explicit_values"): + return "explicit" + return "unknown" + + +def print_comparison(baseline, current, verdict): + """Print human-readable comparison table.""" + b = baseline + c = current + + def row(label, b_val, c_val, diff, verdict_str, detail=""): + line = f" {label:<18} {b_val:>14} {c_val:>14} {diff:>10} {verdict_str:>6}" + if detail: + line += f" {detail}" + logger.info(line) + + logger.info("\n[Comparison: current vs baseline]") + logger.info("-" * 70) + row("Metric", "Baseline", "Current", "Diff", "Verdict") + logger.info(" %s", "-" * 66) + + # Time + if verdict["time_below_floor"]: + time_v, time_detail = "SKIP", "(< 2ms)" + elif not verdict["time_pass"]: + time_v = "FAIL" + time_detail = f"(> {c['thresholds']['time_regression_pct']}%)" + else: + time_v, time_detail = "PASS", "" + row( + "Time (mean)", + f"{b['timing']['mean_ms']:.3f} ms", + f"{c['timing']['mean_ms']:.3f} ms", + f"{verdict['time_diff_pct']:+.1f}%", + time_v, + time_detail, + ) + + # Memory + mem_pass = verdict["mem_pass"] + row( + "Memory (mean)", + _fmt_bytes(b["memory"]["mean_bytes"]), + _fmt_bytes(c["memory"]["mean_bytes"]), + f"{verdict['mem_diff_pct']:+.1f}%", + "PASS" if mem_pass else "FAIL", + "" if mem_pass else f"(> {c['thresholds']['memory_regression_pct']}%)", + ) + + # Queries + b_q = b["queries"]["count"] + c_q = c["queries"]["count"] + if b_q is not None and c_q is not None: + row( + "DB Queries", + str(b_q), + str(c_q), + f"{verdict['query_diff']:+d}", + "PASS" if verdict["queries_pass"] else "FAIL", + ) + else: + row("DB Queries", "N/A", "N/A", "", "SKIP") + + # Data output + data_match = verdict["data_match"] + row( + "Data output", + b["data"]["output_hash"][:20] + "...", + c["data"]["output_hash"][:20] + "...", + "match" if data_match else "differ", + "PASS" if data_match else "FAIL", + ) + + # Info rows + b_size = b["timing"]["json_size_bytes"] + c_size = c["timing"]["json_size_bytes"] + row("JSON size", f"{b_size} B", f"{c_size} B", f"{c_size - b_size:+d} B", "INFO") + + b_rec = b["metadata"]["record_count"] + c_rec = c["metadata"]["record_count"] + row("Records", str(b_rec), str(c_rec), f"{c_rec - b_rec:+d}", "INFO") + + row( + "Pattern", + _pattern_label(b["metadata"]), + _pattern_label(c["metadata"]), + "", + "INFO", + ) + + logger.info(" %s", "-" * 66) + + if not verdict["data_match"]: + logger.info( + " NOTE: Data hashes differ. This may be expected if data changed between runs." + ) + b_sample = b["data"].get("sample", []) + c_sample = c["data"].get("sample", []) + if b_sample and c_sample: + for i, (bs, cs) in enumerate(zip(b_sample, c_sample)): + if bs != cs: + logger.info(" First sample difference at index %d:", i) + logger.info(" Baseline: %s", bs) + logger.info(" Current: %s", cs) + break + + overall = "PASS" if verdict["overall_pass"] else "FAIL" + logger.info("OVERALL VERDICT: %s", overall) + + +def _run_synthetic(args): + """Run benchmark with synthetic viewset at multiple data sizes.""" + setup_kolibri(inherit_kolibri_home=args.inherit_kolibri_home) + + viewset_class = _build_synthetic_viewset() + sizes_report = {} + + for size in SYNTHETIC_SIZES: + if not args.quiet: + logger.info("\n--- Size: %d ---", size) + flat_items = _generate_synthetic_data(size) + mock_qs = _make_synthetic_queryset(flat_items) + + if not args.quiet: + logger.info("Running timing benchmark...") + timing = benchmark_timing(viewset_class, mock_qs, args.iterations, args.warmup) + + if not args.quiet: + logger.info("Running memory benchmark...") + memory = benchmark_memory( + viewset_class, mock_qs, args.memory_iterations, args.warmup + ) + + if not args.quiet: + logger.info("Capturing data snapshot...") + data_snapshot = capture_data_snapshot(viewset_class, mock_qs) + + sizes_report[str(size)] = build_report( + viewset_class=viewset_class, + dotted_path="", + record_count=size, + iterations=args.iterations, + memory_iterations=args.memory_iterations, + warmup=args.warmup, + timing=timing, + memory=memory, + queries=None, + data_snapshot=data_snapshot, + time_threshold=args.time_threshold, + memory_threshold=args.memory_threshold, + ) + + if not args.quiet: + r = sizes_report[str(size)] + logger.info(" Time: %.3f ms (mean)", r["timing"]["mean_ms"]) + logger.info(" Memory: %s (mean)", _fmt_bytes(r["memory"]["mean_bytes"])) + logger.info(" Data hash: %s...", r["data"]["output_hash"][:30]) + + report = { + "schema_version": 1, + "synthetic": True, + "sizes": sizes_report, + } + + output_path = args.output or "synthetic_benchmark.json" + write_report(report, output_path) + + if not args.quiet: + logger.info("\nReport written to: %s", output_path) + + if args.compare: + baseline = load_report(args.compare) + return _compare_synthetic(baseline, report, args) + + return 0 + + +def _compare_synthetic(baseline, current, args): + """Compare two synthetic reports size-by-size.""" + overall_pass = True + + for size in SYNTHETIC_SIZES: + key = str(size) + if key not in baseline.get("sizes", {}): + if not args.quiet: + logger.warning("Size %d not in baseline, skipping", size) + continue + if key not in current.get("sizes", {}): + if not args.quiet: + logger.warning("Size %d not in current, skipping", size) + continue + + b = baseline["sizes"][key] + c = current["sizes"][key] + verdict = compare_reports(b, c, args.time_threshold, args.memory_threshold) + + if not args.quiet: + logger.info("\n--- Size: %d ---", size) + print_comparison(b, c, verdict) + + if not verdict["overall_pass"]: + overall_pass = False + + return 0 if overall_pass else 1 + + +def _run_real_viewset(args): + """Run benchmark against a real viewset with database data.""" + setup_kolibri(inherit_kolibri_home=args.inherit_kolibri_home) + + viewset_class = import_viewset_class(args.viewset) + + queryset = get_queryset_for_viewset(viewset_class) + record_count = queryset.count() + + if record_count == 0: + logger.warning( + "No records found for %s. " + "Use --inherit-kolibri-home with a populated KOLIBRI_HOME.", + viewset_class.__name__, + ) + + has_explicit = "values" in viewset_class.__dict__ and isinstance( + viewset_class.__dict__["values"], tuple + ) + has_derived = viewset_class._cached_serializer is not None + pattern = "derived" if has_derived else ("explicit" if has_explicit else "unknown") + + if not args.quiet: + logger.info("Viewset: %s", args.viewset) + logger.info(" Pattern: %s", pattern) + logger.info(" Records: %d", record_count) + logger.info( + " Iterations: %d (timing), %d (memory)", + args.iterations, + args.memory_iterations, + ) + logger.info(" Warmup: %d", args.warmup) + + # Benchmarks + if not args.quiet: + logger.info("Running timing benchmark...") + timing = benchmark_timing(viewset_class, queryset, args.iterations, args.warmup) + + if not args.quiet: + logger.info("Running memory benchmark...") + memory = benchmark_memory( + viewset_class, queryset, args.memory_iterations, args.warmup + ) + + if not args.quiet: + logger.info("Counting queries...") + queries = count_queries(viewset_class, queryset) + + if not args.quiet: + logger.info("Capturing data snapshot...") + data_snapshot = capture_data_snapshot(viewset_class, queryset) + + report = build_report( + viewset_class=viewset_class, + dotted_path=args.viewset, + record_count=record_count, + iterations=args.iterations, + memory_iterations=args.memory_iterations, + warmup=args.warmup, + timing=timing, + memory=memory, + queries=queries, + data_snapshot=data_snapshot, + time_threshold=args.time_threshold, + memory_threshold=args.memory_threshold, + ) + + output_path = args.output or f"{viewset_class.__name__}_benchmark.json" + write_report(report, output_path) + + if not args.quiet: + logger.info("\nReport written to: %s", output_path) + logger.info(" Time: %.3f ms (mean)", report["timing"]["mean_ms"]) + logger.info(" Memory: %s (mean)", _fmt_bytes(report["memory"]["mean_bytes"])) + logger.info(" Queries: %s", report["queries"]["count"]) + logger.info(" JSON size: %d bytes", report["timing"]["json_size_bytes"]) + logger.info(" Data hash: %s...", report["data"]["output_hash"][:30]) + + if args.compare: + baseline = load_report(args.compare) + verdict = compare_reports( + baseline, report, args.time_threshold, args.memory_threshold + ) + if not args.quiet: + print_comparison(baseline, report, verdict) + return 0 if verdict["overall_pass"] else 1 + + return 0 + + +def main(): + args = parse_args() + if not args.viewset and not args.synthetic: + logger.error("Provide a viewset path or use --synthetic") + return 1 + + if args.synthetic: + return _run_synthetic(args) + + return _run_real_viewset(args) + + +if __name__ == "__main__": + logging.basicConfig(level=logging.INFO, format="%(message)s") + sys.exit(main()) diff --git a/kolibri/core/api.py b/kolibri/core/api.py index 2e2a3adc57f..fa2b3a12d3d 100644 --- a/kolibri/core/api.py +++ b/kolibri/core/api.py @@ -9,8 +9,18 @@ For more information, see: docs/backend_architecture/index.rst """ +import operator +import threading import uuid - +from collections import defaultdict +from contextlib import contextmanager +from itertools import groupby +from typing import Any +from typing import Dict +from typing import List +from typing import Optional + +from django.conf import settings from django.http import Http404 from django.http.request import QueryDict from rest_framework import viewsets @@ -34,6 +44,9 @@ from kolibri.core.discovery.utils.network.client import NetworkClient from kolibri.core.discovery.utils.network.errors import NetworkLocationConnectionFailure from kolibri.core.discovery.utils.network.errors import NetworkLocationResponseFailure +from kolibri.core.utils.serializer_introspection import derive_values_from_serializer +from kolibri.core.utils.serializer_introspection import normalize_field_map +from kolibri.core.utils.serializer_introspection import ValuesMethodField # noqa: F401 from kolibri.utils import conf @@ -88,14 +101,12 @@ def get_default_valid_fields(self, queryset, view, context=None): if context is None: context = {} default_fields = set() - # All the fields that we have field maps defined for - this only allows for simple mapped fields - # where the field is essentially a rename, as we have no good way of doing ordering on a field that - # that is doing more complex function based mapping. - mapped_fields = {v: k for k, v in view.field_map.items() if isinstance(v, str)} + # Invert to source -> target for looking up values by their DB name + mapped_fields = {v: k for k, v in view._field_map.source_map().items()} # All the fields of the model model_fields = {f.name for f in queryset.model._meta.get_fields()} # Loop through every value in the view's values tuple - for field in view.values: + for field in view._values: # If the value is for a foreign key lookup, we split it here to make sure that the first relation key # exists on the model - it's unlikely this would ever not be the case, as otherwise the viewset would # be returning 500s. @@ -121,21 +132,20 @@ def remove_invalid_fields(self, queryset, fields, view, request): Modified from https://github.com/encode/django-rest-framework/blob/version-3.12.2/rest_framework/filters.py#L259 to do filtering based on valuesviewset setup """ - # We filter the mapped fields to ones that do simple string mappings here, any functional maps are excluded. - mapped_fields = {k: v for k, v in view.field_map.items() if isinstance(v, str)} + mapped_fields = view._field_map.source_map() valid_fields = [ item[0] for item in self.get_valid_fields(queryset, view, {"request": request}) ] ordering = [] for term in fields: - if term.lstrip("-") in valid_fields: - if term.lstrip("-") in mapped_fields: + field_name = term.lstrip("-") + if field_name in valid_fields: + if field_name in mapped_fields: # In the case that the ordering field is a mapped field on the values viewset # we substitute the serialized name of the field for the database name. prefix = "-" if term[0] == "-" else "" - new_term = prefix + mapped_fields[term.lstrip("-")] - ordering.append(new_term) + ordering.append(prefix + mapped_fields[field_name]) else: ordering.append(term) if len(ordering) > 1: @@ -143,30 +153,250 @@ def remove_invalid_fields(self, queryset, fields, view, request): return ordering +class _ThreadLocalContext(threading.local): + """ + A dict-like context whose contents are thread-local — writes by one + thread don't leak to others. Used as the ``context`` on a shared cached + serializer so the same instance can safely carry per-request context + (``request``, ``view``, ``format``) on each worker thread without + per-request allocation. + + Inheriting from ``threading.local`` gives every thread its own + ``self.__dict__``; the dict-like protocol proxies straight to it. + + ``BaseValuesViewset.serialize()`` populates this from + ``get_serializer_context()`` before running the pipeline and clears it + on exit. + """ + + def __getitem__(self, key): + return self.__dict__[key] + + def __setitem__(self, key, value): + self.__dict__[key] = value + + def __delitem__(self, key): + del self.__dict__[key] + + def __contains__(self, key): + return key in self.__dict__ + + def __iter__(self): + return iter(self.__dict__) + + def __len__(self): + return len(self.__dict__) + + def __repr__(self): + return repr(self.__dict__) + + def get(self, key, default=None): + return self.__dict__.get(key, default) + + def keys(self): + return self.__dict__.keys() + + def values(self): + return self.__dict__.values() + + def items(self): + return self.__dict__.items() + + def update(self, *args, **kwargs): + self.__dict__.update(*args, **kwargs) + + def setdefault(self, key, default=None): + return self.__dict__.setdefault(key, default) + + def pop(self, key, *args): + return self.__dict__.pop(key, *args) + + def clear(self): + self.__dict__.clear() + + class BaseValuesViewset(viewsets.GenericViewSet): """ A viewset that uses a values call to get all model/queryset data in a single database query, rather than delegating serialization to a DRF ModelSerializer. + + Values can be specified explicitly via the `values` attribute, or derived + automatically from the serializer_class field definitions. + + To use serializer-derived values: + 1. Define serializer_class with proper field source attributes + 2. Do NOT define a `values` attribute (or set it to None) + 3. Optionally set `deferred_fields` for nested serializers to fetch separately """ - # A tuple of values to get from the queryset - # values = None + # A tuple of values to get from the queryset. + # If not defined, values will be derived from serializer_class on first + # instantiation via _ensure_initialized. + # A map of target_key, source_key where target_key is the final target_key that will be set # and source_key is the key on the object retrieved from the values call. # Alternatively, the source_key can be a callable that will be passed the object and return # the value for the target_key. This callable can also pop unwanted values from the obj # to remove unneeded keys from the object as a side effect. - field_map = {} + # For derived pattern, this is built automatically from serializer renames. + + # Tuple of nested serializer field names that should be fetched separately + # rather than joined in the main query. These fields are handled in consolidate(). + deferred_fields = () + + # Cached itemgetter for pk_field, used in _auto_consolidate for fast groupby + _pk_getter = None + # Cached many=True nested info for groupby consolidation + _joined_many = () + # Cached serializer instance used for introspection and for invoking + # ``ValuesMethodField`` bound methods. The instance is shared across + # requests; its ``context`` is a ``_ThreadLocalContext`` so per-request + # values don't leak between threads. + _cached_serializer = None + # Thread-local context object attached to ``_cached_serializer.context``. + # ``serialize()`` populates it from ``get_serializer_context()`` and + # clears it on exit. + _serializer_context = None + # Cached derived info for deferred fields, keyed by serializer_path. + # Defaults to None; set per-class by _ensure_initialized. + _nested_derived_cache = None + # Cached validation schema for DEBUG mode: (expected_fields, nested_schemas) + # Built once during _ensure_initialized to avoid per-request recomputation. + _validation_schema = None + # Whether _ensure_initialized has run for this class + _initialized = False + # Guards _ensure_initialized for this class; each subclass gets its own + # lock via __init_subclass__ so it isn't shared through the MRO. + _initialization_lock = threading.Lock() + + def __init_subclass__(cls, **kwargs): + super().__init_subclass__(**kwargs) + cls._initialization_lock = threading.Lock() + + @classmethod + def _get_own(cls, attr, default=None): + """ + Get attr from this class's own dict, ignoring MRO inheritance. + + Prevents dynamically cached class attributes (e.g. serializer_class + set by get_serializer_class()) from leaking to child classes. + """ + return cls.__dict__.get(attr, default) + + @classmethod + def _ensure_initialized(cls): + """ + Run once per concrete subclass on first instantiation. + + Deferred from __init_subclass__ to avoid instantiating serializers + (which may reference querysets) at class definition / import time. + + Double-checked locking: the unlocked fast path avoids lock overhead + on every instantiation once initialized; the re-check inside the + lock ensures the initialization work runs exactly once per class, + including under free-threaded Python (3.13+ no-GIL). + """ + if cls._get_own("_initialized", False): + return + + with cls._initialization_lock: + if cls._get_own("_initialized", False): + return + cls._do_initialize() + + @classmethod + def _do_initialize(cls): + cls._serializer_context = _ThreadLocalContext() + + has_explicit_values = isinstance(getattr(cls, "values", None), tuple) + serializer_class = getattr(cls, "serializer_class", None) + + if has_explicit_values: + cls._values = tuple(cls.values) + if not hasattr(cls, "field_map"): + cls.field_map = {} + # Normalize legacy str/callable entries to canonical entry + # objects (SourceFieldEntry/CallableFieldEntry). Produces a + # fresh _LegacyFieldMap so post-init mutation of cls.field_map + # doesn't leak into instance serialization. + cls._field_map = normalize_field_map(cls.field_map) + elif serializer_class is not None: + cls._cached_serializer = serializer_class(context=cls._serializer_context) + ( + cls._values, + cls._field_map, + cls._joined_many, + cls._nested_derived_cache, + ) = derive_values_from_serializer( + cls._cached_serializer, + deferred_fields=cls.deferred_fields, + check_constraints=settings.DEBUG, + ) + # Auto-derived: keep _values/_field_map only on cls. Writing to + # ``cls.values`` here would expose the tuple to subclasses via MRO + # so ``has_explicit_values`` would mis-detect it as user-supplied, + # routing the child into the explicit-values path and skipping + # serializer derivation against its own ``serializer_class``. + cls._values = tuple(cls._values) + else: + raise TypeError( + "Either 'values' tuple or 'serializer_class' must be defined" + ) + + # Cache pk itemgetter from queryset + queryset = getattr(cls, "queryset", None) + if queryset is not None and hasattr(queryset, "model"): + cls._pk_getter = operator.itemgetter(queryset.model._meta.pk.name) + + # Cache validation schema for DEBUG mode + if settings.DEBUG: + serializer = cls._get_own("_cached_serializer") + if ( + serializer is None + and getattr(cls, "serializer_class", None) is not None + ): + serializer = cls.serializer_class() + if serializer is not None: + cls._validation_schema = cls._build_validation_schema(serializer) + + cls._initialized = True def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - if not hasattr(self, "values") or not isinstance(self.values, tuple): - raise TypeError("values must be defined as a tuple") - self._values = tuple(self.values) - if not isinstance(self.field_map, dict): - raise TypeError("field_map must be defined as a dict") - self._field_map = self.field_map.copy() + self.__class__._ensure_initialized() + + def get_serializer_context(self): + """ + Return the serializer context for this request. + + Overrides DRF's default to tolerate programmatic invocation outside + the request cycle (tests, inline usage): ``request``, ``view``, and + ``format`` default to ``None`` if the viewset hasn't been dispatched. + """ + return { + "request": getattr(self, "request", None), + "view": self, + "format": getattr(self, "format_kwarg", None), + } + + @contextmanager + def _serializer_context_scope(self): + """ + Populate the thread-local serializer context for the duration of a + serialization pipeline, clearing it on exit. Re-entrant: nested + scopes (e.g. ``serialize_queryset`` invoked from ``consolidate``) + are no-ops, so the outer scope's context survives until its own + exit. + """ + already_set = bool(self._serializer_context) + if not already_set: + self._serializer_context.update(self.get_serializer_context()) + try: + yield + finally: + if not already_set: + self._serializer_context.clear() def generate_serializer(self): queryset = getattr(self, "queryset", None) @@ -178,10 +408,12 @@ def generate_serializer(self): model = getattr(queryset, "model", None) if model is None: return Serializer - mapped_fields = {v: k for k, v in self.field_map.items() if isinstance(v, str)} + # {source: target} for plain renames, so values can be exposed + # under the declared name. + mapped_fields = self._field_map.plain_renames() if self._field_map else {} fields = [] extra_kwargs = {} - for value in self.values: + for value in self._values: try: model._meta.get_field(value) if value in mapped_fields: @@ -212,9 +444,15 @@ def generate_serializer(self): def get_serializer_class(self): if self.serializer_class is not None: return self.serializer_class - # Hack to prevent the renderer logic from breaking completely. - self.__class__.serializer_class = self.generate_serializer() - return self.__class__.serializer_class + # Generate a serializer for DRF schema/renderer compatibility. + # Cached on _generated_serializer_class (not serializer_class) to + # avoid leaking to child classes via MRO. + cls = self.__class__ + generated = cls._get_own("_generated_serializer_class") + if generated is None: + generated = self.generate_serializer() + cls._generated_serializer_class = generated + return generated def _get_lookup_filter(self): lookup_url_kwarg = self.lookup_url_kwarg or self.lookup_field @@ -232,25 +470,302 @@ def _get_lookup_filter(self): def annotate_queryset(self, queryset): return queryset - def _map_fields(self, item): - for key, value in self._field_map.items(): - if callable(value): - item[key] = value(item) - elif value in item: - item[key] = item.pop(value) + def get_nested_serializer(self, path: str): + """ + Resolve a dotted path to a nested serializer. + + Args: + path: Dotted path like 'roles' or 'children__grandchildren' + + Returns: + The nested serializer instance + + Raises: + KeyError: If path doesn't resolve to a valid nested serializer + """ + if self._cached_serializer is None: + raise AttributeError( + "get_nested_serializer requires serializer-derived values" + ) + serializer = self._cached_serializer + for part in path.split("__"): + field = serializer.fields[part] + # Handle many=True (ListSerializer wraps the actual serializer) + if hasattr(field, "child"): + serializer = field.child else: - item[key] = value - return item + serializer = field + + return serializer + + def _group_items( + self, items: List[Dict[str, Any]], group_by: str + ) -> Dict[Any, List[Dict[str, Any]]]: + """ + Group items by a field value. + + Args: + items: List of dictionaries + group_by: Field name to group by + + Returns: + Dict mapping group_by values to lists of items + """ + result: Dict[Any, List[Dict[str, Any]]] = defaultdict(list) + for item in items: + result[item.get(group_by)].append(item) + return dict(result) + + @staticmethod + def _serialize_flat(queryset, values, field_map): + """Base serialization: values() call + field mapping.""" + items = queryset.values(*values) + if field_map: + return [field_map.map_row(item) for item in items] + return list(items) + + def serialize_queryset( + self, + queryset, + serializer_path: Optional[str] = None, + *, + group_by: Optional[str] = None, + ): + """ + Serialize any queryset using a serializer's field definitions. + + Args: + queryset: Any Django queryset + serializer_path: Dotted path to nested serializer (e.g., 'roles' + or 'files'); ``None`` selects the viewset's own top-level + serializer. + group_by: Optional field to group results by (returns dict of key -> [items]) + + Returns: + List of serialized items, or Dict {group_key: [items]} if group_by specified + """ + if serializer_path is None: + values = self._values + field_map = self._field_map + joined_many = self._joined_many + pk_getter = self._pk_getter + if pk_getter is None: + # Class-level queryset wasn't resolvable at init time + # (viewset uses get_queryset()); cache it now so subsequent + # calls skip the lookup. + pk_getter = operator.itemgetter(queryset.model._meta.pk.name) + self.__class__._pk_getter = pk_getter + else: + if self._nested_derived_cache is None: + raise AttributeError( + "serialize_queryset requires serializer-derived values" + ) + values, field_map, joined_many = self._nested_derived_cache[serializer_path] + pk_getter = operator.itemgetter(queryset.model._meta.pk.name) + + with self._serializer_context_scope(): + items = self._serialize_flat(queryset, values, field_map) + items = self._auto_consolidate(items, joined_many, pk_getter) + + # Group if requested + if group_by is not None: + return self._group_items(items, group_by) + + return items + + @staticmethod + def _build_validation_schema(serializer): + """ + Build a cached validation schema from a serializer. + + Returns (expected_fields, nested_schemas) where: + - expected_fields: frozenset of field names (excluding write_only) + - nested_schemas: dict mapping field_name to nested schema tuples + """ + expected_fields = set() + nested_schemas = {} + + for field_name, field in serializer.fields.items(): + if getattr(field, "write_only", False): + continue + expected_fields.add(field_name) + if hasattr(field, "child") and isinstance(field.child, Serializer): + nested_schemas[field_name] = BaseValuesViewset._build_validation_schema( + field.child + ) + elif isinstance(field, Serializer): + nested_schemas[field_name] = BaseValuesViewset._build_validation_schema( + field + ) + + return (frozenset(expected_fields), nested_schemas) + + def _validate_output(self, items: List[Dict[str, Any]]) -> None: + """ + Validate serialized output matches serializer contract. + + Only intended for use in DEBUG mode to catch drift between + consolidate() implementation and serializer declarations. + + Uses the cached _validation_schema when available (built during + _ensure_initialized), falling back to building from the serializer. + """ + if not items: + return + + schema = self._validation_schema + if schema is None: + # Fallback for viewsets without a cached schema + if self._cached_serializer is not None: + schema = self._build_validation_schema(self._cached_serializer) + elif self.serializer_class is not None: + schema = self._build_validation_schema(self.serializer_class()) + else: + return + + self._validate_items_against_schema(items, schema) + + @staticmethod + def _validate_items_against_schema( + items: List[Dict[str, Any]], + schema, + ) -> None: + """ + Validate items against a cached validation schema. + + Only checks the first item since all rows from values() have + uniform keys — one item is enough to catch schema drift. + Recurses into nested schemas. + """ + if not items: + return + + expected_fields, nested_schemas = schema + item = items[0] + item_keys = set(item.keys()) + + missing = expected_fields - item_keys + if missing: + raise ValueError( + "Missing fields in output: {}. " + "Expected: {}, Got: {}".format(missing, expected_fields, item_keys) + ) + + extra = item_keys - expected_fields + if extra: + raise ValueError( + "Unexpected fields in output: {}. " + "Expected: {}, Got: {}".format(extra, expected_fields, item_keys) + ) + + for field_name, nested_schema in nested_schemas.items(): + nested_value = item.get(field_name) + if nested_value is None: + continue + if isinstance(nested_value, dict): + nested_value = [nested_value] + if isinstance(nested_value, list): + BaseValuesViewset._validate_items_against_schema( + nested_value, nested_schema + ) + + @staticmethod + def _get_nested_child_pk(field_name, nested_pk, val): + """Extract a nested child's PK, raising on missing keys. + + When nested_pk is None the field is a scalar from a one-to-many + relation (e.g. roles__kind); the value itself is the dedup key. + """ + if nested_pk is None: + return val + try: + return val[nested_pk] + except KeyError: + raise KeyError( + "_auto_consolidate: nested field '{}' has no key " + "'{}' for deduplication. Available keys: {}. " + "Check that _resolve_nested_pk_output_name matches " + "the field_map output.".format(field_name, nested_pk, list(val.keys())) + ) + + def _auto_consolidate( + self, + items: List[Dict[str, Any]], + joined_many, + pk_getter, + ) -> List[Dict[str, Any]]: + """ + Consolidate many=True nested fields using groupby. + + Nested extraction is already done by field_map callables during + _serialize_flat. This method only handles groupby + list collection + for many=True fields (converting per-row dicts to lists and deduplicating). + + Items must be sorted by PK for groupby, but original queryset + ordering is restored afterwards. + + ``joined_many`` and ``pk_getter`` are passed by the caller so the + same routine handles top-level and nested-path consolidation + (``serialize_queryset`` reads them from the cache entry / nested + queryset's model). + """ + if not items or not joined_many: + return items + + # dict.fromkeys deduplicates while preserving insertion order (a set + # would not), so consolidated items can be returned in the original + # queryset order. + original_pk_order = list(dict.fromkeys(pk_getter(item) for item in items)) + # groupby only groups *consecutive* equal keys, so items must be + # sorted by PK first or a custom queryset ordering could split one + # PK's rows into multiple groups. + items = sorted(items, key=pk_getter) + consolidated: Dict[Any, Dict[str, Any]] = {} + + for pk, group in groupby(items, pk_getter): + group_iter = iter(group) + consolidated_item = next(group_iter) + seen_child_pks: Dict[str, set] = {fn: set() for fn, _ in joined_many} + + # Convert per-row nested dicts to lists for the first item + for field_name, nested_pk in joined_many: + val = consolidated_item[field_name] + if val is not None: + child_pk = self._get_nested_child_pk(field_name, nested_pk, val) + seen_child_pks[field_name].add(child_pk) + consolidated_item[field_name] = [val] + else: + consolidated_item[field_name] = [] + + for item in group_iter: + for field_name, nested_pk in joined_many: + val = item[field_name] + if val is not None: + child_pk = self._get_nested_child_pk(field_name, nested_pk, val) + if child_pk not in seen_child_pks[field_name]: + seen_child_pks[field_name].add(child_pk) + consolidated_item[field_name].append(val) + consolidated[pk] = consolidated_item + + return [consolidated[pk] for pk in original_pk_order] def consolidate(self, items, queryset): + """ + Override point for custom consolidation logic. + """ return items def serialize(self, queryset): queryset = self.annotate_queryset(queryset) - values_queryset = queryset.values(*self._values) - return self.consolidate( - list(map(self._map_fields, values_queryset or [])), queryset - ) + with self._serializer_context_scope(): + items = self.serialize_queryset(queryset) + result = self.consolidate(items, queryset) + + # Dev-mode validation: check output matches serializer contract + if settings.DEBUG: + self._validate_output(result) + + return result def serialize_object(self, **filter_kwargs): try: diff --git a/kolibri/core/auth/api.py b/kolibri/core/auth/api.py index 9107f9f109a..21f5d01eada 100644 --- a/kolibri/core/auth/api.py +++ b/kolibri/core/auth/api.py @@ -68,6 +68,7 @@ from .models import Role from .serializers import ClassroomSerializer from .serializers import CreateFacilitySerializer +from .serializers import DeletedFacilityUserSerializer from .serializers import ExtraFieldsSerializer from .serializers import FacilityDatasetSerializer from .serializers import FacilitySerializer @@ -75,6 +76,7 @@ from .serializers import LearnerGroupSerializer from .serializers import MembershipSerializer from .serializers import PublicFacilitySerializer +from .serializers import PublicFacilityUserSerializer from .serializers import RoleSerializer from kolibri.core import error_constants from kolibri.core.api import ReadOnlyValuesViewset @@ -533,23 +535,10 @@ class Meta: class PublicFacilityUserViewSet(ReadOnlyValuesViewset): - queryset = FacilityUser.objects.all() + queryset = FacilityUser.objects.all().order_by("id") + serializer_class = PublicFacilityUserSerializer authentication_classes = [BasicMultiArgumentAuthentication] permission_classes = [IsAuthenticated] - values = ( - "id", - "username", - "full_name", - "facility", - "roles__kind", - "devicepermissions__is_superuser", - "id_number", - "gender", - "birth_year", - ) - field_map = { - "is_superuser": lambda x: bool(x.pop("devicepermissions__is_superuser")), - } def get_queryset(self): if self.request.user.is_anonymous: @@ -572,60 +561,8 @@ def get_queryset(self): return queryset - def consolidate(self, items, queryset): - output = [] - items = sorted(items, key=lambda x: x["id"]) - for key, group in groupby(items, lambda x: x["id"]): - roles = [] - for item in group: - role = item.pop("roles__kind") - if role is not None: - roles.append(role) - item["roles"] = roles - output.append(item) - return output - - -class FacilityUserConsolidateMixin: - """ - Mixin for FacilityUser ViewSets to handle consolidate logic - """ - - def consolidate(self, items, queryset): - output = [] - items = sorted(items, key=lambda x: x["id"]) - ordering_param = self.request.query_params.get("ordering", self.order_by_field) - reverse = False - for key, group in groupby(items, lambda x: x["id"]): - roles = [] - for item in group: - role = { - "collection": item.pop("roles__collection"), - "kind": item.pop("roles__kind"), - "id": item.pop("roles__id"), - } - if role["collection"]: - # Our values call will return null for users with no assigned roles - # So filter them here. - roles.append(role) - item["roles"] = roles - output.append(item) - if ordering_param.startswith("-"): - ordering_param = ordering_param[1:] - reverse = True - output = sorted( - output, - key=lambda x: ( - x[ordering_param].lower() - if isinstance(x[ordering_param], str) - else x[ordering_param] - ), - reverse=reverse, - ) - return output - -class FacilityUserViewSet(FacilityUserConsolidateMixin, ValuesViewset, BulkDeleteMixin): +class FacilityUserViewSet(ValuesViewset, BulkDeleteMixin): permission_classes = (KolibriAuthPermissions,) pagination_class = OptionalPageNumberPagination filter_backends = ( @@ -642,23 +579,6 @@ class FacilityUserViewSet(FacilityUserConsolidateMixin, ValuesViewset, BulkDelet search_fields = ("username", "full_name") - values = ( - "id", - "username", - "full_name", - "facility", - "roles__kind", - "roles__collection", - "roles__id", - "devicepermissions__is_superuser", - "id_number", - "gender", - "birth_year", - "extra_demographics", - "date_joined", - "picture_password", - ) - ordering_fields = ( "id", "username", @@ -669,10 +589,6 @@ class FacilityUserViewSet(FacilityUserConsolidateMixin, ValuesViewset, BulkDelet "date_joined", ) - field_map = { - "is_superuser": lambda x: bool(x.pop("devicepermissions__is_superuser")) - } - def destroy(self, request, *args, **kwargs): if kwargs.get("pk"): # Single object deletion @@ -701,7 +617,6 @@ def perform_update(self, serializer): class DeletedFacilityUserViewSet( - FacilityUserConsolidateMixin, ReadOnlyValuesViewset, DestroyModelMixin, BulkDeleteMixin, @@ -719,13 +634,11 @@ class DeletedFacilityUserViewSet( order_by_field = "date_deleted" queryset = FacilityUser.soft_deleted_objects.all().order_by(order_by_field) - serializer_class = FacilityUserSerializer + serializer_class = DeletedFacilityUserSerializer filterset_class = FacilityUserFilter search_fields = FacilityUserViewSet.search_fields - values = FacilityUserViewSet.values + ("date_deleted",) ordering_fields = FacilityUserViewSet.ordering_fields + ("date_deleted",) - field_map = FacilityUserViewSet.field_map @decorators.action(detail=False, methods=["post"]) def restore(self, request): diff --git a/kolibri/core/auth/serializers.py b/kolibri/core/auth/serializers.py index 6cdaea89a37..7f1c4ccc018 100644 --- a/kolibri/core/auth/serializers.py +++ b/kolibri/core/auth/serializers.py @@ -144,8 +144,22 @@ def validate(self, attrs): return attrs +class FacilityUserRoleSerializer(serializers.ModelSerializer): + """Read-only role serializer for FacilityUser API responses. + + Excludes 'user' since it's redundant when nested inside a user response. + """ + + class Meta: + model = Role + fields = ("id", "kind", "collection") + + class FacilityUserSerializer(serializers.ModelSerializer): - roles = RoleSerializer(many=True, read_only=True) + roles = FacilityUserRoleSerializer(many=True, read_only=True) + is_superuser = serializers.BooleanField( + source="devicepermissions.is_superuser", default=False, read_only=True + ) facility = serializers.PrimaryKeyRelatedField( queryset=Facility.objects.all(), default=Facility.get_default_facility, @@ -170,6 +184,7 @@ class Meta: "birth_year", "extra_demographics", "picture_password", + "date_joined", ) read_only_fields = ("is_superuser", "picture_password") @@ -249,6 +264,34 @@ def validate(self, attrs): ) +class DeletedFacilityUserSerializer(FacilityUserSerializer): + class Meta(FacilityUserSerializer.Meta): + fields = FacilityUserSerializer.Meta.fields + ("date_deleted",) + + +class PublicFacilityUserSerializer(serializers.ModelSerializer): + """Read-only serializer for the public (device-to-device) user API.""" + + roles = serializers.CharField(source="roles.kind", read_only=True) + is_superuser = serializers.BooleanField( + source="devicepermissions.is_superuser", default=False, read_only=True + ) + + class Meta: + model = FacilityUser + fields = ( + "id", + "username", + "full_name", + "facility", + "roles", + "is_superuser", + "id_number", + "gender", + "birth_year", + ) + + class MembershipListSerializer(serializers.ListSerializer): def validate(self, attrs): lg_items = [] diff --git a/kolibri/core/auth/test/test_api.py b/kolibri/core/auth/test/test_api.py index a64181bef05..d61b8dbcf39 100644 --- a/kolibri/core/auth/test/test_api.py +++ b/kolibri/core/auth/test/test_api.py @@ -673,6 +673,42 @@ def test_public_facilityuser_endpoint(self): item["facility"], ) + def test_public_facilityuser_roles_are_flat_strings(self): + """Roles must be a flat list of kind strings, not nested objects. + + Consumers (RemoteFacilityUserAuthenticatedViewset, peer import + validation, frontend JS) do ``role in roles`` checks against plain + strings like "admin", so the shape must stay ["admin", ...]. + """ + # Give user1 an admin role on the facility + models.Role.objects.create( + user=self.user1, collection=self.facility1, kind="admin" + ) + credentials = base64.b64encode( + str.encode( + "username={}&{}={}:{}".format( + self.superuser.username, + FACILITY_CREDENTIAL_KEY, + self.facility1.id, + DUMMY_PASSWORD, + ) + ) + ).decode("ascii") + self.client.credentials(HTTP_AUTHORIZATION="Basic {}".format(credentials)) + response = self.client.get( + reverse("kolibri:core:publicuser-list"), + {"facility_id": self.facility1.id}, + format="json", + ) + user1_data = next(u for u in response.data if u["id"] == self.user1.id) + self.assertEqual(user1_data["roles"], ["admin"]) + # Regular user with no roles gets an empty list + user2_data = next( + (u for u in response.data if u["roles"] == []), + None, + ) + self.assertIsNotNone(user2_data) + def test_create_new_facility_non_superuser_permission_denied(self): self.client.login( username=self.user1.username, diff --git a/kolibri/core/test/test_api.py b/kolibri/core/test/test_api.py new file mode 100644 index 00000000000..54260b23b70 --- /dev/null +++ b/kolibri/core/test/test_api.py @@ -0,0 +1,1640 @@ +import datetime +from typing import Type +from unittest.mock import MagicMock + +from django.db.models import Model +from django.test import override_settings +from django.test import TestCase +from django.utils import timezone +from parameterized import parameterized +from rest_framework import serializers + +from kolibri.core.api import BaseValuesViewset +from kolibri.core.api import ListModelMixin +from kolibri.core.api import ValuesMethodField +from kolibri.core.api import ValuesViewsetOrderingFilter +from kolibri.core.test.test_app.models import Author +from kolibri.core.test.test_app.models import Book +from kolibri.core.test.test_app.models import Classroom +from kolibri.core.test.test_app.models import DateTimeTzModel +from kolibri.core.test.test_app.models import Enrollment +from kolibri.core.test.test_app.models import Profile +from kolibri.core.test.test_app.models import Publisher +from kolibri.core.test.test_app.models import Tag + + +def create_mock_queryset(flat_items, model: Type[Model] = Author): + """Mock queryset that returns flat_items from .values().""" + mock_qs = MagicMock() + mock_qs.model = model + mock_qs.values.return_value = flat_items + return mock_qs + + +def make_serializer(model: Type[Model] = Author, **fields): + """Create a ModelSerializer class dynamically. Returns a CLASS.""" + meta = type("Meta", (), {"model": model, "fields": list(fields.keys())}) + attrs: dict = dict(fields) + attrs["Meta"] = meta + return type("DynamicSerializer", (serializers.ModelSerializer,), attrs) + + +def make_nested( + model: Type[Model] = Author, + many=False, + source=None, + allow_null=False, + **fields, +): + """Create a nested serializer INSTANCE for embedding in another serializer.""" + child_cls = make_serializer(model=model, **fields) + kwargs: dict = {} + if many: + kwargs["many"] = True + if source: + kwargs["source"] = source + if allow_null: + kwargs["allow_null"] = True + return child_cls(**kwargs) + + +def make_viewset( + serializer_class=None, + model: Type[Model] = Author, + queryset=None, + deferred_fields=(), + **fields, +): + """Create a viewset INSTANCE. Builds serializer from **fields if none provided.""" + if serializer_class is None: + serializer_class = make_serializer(model=model, **fields) + if queryset is None: + queryset = model.objects.none() + attrs: dict = {"queryset": queryset, "serializer_class": serializer_class} + if deferred_fields: + attrs["deferred_fields"] = deferred_fields + cls = type("DynamicViewset", (BaseValuesViewset, ListModelMixin), attrs) + return cls() + + +BookSerializer = make_serializer( + model=Book, id=serializers.CharField(), title=serializers.CharField() +) +TagSerializer = make_serializer( + model=Tag, id=serializers.CharField(), name=serializers.CharField() +) +ClassroomSerializer = make_serializer( + model=Classroom, id=serializers.CharField(), name=serializers.CharField() +) + + +def author_books_viewset(deferred=False, **extra_author_fields): + """Author(id) + books(many=True) reverse FK with id+title.""" + return make_viewset( + id=serializers.CharField(), + books=make_nested( + model=Book, + many=True, + id=serializers.CharField(), + title=serializers.CharField(), + ), + deferred_fields=("books",) if deferred else (), + **extra_author_fields, + ) + + +def _serialize(viewset, flat_items, **kwargs): + """Shortcut: create mock queryset and serialize in one call.""" + mock_qs = create_mock_queryset(flat_items, **kwargs) + return viewset.serialize(mock_qs) + + +def _assert_serialize_raises(test_case, viewset, flat_items, expected_substr): + """Assert that serialize() raises ValueError containing expected_substr.""" + mock_qs = create_mock_queryset(flat_items) + with test_case.assertRaises(ValueError) as ctx: + viewset.serialize(mock_qs) + test_case.assertIn(expected_substr, str(ctx.exception)) + + +class TestDataSerialization(TestCase): + """Integration: ``viewset.serialize()`` over real Django querysets. + + Covers the field-level contract (rename, type inference, default, dot + notation, write-only, PK-related, choice, FK traversal, + ``ValuesMethodField``), every relation shape (FK, OneToOne fwd/rev, + reverse FK, direct M2M fwd/rev, M2M-through, scalar-many across each), + and the row-merging invariants (grouping, dedup, ordering preservation, + null-join handling, scalar-many collection) against real database rows. + + A shared ``setUpTestData`` fixture covers every relation type: + + - ``alice`` — publisher + profile + 3 books (``book_a3`` has null + description) + 2 classrooms via Enrollment + - ``bob`` — publisher + profile + 1 book, no classrooms + - ``carol`` — orphan: no publisher, no profile, no books, no classrooms + - Tags ``fiction`` + ``classic`` (``book_a1`` has both, ``book_a2`` has + ``fiction``) + - Alice's books × classrooms produces the cartesian needed for dedup tests + """ + + @classmethod + def setUpTestData(cls): + cls.main_publisher = Publisher.objects.create(name="Main House") + + cls.alice = Author.objects.create( + name="Alice", + email="alice@example.com", + publisher=cls.main_publisher, + ) + cls.bob = Author.objects.create( + name="Bob", + email="bob@example.com", + publisher=cls.main_publisher, + ) + cls.carol = Author.objects.create( + name="Carol", + email="carol@example.com", + publisher=None, + ) + + cls.alice_profile = Profile.objects.create( + author=cls.alice, bio="SF writer", is_verified=True + ) + cls.bob_profile = Profile.objects.create( + author=cls.bob, bio="Poet", is_verified=False + ) + + cls.book_a1 = Book.objects.create(author=cls.alice, title="Alice Book 1") + cls.book_a2 = Book.objects.create(author=cls.alice, title="Alice Book 2") + cls.book_a3 = Book.objects.create( + author=cls.alice, title="Alice Book 3", description=None + ) + cls.book_b1 = Book.objects.create(author=cls.bob, title="Bob Book 1") + + cls.tag_fiction = Tag.objects.create(name="fiction") + cls.tag_classic = Tag.objects.create(name="classic") + cls.book_a1.tags.add(cls.tag_fiction, cls.tag_classic) + cls.book_a2.tags.add(cls.tag_fiction) + + cls.classroom_101 = Classroom.objects.create(name="Room 101") + cls.classroom_102 = Classroom.objects.create(name="Room 102") + Enrollment.objects.create(author=cls.alice, classroom=cls.classroom_101) + Enrollment.objects.create(author=cls.alice, classroom=cls.classroom_102) + + def _run(self, viewset): + """Run the viewset's own queryset through serialize().""" + return viewset.serialize(viewset.get_queryset()) + + # Field contract + + def test_flat_field_rename(self): + """Field with source != name: output uses declared name, source is removed.""" + viewset = make_viewset( + queryset=Author.objects.filter(pk=self.alice.pk), + author_id=serializers.UUIDField(source="id"), + display_name=serializers.CharField(source="name"), + ) + result = self._run(viewset) + self.assertEqual( + result[0], + {"author_id": self.alice.pk, "display_name": "Alice"}, + ) + + def test_matching_field_type_skips_transform(self): + """Declared CharField on CharField model — simple rename, no to_representation.""" + viewset = make_viewset( + queryset=Author.objects.filter(pk=self.alice.pk), + display_name=serializers.CharField(source="name"), + ) + result = self._run(viewset) + self.assertEqual(result[0], {"display_name": "Alice"}) + + def test_mismatched_field_type_applies_transform(self): + """Declared field type differs from inferred — to_representation is called.""" + + class UppercaseField(serializers.CharField): + def to_representation(self, value): + return value.upper() if value else value + + viewset = make_viewset( + queryset=Author.objects.filter(pk=self.alice.pk), + display_name=UppercaseField(source="name"), + ) + result = self._run(viewset) + self.assertEqual(result[0]["display_name"], "ALICE") + + def test_custom_field_in_nested_child_applies_transform(self): + """Custom to_representation on a nested child field is applied.""" + + class UppercaseField(serializers.CharField): + def to_representation(self, value): + return value.upper() if value else value + + BookSer = make_serializer( + model=Book, + id=serializers.IntegerField(), + loud_title=UppercaseField(source="title"), + ) + viewset = make_viewset( + queryset=Author.objects.filter(pk=self.alice.pk), + id=serializers.UUIDField(), + books=BookSer(many=True), + ) + result = self._run(viewset) + loud_titles = {b["loud_title"] for b in result[0]["books"]} + self.assertEqual( + loud_titles, + {"ALICE BOOK 1", "ALICE BOOK 2", "ALICE BOOK 3"}, + ) + + def test_default_used_when_value_is_none(self): + """Field with a declared default substitutes for None raw values. + + Covers LEFT-JOIN misses on OneToOne fields: alice has a Profile + (``is_verified=True``); carol has none, so the joined column comes + back as ``None`` and the declared ``default=False`` kicks in. + """ + viewset = make_viewset( + queryset=Author.objects.filter(pk__in=[self.alice.pk, self.carol.pk]), + id=serializers.UUIDField(), + is_verified=serializers.BooleanField( + source="profile.is_verified", + default=False, + ), + ) + result = self._run(viewset) + by_id = {r["id"]: r["is_verified"] for r in result} + self.assertEqual(by_id[self.alice.pk], True) + self.assertEqual(by_id[self.carol.pk], False) + + def test_dot_notation_source_converted_to_underscore(self): + """DRF dot-notation source is converted to ``__`` at the Django + boundary, for both flat and nested-child fields. A successful + result with declared-name keys implies the conversion worked.""" + viewset = make_viewset( + queryset=Author.objects.filter(pk=self.alice.pk), + id=serializers.UUIDField(), + publisher_name=serializers.CharField(source="publisher.name"), + books=make_nested( + model=Book, + many=True, + id=serializers.IntegerField(), + author_name=serializers.CharField(source="author.name"), + ), + ) + result = self._run(viewset) + self.assertEqual(result[0]["publisher_name"], "Main House") + for book in result[0]["books"]: + self.assertEqual(book["author_name"], "Alice") + + def test_write_only_field_excluded_from_output(self): + """write_only fields are neither fetched nor present in output.""" + viewset = make_viewset( + queryset=Author.objects.filter(pk=self.alice.pk), + id=serializers.UUIDField(), + email=serializers.CharField(write_only=True), + ) + result = self._run(viewset) + self.assertEqual(result[0], {"id": self.alice.pk}) + + def test_none_value_not_passed_to_transform(self): + """None raw values short-circuit both rename and in-place transform paths. + + ``book_a3.description`` is ``None`` — a custom ``to_representation`` + that would raise on ``None`` must never be called. + """ + + class FailOnNoneField(serializers.Field): + def to_representation(self, value): + if value is None: + raise ValueError("received None!") + return str(value).upper() + + # Rename path: source != declared name. + rename_viewset = make_viewset( + queryset=Book.objects.filter(pk=self.book_a3.pk), + id=serializers.IntegerField(), + loud_desc=FailOnNoneField(source="description"), + ) + result = self._run(rename_viewset) + self.assertIsNone(result[0]["loud_desc"]) + + # In-place path: source == declared name. + in_place_viewset = make_viewset( + queryset=Book.objects.filter(pk=self.book_a3.pk), + id=serializers.IntegerField(), + description=FailOnNoneField(), + ) + result = self._run(in_place_viewset) + self.assertIsNone(result[0]["description"]) + + def test_primary_key_related_field_on_fk_passes_through(self): + """PrimaryKeyRelatedField on a FK model field passes the raw PK through. + + ``values()`` returns the raw FK value (e.g. a UUID string). + ``PrimaryKeyRelatedField.to_representation`` expects a model + instance, so the introspection must recognize the match and skip + the transform. + """ + viewset = make_viewset( + queryset=Author.objects.filter(pk=self.alice.pk), + id=serializers.UUIDField(), + publisher=serializers.PrimaryKeyRelatedField( + queryset=Publisher.objects.all(), + ), + ) + result = self._run(viewset) + self.assertEqual(result[0]["publisher"], self.main_publisher.pk) + + def test_choice_field_serialization(self): + """ChoiceField with declared choices passes raw values through.""" + retired = Author.objects.create( + name="Retired", email="retired@example.com", status="retired" + ) + blank = Author.objects.create( + name="Blank", email="blank@example.com", status="" + ) + viewset = make_viewset( + queryset=Author.objects.filter( + pk__in=[self.alice.pk, retired.pk, blank.pk] + ), + id=serializers.UUIDField(), + author_status=serializers.ChoiceField( + source="status", + choices=[("", ""), ("active", "Active"), ("retired", "Retired")], + ), + ) + result = self._run(viewset) + by_id = {r["id"]: r["author_status"] for r in result} + self.assertEqual(by_id[self.alice.pk], "active") + self.assertEqual(by_id[retired.pk], "retired") + self.assertEqual(by_id[blank.pk], "") + + def test_boolean_field_on_boolean_model_field_matches(self): + """BooleanField on a BooleanField model field — no transform applied.""" + viewset = make_viewset( + queryset=Author.objects.filter(pk=self.alice.pk), + active=serializers.BooleanField(source="is_active"), + ) + result = self._run(viewset) + self.assertEqual(result[0]["active"], True) + + def test_uuid_field_on_uuid_model_field_passes_through(self): + """UUIDField on Author.id (UUIDField) — raw UUID passes through.""" + viewset = make_viewset( + queryset=Author.objects.filter(pk=self.alice.pk), + author_id=serializers.UUIDField(source="id"), + name=serializers.CharField(), + ) + result = self._run(viewset) + self.assertEqual(result[0]["author_id"], self.alice.pk) + + def test_plain_serializer_method_field_rejected(self): + """Plain ``SerializerMethodField`` is not supported on ValuesViewset; + class init raises ``TypeError`` pointing at ``ValuesMethodField``.""" + + class S(serializers.ModelSerializer): + label = serializers.SerializerMethodField() + + def get_label(self, instance): + return "x" + + class Meta: + model = Author + fields = ("id", "label") + + V = type( + "V", + (BaseValuesViewset, ListModelMixin), + {"serializer_class": S, "queryset": Author.objects.none()}, + ) + with self.assertRaises(TypeError) as ctx: + V() + self.assertIn("ValuesMethodField", str(ctx.exception)) + + def test_method_field_sees_python_value_not_serialized_form(self): + """The proxy exposes the Python value a bound method would see in + vanilla DRF (e.g. a ``datetime``) — not the post-``to_representation`` + form (ISO-8601 string). The method can only return a year int if it + received an actual ``datetime``. + """ + aware = timezone.get_current_timezone().localize( + datetime.datetime(2026, 4, 23, 10, 30, 0) + ) + dtm = DateTimeTzModel.objects.create(timestamp=aware) + + class S(serializers.ModelSerializer): + id = serializers.IntegerField() + timestamp = serializers.DateTimeField() + timestamp_year = ValuesMethodField(sources=("timestamp",)) + + def get_timestamp_year(self, obj): + return obj.timestamp.year + + class Meta: + model = DateTimeTzModel + fields = ("id", "timestamp", "timestamp_year") + + viewset = make_viewset( + serializer_class=S, + queryset=DateTimeTzModel.objects.filter(pk=dtm.pk), + ) + result = self._run(viewset) + self.assertEqual(result[0]["timestamp_year"], 2026) + + def test_method_field_proxy_raises_on_undeclared_attribute(self): + """Proxy access to an attribute not in ``sources`` raises + ``AttributeError``; the message names the requested attribute and + surfaces the declared sources so the boundary is discoverable.""" + + class S(serializers.ModelSerializer): + id = serializers.UUIDField() + label = ValuesMethodField(sources=("name",)) + + def get_label(self, obj): + return obj.email # not declared + + class Meta: + model = Author + fields = ("id", "label") + + viewset = make_viewset( + serializer_class=S, + queryset=Author.objects.filter(pk=self.alice.pk), + ) + with self.assertRaises(AttributeError) as ctx: + self._run(viewset) + message = str(ctx.exception) + self.assertIn("email", message) + self.assertIn("name", message) + + def test_method_field_empty_sources_invokes_method(self): + """``sources=()`` still invokes the bound method — useful for + constant-returning computations (e.g. reading a global setting).""" + + class S(serializers.ModelSerializer): + id = serializers.UUIDField() + constant = ValuesMethodField(sources=()) + + def get_constant(self, obj): + return "always-same" + + class Meta: + model = Author + fields = ("id", "constant") + + viewset = make_viewset( + serializer_class=S, + queryset=Author.objects.filter(pk=self.alice.pk), + ) + result = self._run(viewset) + self.assertEqual(result[0]["constant"], "always-same") + + def test_nested_non_model_serializer_treated_as_regular_field(self): + """A nested plain ``Serializer`` over a JSONField column is treated + as a regular field: its ``to_representation`` runs on the raw dict, + so undeclared keys are dropped from the output.""" + + class MetadataSerializer(serializers.Serializer): + a = serializers.CharField() + b = serializers.CharField() + c = serializers.CharField() + + self.alice.metadata = {"a": "alpha", "b": "beta", "c": "gamma", "d": "delta"} + self.alice.save() + + viewset = make_viewset( + queryset=Author.objects.filter(pk=self.alice.pk), + id=serializers.UUIDField(), + metadata=MetadataSerializer(), + ) + result = self._run(viewset) + self.assertEqual( + result[0]["metadata"], + {"a": "alpha", "b": "beta", "c": "gamma"}, + ) + + # Relation shapes + + def test_fk_single_nested(self): + viewset = make_viewset( + queryset=Book.objects.filter(pk=self.book_a1.pk), + id=serializers.IntegerField(), + title=serializers.CharField(), + author=make_nested( + model=Author, + id=serializers.UUIDField(), + name=serializers.CharField(), + ), + ) + result = self._run(viewset) + self.assertEqual(len(result), 1) + self.assertEqual(result[0]["author"]["name"], "Alice") + + def test_one_to_one_forward_single_nested(self): + viewset = make_viewset( + model=Profile, + queryset=Profile.objects.filter(pk=self.alice_profile.pk), + id=serializers.IntegerField(), + bio=serializers.CharField(), + author=make_nested( + model=Author, + id=serializers.UUIDField(), + name=serializers.CharField(), + ), + ) + result = self._run(viewset) + self.assertEqual(result[0]["author"]["name"], "Alice") + + def test_one_to_one_reverse_single_nested(self): + viewset = make_viewset( + queryset=Author.objects.filter(pk=self.alice.pk), + id=serializers.UUIDField(), + profile=make_nested( + model=Profile, + allow_null=True, + id=serializers.IntegerField(), + bio=serializers.CharField(), + ), + ) + result = self._run(viewset) + self.assertEqual(result[0]["profile"]["bio"], "SF writer") + + def test_reverse_fk_many_nested(self): + viewset = make_viewset( + queryset=Author.objects.filter(pk=self.alice.pk), + id=serializers.UUIDField(), + books=make_nested( + model=Book, + many=True, + id=serializers.IntegerField(), + title=serializers.CharField(), + ), + ) + result = self._run(viewset) + titles = sorted(b["title"] for b in result[0]["books"]) + self.assertEqual(titles, ["Alice Book 1", "Alice Book 2", "Alice Book 3"]) + + def test_m2m_direct_forward_many_nested(self): + viewset = make_viewset( + model=Book, + queryset=Book.objects.filter(pk=self.book_a1.pk), + id=serializers.IntegerField(), + tags=make_nested( + model=Tag, + many=True, + id=serializers.IntegerField(), + name=serializers.CharField(), + ), + ) + result = self._run(viewset) + names = sorted(t["name"] for t in result[0]["tags"]) + self.assertEqual(names, ["classic", "fiction"]) + + def test_m2m_direct_reverse_many_nested(self): + viewset = make_viewset( + model=Tag, + queryset=Tag.objects.filter(pk=self.tag_fiction.pk), + id=serializers.IntegerField(), + books=make_nested( + model=Book, + many=True, + id=serializers.IntegerField(), + title=serializers.CharField(), + ), + ) + result = self._run(viewset) + titles = sorted(b["title"] for b in result[0]["books"]) + self.assertEqual(titles, ["Alice Book 1", "Alice Book 2"]) + + def test_m2m_through_many_nested(self): + viewset = make_viewset( + queryset=Author.objects.filter(pk=self.alice.pk), + id=serializers.UUIDField(), + classrooms=make_nested( + model=Classroom, + many=True, + id=serializers.IntegerField(), + name=serializers.CharField(), + ), + ) + result = self._run(viewset) + names = sorted(c["name"] for c in result[0]["classrooms"]) + self.assertEqual(names, ["Room 101", "Room 102"]) + + def test_scalar_many_via_reverse_fk(self): + viewset = make_viewset( + queryset=Author.objects.filter(pk=self.alice.pk), + id=serializers.UUIDField(), + book_titles=serializers.CharField(source="books.title"), + ) + result = self._run(viewset) + self.assertEqual( + sorted(result[0]["book_titles"]), + ["Alice Book 1", "Alice Book 2", "Alice Book 3"], + ) + + def test_scalar_many_via_m2m_direct(self): + """Exercises the ``many_to_many`` branch of _source_crosses_many_relation.""" + viewset = make_viewset( + model=Book, + queryset=Book.objects.filter(pk=self.book_a1.pk), + id=serializers.IntegerField(), + tag_names=serializers.CharField(source="tags.name"), + ) + result = self._run(viewset) + self.assertEqual(sorted(result[0]["tag_names"]), ["classic", "fiction"]) + + def test_scalar_many_via_m2m_through(self): + viewset = make_viewset( + queryset=Author.objects.filter(pk=self.alice.pk), + id=serializers.UUIDField(), + classroom_names=serializers.CharField(source="classrooms.name"), + ) + result = self._run(viewset) + self.assertEqual(sorted(result[0]["classroom_names"]), ["Room 101", "Room 102"]) + + # Consolidation invariants + + def test_many_rows_same_parent_merge(self): + """Multiple books for one author collapse into a single output row.""" + viewset = make_viewset( + queryset=Author.objects.filter(pk=self.alice.pk), + id=serializers.UUIDField(), + books=make_nested( + model=Book, + many=True, + id=serializers.IntegerField(), + title=serializers.CharField(), + ), + ) + result = self._run(viewset) + self.assertEqual(len(result), 1) + self.assertEqual(len(result[0]["books"]), 3) + + def test_null_many_nested_produces_empty_list(self): + """Author with no related books yields [] — LEFT JOIN miss.""" + viewset = make_viewset( + queryset=Author.objects.filter(pk=self.carol.pk), + id=serializers.UUIDField(), + books=make_nested( + model=Book, + many=True, + id=serializers.IntegerField(), + title=serializers.CharField(), + ), + ) + result = self._run(viewset) + self.assertEqual(result[0]["books"], []) + + def test_null_single_fk_produces_null(self): + """Author with no publisher — single-nested FK yields None.""" + viewset = make_viewset( + queryset=Author.objects.filter(pk=self.carol.pk), + id=serializers.UUIDField(), + publisher_info=make_nested( + model=Publisher, + source="publisher", + allow_null=True, + id=serializers.IntegerField(), + name=serializers.CharField(), + ), + ) + result = self._run(viewset) + self.assertIsNone(result[0]["publisher_info"]) + + def test_nullable_first_field_in_nested_not_dropped(self): + """Nested row with null first declared field but non-null PK is kept. + + ``book_a3`` has ``description=None`` but a valid id. Declaring + ``description`` as the first nested field must not cause the nested + dict to be dropped — the null-check uses the PK, not field order. + """ + BookSer = make_serializer( + model=Book, + description=serializers.CharField(allow_null=True), + id=serializers.IntegerField(), + ) + viewset = make_viewset( + queryset=Author.objects.filter(pk=self.alice.pk), + id=serializers.UUIDField(), + books=BookSer(many=True), + ) + result = self._run(viewset) + self.assertEqual(len(result[0]["books"]), 3) + book_a3 = next(b for b in result[0]["books"] if b["id"] == self.book_a3.pk) + self.assertIsNone(book_a3["description"]) + + def test_duplicate_child_rows_deduplicated(self): + """Cartesian rows (books × classrooms) collapse to actual child counts. + + Alice has 3 books × 2 classrooms = 6 rows via the cartesian, but + the nested list must dedupe to 3 books and the scalar-many to 2 + classroom names. + """ + viewset = make_viewset( + queryset=Author.objects.filter(pk=self.alice.pk), + id=serializers.UUIDField(), + books=make_nested( + model=Book, + many=True, + id=serializers.IntegerField(), + title=serializers.CharField(), + ), + classroom_names=serializers.CharField(source="classrooms.name"), + ) + result = self._run(viewset) + self.assertEqual(len(result[0]["books"]), 3) + self.assertEqual(len(result[0]["classroom_names"]), 2) + + def test_queryset_ordering_preserved(self): + """Output order matches queryset order, not PK order (groupby sorts by PK).""" + viewset = make_viewset( + queryset=Author.objects.order_by("-email"), + id=serializers.UUIDField(), + email=serializers.CharField(), + ) + result = self._run(viewset) + self.assertEqual( + [item["email"] for item in result], + ["carol@example.com", "bob@example.com", "alice@example.com"], + ) + + def test_scalar_many_deduplicates_values(self): + """Duplicates from a cartesian collapse to unique scalar entries. + + Joining both books and classrooms for Alice creates a cartesian + where each book title appears twice (once per classroom). Scalar-many + dedup collapses back to 3 unique titles. + """ + viewset = make_viewset( + queryset=Author.objects.filter(pk=self.alice.pk), + id=serializers.UUIDField(), + book_titles=serializers.CharField(source="books.title"), + classroom_names=serializers.CharField(source="classrooms.name"), + ) + result = self._run(viewset) + self.assertEqual(len(result[0]["book_titles"]), 3) + self.assertEqual(len(result[0]["classroom_names"]), 2) + + def test_scalar_many_null_produces_empty_list(self): + """Scalar-many with no related rows yields [].""" + viewset = make_viewset( + queryset=Author.objects.filter(pk=self.carol.pk), + id=serializers.UUIDField(), + book_titles=serializers.CharField(source="books.title"), + ) + result = self._run(viewset) + self.assertEqual(result[0]["book_titles"], []) + + def test_method_field_excludes_unshared_source(self): + """A source referenced only by ``ValuesMethodField`` is fetched into + ``values()`` but does not appear in the serialized output row.""" + + class S(serializers.ModelSerializer): + id = serializers.UUIDField() + label = ValuesMethodField(sources=("name",)) + + def get_label(self, obj): + return "label: {}".format(obj.name) + + class Meta: + model = Author + fields = ("id", "label") + + viewset = make_viewset( + serializer_class=S, + queryset=Author.objects.filter(pk=self.alice.pk), + ) + result = self._run(viewset) + self.assertEqual(set(result[0].keys()), {"id", "label"}) + self.assertEqual(result[0]["label"], "label: Alice") + + def test_method_field_keeps_shared_source_under_declared_name(self): + """When the method's source is also a declared field, it stays in + output under its declared name.""" + + class S(serializers.ModelSerializer): + id = serializers.UUIDField() + name = serializers.CharField() + label = ValuesMethodField(sources=("name",)) + + def get_label(self, obj): + return "label: {}".format(obj.name) + + class Meta: + model = Author + fields = ("id", "name", "label") + + viewset = make_viewset( + serializer_class=S, + queryset=Author.objects.filter(pk=self.alice.pk), + ) + result = self._run(viewset) + self.assertEqual(result[0]["name"], "Alice") + self.assertEqual(result[0]["label"], "label: Alice") + + def test_method_field_reads_dotted_source_from_fk(self): + """``sources=('publisher.name',)`` fetches ``publisher__name`` and the + proxy walks it as ``obj.publisher.name``.""" + + class S(serializers.ModelSerializer): + id = serializers.UUIDField() + publisher_label = ValuesMethodField(sources=("publisher.name",)) + + def get_publisher_label(self, obj): + return "pub: {}".format(obj.publisher.name) + + class Meta: + model = Author + fields = ("id", "publisher_label") + + viewset = make_viewset( + serializer_class=S, + queryset=Author.objects.filter(pk=self.alice.pk), + ) + result = self._run(viewset) + self.assertEqual(set(result[0].keys()), {"id", "publisher_label"}) + self.assertEqual(result[0]["publisher_label"], "pub: Main House") + + def test_method_field_reads_context_from_request(self): + """The bound method's ``self.context`` is populated per-request from + ``viewset.get_serializer_context()``.""" + + class S(serializers.ModelSerializer): + id = serializers.UUIDField() + ctx_label = ValuesMethodField(sources=("name",)) + + def get_ctx_label(self, obj): + hint = self.context.get("hint", "missing") + return "{}/{}".format(obj.name, hint) + + class Meta: + model = Author + fields = ("id", "ctx_label") + + viewset = make_viewset( + serializer_class=S, + queryset=Author.objects.filter(pk=self.alice.pk), + ) + viewset.get_serializer_context = lambda: {"hint": "yo"} + result = self._run(viewset) + self.assertEqual(result[0]["ctx_label"], "Alice/yo") + + def test_serialize_queryset_group_by_returns_dict(self): + """``serialize_queryset`` with ``group_by`` returns a dict keyed by + the group column's value. The grouping column must be declared on + the nested serializer — only declared fields reach the output.""" + viewset = make_viewset( + id=serializers.UUIDField(), + books=make_nested( + model=Book, + many=True, + id=serializers.IntegerField(), + title=serializers.CharField(), + author=serializers.PrimaryKeyRelatedField(read_only=True), + ), + deferred_fields=("books",), + ) + result = viewset.serialize_queryset( + Book.objects.all(), "books", group_by="author" + ) + self.assertEqual(set(result.keys()), {self.alice.pk, self.bob.pk}) + self.assertEqual(len(result[self.alice.pk]), 3) + self.assertEqual(len(result[self.bob.pk]), 1) + + def test_serialize_queryset_consolidates_grand_nested_many(self): + """``serialize_queryset`` for a path whose nested serializer itself + has a ``many=True`` child must merge the JOIN-multiplied rows into + per-parent lists, mirroring ``serialize()``'s consolidation. Without + consolidation, ``book_a1`` (two tags) would appear twice in output. + """ + TagSer = make_serializer( + model=Tag, + id=serializers.IntegerField(), + name=serializers.CharField(), + ) + BookSer = make_serializer( + model=Book, + id=serializers.IntegerField(), + title=serializers.CharField(), + tags=TagSer(many=True), + ) + viewset = make_viewset( + serializer_class=make_serializer( + id=serializers.UUIDField(), + books=BookSer(many=True), + ), + deferred_fields=("books",), + ) + result = viewset.serialize_queryset(Book.objects.all(), "books") + by_id = {r["id"]: r for r in result} + # One row per book — book_a1 must not be duplicated by its 2 tags. + self.assertEqual(len(result), 4) + self.assertEqual( + set(by_id), + {self.book_a1.pk, self.book_a2.pk, self.book_a3.pk, self.book_b1.pk}, + ) + self.assertEqual( + sorted(t["name"] for t in by_id[self.book_a1.pk]["tags"]), + ["classic", "fiction"], + ) + self.assertEqual( + [t["name"] for t in by_id[self.book_a2.pk]["tags"]], ["fiction"] + ) + self.assertEqual(by_id[self.book_a3.pk]["tags"], []) + self.assertEqual(by_id[self.book_b1.pk]["tags"], []) + + def test_serialize_queryset_passes_context_to_nested_method_field(self): + """A ``ValuesMethodField`` on a nested serializer must read + per-request context via ``self.context`` when reached through + ``serialize_queryset``. The context flows in through the cached + parent's threading-local ``_context``: ``Field.context`` walks + ``self.root._context``, so the nested bound method sees the same + dict the scope manager populated for this request. + """ + + class BookSer(serializers.ModelSerializer): + id = serializers.IntegerField() + title_with_hint = ValuesMethodField(sources=("title",)) + + def get_title_with_hint(self, obj): + hint = self.context.get("hint", "missing") + return "{}/{}".format(obj.title, hint) + + class Meta: + model = Book + fields = ("id", "title_with_hint") + + class S(serializers.ModelSerializer): + id = serializers.UUIDField() + books = BookSer(many=True) + + class Meta: + model = Author + fields = ("id", "books") + + viewset = make_viewset(serializer_class=S, deferred_fields=("books",)) + viewset.get_serializer_context = lambda: {"hint": "yo"} + result = viewset.serialize_queryset( + Book.objects.filter(pk=self.book_a1.pk), "books" + ) + self.assertEqual(result[0]["title_with_hint"], "Alice Book 1/yo") + + def test_nested_path_deferred_with_consolidate(self): + """Full pipeline: ``Publisher`` → ``authors`` (deferred at top), + ``AuthorSer`` has ``books`` (joined inside the authors query) and + ``enrollments`` (deferred deeper via ``authors__enrollments``). + ``consolidate`` batches both deferred layers via ``group_by`` + so the whole result is two extra queries — no N+1. + """ + BookSer = make_serializer( + model=Book, + id=serializers.IntegerField(), + title=serializers.CharField(), + ) + EnrollmentSer = make_serializer( + model=Enrollment, + id=serializers.IntegerField(), + classroom=serializers.PrimaryKeyRelatedField(read_only=True), + author=serializers.PrimaryKeyRelatedField(read_only=True), + ) + AuthorSer = make_serializer( + model=Author, + id=serializers.UUIDField(), + publisher=serializers.PrimaryKeyRelatedField(read_only=True), + books=BookSer(many=True), + enrollments=EnrollmentSer(many=True), + ) + Ser = make_serializer( + model=Publisher, + id=serializers.IntegerField(), + name=serializers.CharField(), + authors=AuthorSer(many=True), + ) + + class V(BaseValuesViewset, ListModelMixin): + queryset = Publisher.objects.filter(pk=self.main_publisher.pk) + serializer_class = Ser + deferred_fields = ("authors", "authors__enrollments") + + def consolidate(self, items, queryset): + if not items: + return items + pub_ids = [p["id"] for p in items] + authors_by_pub = self.serialize_queryset( + Author.objects.filter(publisher_id__in=pub_ids), + "authors", + group_by="publisher", + ) + author_ids = [ + a["id"] for authors in authors_by_pub.values() for a in authors + ] + enrollments_by_author = self.serialize_queryset( + Enrollment.objects.filter(author_id__in=author_ids), + "authors__enrollments", + group_by="author", + ) + for pub in items: + pub_authors = authors_by_pub.get(pub["id"], []) + for author in pub_authors: + author["enrollments"] = enrollments_by_author.get( + author["id"], [] + ) + pub["authors"] = pub_authors + return items + + result = V().serialize(V().queryset) + self.assertEqual(len(result), 1) + pub = result[0] + self.assertEqual(pub["name"], "Main House") + + authors_by_id = {a["id"]: a for a in pub["authors"]} + self.assertEqual(set(authors_by_id), {self.alice.pk, self.bob.pk}) + + # alice: 3 books joined inside authors, 2 enrollments deferred deeper. + alice = authors_by_id[self.alice.pk] + self.assertEqual( + sorted(b["title"] for b in alice["books"]), + ["Alice Book 1", "Alice Book 2", "Alice Book 3"], + ) + self.assertEqual( + sorted(e["classroom"] for e in alice["enrollments"]), + [self.classroom_101.pk, self.classroom_102.pk], + ) + + # bob: 1 book, 0 enrollments. + bob = authors_by_id[self.bob.pk] + self.assertEqual([b["title"] for b in bob["books"]], ["Bob Book 1"]) + self.assertEqual(bob["enrollments"], []) + + +class TestLegacyViewset(TestCase): + """Viewsets using the pre-serializer-derivation pattern (explicit + ``values`` tuple + ``field_map`` dict) must continue to work, including + inheritance semantics and MRO isolation between parent and child classes. + """ + + def test_explicit_values_and_string_field_map(self): + """Explicit values tuple + string field_map renames source → declared key.""" + + class V(BaseValuesViewset, ListModelMixin): + queryset = Author.objects.none() + values = ("id", "name") + field_map = {"display_name": "name"} + + result = _serialize(V(), [{"id": "a1", "name": "Alice"}]) + self.assertEqual(result[0], {"id": "a1", "display_name": "Alice"}) + + def test_callable_field_map(self): + """Callable field_map entries receive the full item and can pop/transform.""" + + def upper(item): + return item.pop("name", "").upper() + + class V(BaseValuesViewset, ListModelMixin): + queryset = Author.objects.none() + values = ("id", "name") + field_map = {"loud_name": upper} + + result = _serialize(V(), [{"id": "a1", "name": "alice"}]) + self.assertEqual(result[0]["loud_name"], "ALICE") + + def test_field_map_mutation_after_init_does_not_leak(self): + """Mutating the class-level field_map after init must not affect the instance.""" + + class V(BaseValuesViewset, ListModelMixin): + queryset = Author.objects.none() + values = ("id", "name") + field_map = {"display_name": "name"} + + inst = V() + V.field_map["injected"] = "id" # post-init mutation + result = _serialize(inst, [{"id": "a1", "name": "Alice"}]) + self.assertNotIn("injected", result[0]) + + def test_child_inherits_parent_explicit_values(self): + """A subclass without overrides serializes using parent's values + field_map.""" + + class Parent(BaseValuesViewset, ListModelMixin): + queryset = Author.objects.none() + values = ("id", "name") + field_map = {"display_name": "name"} + + class Child(Parent): + pass + + result = _serialize(Child(), [{"id": "a1", "name": "Alice"}]) + self.assertEqual(result[0], {"id": "a1", "display_name": "Alice"}) + + def test_subclass_serializer_does_not_reuse_parent_derived_info(self): + """A subclass declaring its own serializer_class uses its own derived fields.""" + ParentSer = make_serializer( + display_name=serializers.CharField(source="name"), + ) + ChildSer = make_serializer( + loud_name=serializers.CharField(source="name"), + ) + + class Parent(BaseValuesViewset, ListModelMixin): + queryset = Author.objects.none() + serializer_class = ParentSer + + class Child(Parent): + serializer_class = ChildSer + + result = _serialize(Child(), [{"id": "a1", "name": "Alice"}]) + self.assertIn("loud_name", result[0]) + self.assertNotIn("display_name", result[0]) + + def test_child_not_confused_by_parent_auto_derived_values(self): + """Parent's auto-derived ``values`` (set on the class during + ``_ensure_initialized`` to support the ordering filter) must not be + treated as explicit when a child subclasses it. Otherwise the child + falls into the explicit-values path and serializes with parent's + fields rather than deriving from its own ``serializer_class``. + """ + ParentSer = make_serializer( + display_name=serializers.CharField(source="name"), + ) + ChildSer = make_serializer( + loud_name=serializers.CharField(source="name"), + ) + + class Parent(BaseValuesViewset, ListModelMixin): + queryset = Author.objects.none() + serializer_class = ParentSer + + Parent() # Force init so cls.values is auto-set on Parent + + class Child(Parent): + serializer_class = ChildSer + + result = _serialize(Child(), [{"id": "a1", "name": "Alice"}]) + self.assertIn("loud_name", result[0]) + self.assertNotIn("display_name", result[0]) + + def test_child_not_confused_by_parent_generated_serializer_class(self): + """Parent's auto-generated serializer must not leak to the child. + + generate_serializer() drops FK-traversal entries from values (they + aren't direct model fields), so re-deriving from a parent's cached + auto-generated serializer would lose those entries on the child. + The child must still serialize rows that include the FK traversal. + """ + + class Parent(BaseValuesViewset, ListModelMixin): + queryset = Author.objects.none() + values = ("id", "name", "publisher__name") + field_map = { + "display_name": "name", + "publisher_name": "publisher__name", + } + + Parent().get_serializer_class() # triggers the lossy auto-gen cache + + class Child(Parent): + pass + + result = _serialize( + Child(), + [{"id": "a1", "name": "Alice", "publisher__name": "Main House"}], + ) + self.assertEqual(result[0]["display_name"], "Alice") + self.assertEqual(result[0]["publisher_name"], "Main House") + + def test_ordering_filter_over_explicit_field_map(self): + """Ordering filter exposes explicit field_map keys as valid ordering fields.""" + + class V(BaseValuesViewset, ListModelMixin): + queryset = Author.objects.all() + values = ("id", "name") + field_map = {"display_name": "name"} + + valid = ValuesViewsetOrderingFilter().get_default_valid_fields( + V().queryset, V() + ) + self.assertIn(("display_name", "display_name"), valid) + + def test_ordering_filter_over_derived_field_map(self): + """Ordering filter exposes declared serializer names when field_map is derived.""" + viewset = make_viewset( + queryset=Author.objects.all(), + id=serializers.UUIDField(), + display_name=serializers.CharField(source="name"), + ) + valid = ValuesViewsetOrderingFilter().get_default_valid_fields( + viewset.queryset, viewset + ) + self.assertIn(("display_name", "display_name"), valid) + + def test_ordering_filter_translates_declared_name_to_source(self): + """Ordering by a declared name translates to the source column for the DB.""" + viewset = make_viewset( + queryset=Author.objects.all(), + id=serializers.UUIDField(), + display_name=serializers.CharField(source="name"), + ) + filter_backend = ValuesViewsetOrderingFilter() + request = MagicMock() + self.assertEqual( + filter_backend.remove_invalid_fields( + viewset.queryset, ["display_name"], viewset, request + ), + ["name"], + ) + self.assertEqual( + filter_backend.remove_invalid_fields( + viewset.queryset, ["-display_name"], viewset, request + ), + ["-name"], + ) + + +class TestDevModeSafeguards(TestCase): + """DEBUG-only contracts catch the configs developers are most likely to + get wrong, plus errors for misconfigurations that would otherwise fail + silently. The goal is surfacing problems at a useful boundary with + identifying info in the error message. + """ + + @override_settings(DEBUG=True) + def test_multiple_joined_many_nested_raises_error(self): + """Two many=True nested serializers without deferring raise (cartesian product).""" + Ser = make_serializer( + id=serializers.CharField(), + books=BookSerializer(many=True), + classrooms=ClassroomSerializer(many=True), + ) + with self.assertRaises(TypeError) as ctx: + make_viewset(serializer_class=Ser) + self.assertIn("books", str(ctx.exception)) + self.assertIn("classrooms", str(ctx.exception)) + + def test_multiple_joined_many_with_one_deferred_is_fine(self): + """Deferring one of two many-nested serializers avoids the cartesian error.""" + Ser = make_serializer( + id=serializers.CharField(), + books=BookSerializer(many=True), + classrooms=ClassroomSerializer(many=True), + ) + with override_settings(DEBUG=True): + viewset = make_viewset( + serializer_class=Ser, deferred_fields=("classrooms",) + ) + result = _serialize( + viewset, + [{"id": "a1", "books__id": "b1", "books__title": "B1"}], + ) + self.assertEqual(len(result[0]["books"]), 1) + self.assertNotIn("classrooms", result[0]) + + @override_settings(DEBUG=True) + def test_multiple_many_inside_deferred_raises_error(self): + """A deferred nested serializer whose own children include 2+ + un-deferred many=True nested serializers must raise: otherwise + ``serialize_queryset`` for that path would silently emit cartesian + rows (auto-consolidate dedupes but the SQL is over-fetched). + """ + GcA = make_serializer(id=serializers.CharField()) + GcB = make_serializer(id=serializers.CharField()) + InnerSer = make_serializer( + id=serializers.CharField(), + tags=GcA(many=True), + co_authors=GcB(many=True), + ) + Ser = make_serializer( + id=serializers.CharField(), + books=InnerSer(many=True), + ) + with self.assertRaises(TypeError) as ctx: + make_viewset(serializer_class=Ser, deferred_fields=("books",)) + self.assertIn("tags", str(ctx.exception)) + self.assertIn("co_authors", str(ctx.exception)) + + @parameterized.expand( + [ + ("non_many_child_many_gc", False, True, "book"), + ("many_child_many_gc", True, True, "books_outer"), + ("non_many_child_non_many_gc", False, False, "book"), + ("many_child_non_many_gc", True, False, "books_outer"), + ] + ) + @override_settings(DEBUG=True) + def test_deep_nesting_raises_error( + self, _name, child_many, grandchild_many, expected_field + ): + """All deep-nesting shapes (nested-in-nested) raise at viewset instantiation.""" + GC = make_serializer(id=serializers.CharField()) + gc_field = "grandchildren" if grandchild_many else "grandchild" + gc_kwargs = {"many": True} if grandchild_many else {} + ChildSer = make_serializer( + id=serializers.CharField(), **{gc_field: GC(**gc_kwargs)} + ) + child_field = "books_outer" if child_many else "book" + child_kwargs = ( + {"many": True, "source": "books"} if child_many else {"source": "books"} + ) + ParentSer = make_serializer( + id=serializers.CharField(), + **{child_field: ChildSer(**child_kwargs)}, + ) + with self.assertRaises(TypeError) as ctx: + make_viewset(serializer_class=ParentSer) + self.assertIn(expected_field, str(ctx.exception)) + + def test_deep_nesting_with_deferred_avoids_error(self): + """Deferring the deeply-nested field lets derivation succeed.""" + GC = make_serializer(id=serializers.CharField()) + ChildSer = make_serializer( + id=serializers.CharField(), grandchildren=GC(many=True) + ) + ParentSer = make_serializer( + id=serializers.CharField(), + books_outer=ChildSer(many=True, source="books"), + ) + with override_settings(DEBUG=True): + viewset = make_viewset( + serializer_class=ParentSer, + deferred_fields=("books_outer",), + ) + result = _serialize(viewset, [{"id": "a1"}]) + self.assertNotIn("books_outer", result[0]) + + @override_settings(DEBUG=True) + def test_validate_raises_on_drift_in_flat_output(self): + """consolidate() adding a field not on the serializer raises ValueError.""" + Ser = make_serializer(id=serializers.CharField(), name=serializers.CharField()) + + class V(BaseValuesViewset, ListModelMixin): + queryset = Author.objects.none() + serializer_class = Ser + + def consolidate(self, items, queryset): + for item in items: + item["unexpected"] = "oops" + return items + + _assert_serialize_raises( + self, V(), [{"id": "a1", "name": "Alice"}], "unexpected" + ) + + @override_settings(DEBUG=True) + def test_validate_raises_on_drift_in_nested_many_output(self): + """consolidate() producing a nested many item missing a field raises.""" + Ser = make_serializer( + id=serializers.CharField(), books=BookSerializer(many=True) + ) + + class V(BaseValuesViewset, ListModelMixin): + queryset = Author.objects.none() + serializer_class = Ser + + def consolidate(self, items, queryset): + for item in items: + item["books"] = [{"id": "b1"}] # missing 'title' + return items + + _assert_serialize_raises( + self, + V(), + [{"id": "a1", "books__id": "b1", "books__title": "B1"}], + "title", + ) + + @override_settings(DEBUG=True) + def test_validate_raises_on_drift_in_nested_single_output(self): + """consolidate() producing a nested dict missing a field raises.""" + Ser = make_serializer( + model=Book, + id=serializers.CharField(), + author=make_nested( + model=Author, + id=serializers.CharField(), + name=serializers.CharField(), + ), + ) + + class V(BaseValuesViewset, ListModelMixin): + queryset = Book.objects.none() + serializer_class = Ser + + def consolidate(self, items, queryset): + for item in items: + item["author"] = {"id": "a1"} # missing 'name' + return items + + _assert_serialize_raises( + self, + V(), + [{"id": "b1", "author__id": "a1", "author__name": "Alice"}], + "name", + ) + + @override_settings(DEBUG=True) + def test_validate_catches_consolidate_deleting_a_field(self): + """consolidate() removing a declared field raises.""" + Ser = make_serializer(id=serializers.CharField(), name=serializers.CharField()) + + class V(BaseValuesViewset, ListModelMixin): + queryset = Author.objects.none() + serializer_class = Ser + + def consolidate(self, items, queryset): + for item in items: + del item["name"] + return items + + _assert_serialize_raises(self, V(), [{"id": "a1", "name": "Alice"}], "name") + + @override_settings(DEBUG=True) + def test_validate_ignores_write_only_fields(self): + """write_only fields missing from output don't trigger validation errors.""" + viewset = make_viewset( + id=serializers.CharField(), + email=serializers.CharField(write_only=True), + ) + result = _serialize(viewset, [{"id": "a1"}]) + self.assertEqual(result[0], {"id": "a1"}) + + @override_settings(DEBUG=True) + def test_validate_does_not_crash_on_listfield_child(self): + """ListField(child=CharField()) in serializer doesn't crash validation. + + ListField's child is a plain Field, not a Serializer — the validator + must not recurse into it as a nested schema. + """ + Ser = make_serializer( + id=serializers.CharField(), + book_titles=serializers.ListField(child=serializers.CharField()), + ) + + class V(BaseValuesViewset, ListModelMixin): + queryset = Author.objects.none() + serializer_class = Ser + deferred_fields = ("book_titles",) + + def consolidate(self, items, queryset): + for item in items: + item["book_titles"] = ["B1", "B2"] + return items + + result = _serialize(V(), [{"id": "a1"}]) + self.assertEqual(result[0]["book_titles"], ["B1", "B2"]) + + @override_settings(DEBUG=False) + def test_validation_skipped_when_debug_false(self): + """DEBUG=False — drifting output passes silently.""" + Ser = make_serializer(id=serializers.CharField(), name=serializers.CharField()) + + class V(BaseValuesViewset, ListModelMixin): + queryset = Author.objects.none() + serializer_class = Ser + + def consolidate(self, items, queryset): + for item in items: + item["extra"] = "ignored" + return items + + result = _serialize(V(), [{"id": "a1", "name": "Alice"}]) + self.assertEqual(result[0]["extra"], "ignored") + + @override_settings(DEBUG=True) + def test_scalar_many_passes_validation(self): + """Scalar-many fields should not trip DEBUG validation (flat list, not dict).""" + viewset = make_viewset( + id=serializers.CharField(), + book_titles=serializers.CharField(source="books.title"), + ) + result = _serialize( + viewset, + [ + {"id": "a1", "books__title": "B1"}, + {"id": "a2", "books__title": None}, + ], + ) + self.assertEqual(len(result), 2) + + def test_missing_nested_pk_raises_descriptive_error(self): + """Missing nested-pk key in a row raises with field + key identification.""" + viewset = author_books_viewset() + viewset.__class__._joined_many = (("books", "nonexistent_pk"),) + + flat_items = [ + {"id": "a1", "books__id": "b1", "books__title": "B1"}, + {"id": "a1", "books__id": "b2", "books__title": "B2"}, + ] + with self.assertRaises(KeyError) as ctx: + viewset.serialize(create_mock_queryset(flat_items)) + msg = str(ctx.exception) + self.assertIn("books", msg) + self.assertIn("nonexistent_pk", msg) + self.assertIn("_auto_consolidate", msg) + + def test_missing_source_key_raises_key_error(self): + """A field_map entry pointing at a source absent from the row fails fast. + + Misconfigured mappings raise KeyError during serialize() rather than + silently producing None — so a typo in ``field_map`` surfaces at the + first request rather than propagating as bad output. + """ + + class V(BaseValuesViewset, ListModelMixin): + queryset = Author.objects.none() + values = ("id",) + field_map = {"display_name": "nonexistent"} + + with self.assertRaises(KeyError): + _serialize(V(), [{"id": "a1"}]) + + +class TestAuxiliaryAPIs(TestCase): + """Surfaces beyond ``serialize()``: nested serializer lookup, + separate-queryset serialization (``serialize_queryset``), deferred-field + filtering, and lazy queryset resolution when no class-level ``queryset`` + is defined. + """ + + def test_get_nested_serializer_direct_path(self): + """Direct nested path returns the corresponding child serializer.""" + viewset = author_books_viewset() + nested = viewset.get_nested_serializer("books") + self.assertIn("id", nested.fields) + self.assertIn("title", nested.fields) + + def test_get_nested_serializer_doubly_nested_path(self): + """Dotted path (e.g., 'books__tags') resolves through deferred nested serializers.""" + TagSer = make_serializer( + model=Tag, + id=serializers.CharField(), + name=serializers.CharField(), + ) + BookSer = make_serializer( + model=Book, + id=serializers.CharField(), + tags=TagSer(many=True), + ) + viewset = make_viewset( + serializer_class=make_serializer( + id=serializers.CharField(), books=BookSer(many=True) + ), + deferred_fields=("books",), + ) + nested = viewset.get_nested_serializer("books__tags") + self.assertEqual(set(nested.fields), {"id", "name"}) + + def test_get_nested_serializer_invalid_path_raises(self): + """A path that doesn't resolve raises KeyError.""" + viewset = make_viewset(id=serializers.CharField()) + with self.assertRaises(KeyError): + viewset.get_nested_serializer("nonexistent") + + def test_serialize_queryset_returns_list_of_items(self): + """serialize_queryset without group_by returns a flat list.""" + viewset = author_books_viewset(deferred=True) + qs = MagicMock() + qs.values.return_value = [ + {"id": "b1", "title": "B1"}, + {"id": "b2", "title": "B2"}, + ] + result = viewset.serialize_queryset(qs, "books") + self.assertEqual(len(result), 2) + self.assertEqual(result[0]["title"], "B1") + + def test_deferred_field_excluded_from_values_call(self): + """Fields in deferred_fields are not requested in the main values() call.""" + viewset = author_books_viewset(deferred=True, email=serializers.CharField()) + mock_qs = create_mock_queryset([{"id": "a1", "email": "alice@example.com"}]) + result = viewset.serialize(mock_qs) + values_args = mock_qs.values.call_args[0] + self.assertIn("id", values_args) + self.assertIn("email", values_args) + self.assertNotIn("books__id", values_args) + self.assertNotIn("books__title", values_args) + self.assertNotIn("books", result[0]) + + def test_auto_consolidate_works_without_class_level_queryset(self): + """Viewsets using get_queryset() (no class queryset) still resolve PK lazily.""" + Ser = make_serializer( + id=serializers.CharField(), + books=make_nested( + model=Book, + many=True, + id=serializers.CharField(), + title=serializers.CharField(), + ), + ) + + class V(BaseValuesViewset, ListModelMixin): + serializer_class = Ser + + def get_queryset(self): + return Author.objects.none() + + flat_items = [ + {"id": "a1", "books__id": "b1", "books__title": "B1"}, + {"id": "a1", "books__id": "b2", "books__title": "B2"}, + ] + result = _serialize(V(), flat_items) + self.assertEqual(len(result[0]["books"]), 2) diff --git a/kolibri/core/test/test_app/models.py b/kolibri/core/test/test_app/models.py index 3196f473b84..aeb74806a21 100644 --- a/kolibri/core/test/test_app/models.py +++ b/kolibri/core/test/test_app/models.py @@ -1,9 +1,11 @@ import datetime +import uuid from django.db import models from django.utils import timezone from kolibri.core.fields import DateTimeTzField +from kolibri.core.fields import JSONField def aware_datetime(): @@ -40,3 +42,75 @@ class Membership(models.Model): class DateTimeTzModel(models.Model): timestamp = DateTimeTzField(null=True) default_timestamp = DateTimeTzField(default=aware_datetime) + + +# Synthetic relation zoo for test_api.py. +# +# Author is the primary outer model (UUID pk + scalar fields covering the +# types exercised by type-inference tests). The surrounding models provide +# every relation shape the introspection code distinguishes: +# +# - Publisher: FK target (nullable, for flat FK-traversal + null FK tests) +# - Profile: OneToOne to Author (single-nested + reverse 1:1) +# - Book: reverse FK many (via Author.books) + direct M2M (Book.tags) +# - Tag: M2M target (reverse M2M via Tag.books) +# - Enrollment: through-model for Author↔Classroom M2M (Author.classrooms) + + +class Publisher(models.Model): + name = models.CharField(max_length=128, default="") + + +class Author(models.Model): + id = models.UUIDField(primary_key=True, default=uuid.uuid4, editable=False) + name = models.CharField(max_length=128, default="") + email = models.CharField(max_length=128, default="") + is_active = models.BooleanField(default=True) + status = models.CharField( + max_length=16, + choices=[("active", "Active"), ("retired", "Retired"), ("", "")], + default="active", + ) + publisher = models.ForeignKey( + Publisher, + related_name="authors", + null=True, + blank=True, + on_delete=models.SET_NULL, + ) + classrooms = models.ManyToManyField( + Classroom, + through="Enrollment", + related_name="enrolled_authors", + ) + metadata = JSONField(null=True, blank=True, default=dict) + + +class Profile(models.Model): + author = models.OneToOneField( + Author, + related_name="profile", + on_delete=models.CASCADE, + ) + bio = models.CharField(max_length=255, default="", blank=True) + is_verified = models.BooleanField(default=False) + + +class Tag(models.Model): + name = models.CharField(max_length=64, default="") + + +class Book(models.Model): + author = models.ForeignKey(Author, related_name="books", on_delete=models.CASCADE) + title = models.CharField(max_length=128, default="") + description = models.CharField(max_length=255, null=True, blank=True) + tags = models.ManyToManyField(Tag, related_name="books") + + +class Enrollment(models.Model): + author = models.ForeignKey( + Author, related_name="enrollments", on_delete=models.CASCADE + ) + classroom = models.ForeignKey( + Classroom, related_name="author_enrollments", on_delete=models.CASCADE + ) diff --git a/kolibri/core/utils/serializer_introspection.py b/kolibri/core/utils/serializer_introspection.py new file mode 100644 index 00000000000..130cb915e47 --- /dev/null +++ b/kolibri/core/utils/serializer_introspection.py @@ -0,0 +1,850 @@ +from abc import ABCMeta +from abc import abstractmethod +from typing import Any +from typing import Callable +from typing import cast +from typing import Dict +from typing import List +from typing import Optional +from typing import Set +from typing import Tuple +from typing import Type +from typing import Union + +from django.core.exceptions import FieldDoesNotExist +from django.db.models import Field +from django.db.models import Model +from django.db.models.fields.related import ForeignObjectRel +from rest_framework import serializers as drf_serializers +from rest_framework.fields import empty +from rest_framework.fields import Field as DrfField +from rest_framework.serializers import ModelSerializer +from rest_framework.serializers import SerializerMethodField +from rest_framework.utils.field_mapping import ClassLookupDict + + +class ValuesMethodField(SerializerMethodField): + """ + ``SerializerMethodField`` variant for ``ValuesViewset``: declares the + row columns the bound method reads. + + ``sources`` paths use DRF dot notation (``"dataset.id"``) and are + translated internally to ORM double-underscore (``"dataset__id"``) for + the ``values()`` query. The bound method receives a proxy over those + declared sources: + + class MySerializer(ModelSerializer): + status = ValuesMethodField(sources=("transfer_status", "last_synced")) + + def get_status(self, obj): + if obj.transfer_status in IN_PROGRESS: + return SYNCING + ... + + Sources that are not also declared as serializer output fields are + fetched but stripped from the final row. + """ + + def __init__(self, *, sources=(), method_name=None, **kwargs): + super().__init__(method_name=method_name, **kwargs) + self.sources = tuple(sources) + + +class _SourcesProxy: + """ + Attribute proxy over a raw ``values()`` row, scoped to the paths + declared in a ``ValuesMethodField``'s ``sources``. + + Supports dotted traversal matching declared paths: for + ``sources=("publisher.name",)``, ``obj.publisher.name`` walks down + and returns ``raw["publisher__name"]``. + + Attribute access outside the declared set raises ``AttributeError`` + with the requested name and the declared sources in the message, so + the boundary is discoverable. + """ + + __slots__ = ("_raw", "_sources", "_prefix") + + def __init__(self, raw, sources, prefix=""): + self._raw = raw + self._sources = sources + self._prefix = prefix + + def __getattr__(self, name): + # Let Python's internal machinery (repr, copy, pickle, etc.) see a + # plain AttributeError for dunder and private attribute lookups + # rather than the ValuesMethodField-framed message below. + if name.startswith("_"): + raise AttributeError(name) + path = "{}__{}".format(self._prefix, name) if self._prefix else name + if path in self._sources: + return self._raw[path] + sep = path + "__" + if any(source.startswith(sep) for source in self._sources): + return _SourcesProxy(self._raw, self._sources, path) + declared = sorted(source.replace("__", ".") for source in self._sources) + raise AttributeError( + "{!r} not declared — ValuesMethodField exposes sources only: " + "{}. Add to sources=, or inline the logic.".format(name, declared) + ) + + +# A row produced by ``queryset.values()`` — dict of flat path → value. +Row = Dict[str, Any] + + +class SourceFieldEntry: + """ + Field map entry for a single-source rename, optionally transformed. + + ``to_repr=None`` means a plain rename. When ``to_repr`` is set and the + raw value is ``None``, ``default`` is substituted — mirrors DRF's + get_attribute fallback for missing relations. + """ + + __slots__ = ("source", "to_repr", "default") + + def __init__( + self, + source: str, + to_repr: Optional[Callable] = None, + default: Any = None, + ): + self.source = source + self.to_repr = to_repr + self.default = default + + def _represent(self, raw: Any) -> Any: + if self.to_repr is None: + return raw + if raw is None: + return self.default + return self.to_repr(raw) + + def extract(self, row: Row) -> Any: + """Produce the output value, reading ``row`` without mutating it.""" + return self._represent(row.get(self.source)) + + def apply(self, key: str, row: Row) -> None: + """Pop the source from ``row`` and write the value under ``key``.""" + row[key] = self._represent(row.pop(self.source)) + + +class CallableFieldEntry: + """ + Field map entry computed by a callable over the whole raw row: nested + extractors, ``ValuesMethodField`` invokers, and legacy user-written + callables. + + ``source`` and ``to_repr`` are ``None`` at the class level: the value + doesn't map to a single DB column, so consumers that introspect sources + (ordering filter, serializer generation) skip these entries. + """ + + __slots__ = ("func",) + + source = None + to_repr = None + + def __init__(self, func: Callable[[Row], Any]): + self.func = func + + def extract(self, row: Row) -> Any: + """Produce the output value, reading ``row`` without mutating it.""" + return self.func(row) + + def apply(self, key: str, row: Row) -> None: + """Write the value under ``key``; legacy callables may mutate ``row`` + themselves (e.g. popping consumed sources).""" + row[key] = self.func(row) + + +# User-written legacy ``field_map = {"x": "source"}`` / ``{"x": callable}`` +# entries are normalized to these classes at ingest by ``normalize_field_map``. +FieldMapEntry = Union[SourceFieldEntry, CallableFieldEntry] +FieldMap = Dict[str, FieldMapEntry] + + +class _BaseFieldMap(dict, metaclass=ABCMeta): + """ + Base field map: a dict of output field name → entry producing its value, + owning row mapping and source introspection over those entries. + """ + + @abstractmethod + def map_row(self, row: Row) -> Row: + """Produce the output row for a raw ``values()`` row.""" + + def source_map(self) -> Dict[str, str]: + """ + ``{target: source}`` pairs for fields backed by a single DB column. + + Callable entries are excluded (``source`` is ``None``) — they may do + arbitrary transformations and don't map cleanly to a single column. + """ + return { + key: entry.source for key, entry in self.items() if entry.source is not None + } + + def plain_renames(self) -> Dict[str, str]: + """ + ``{source: target}`` pairs for plain renames (source set, no + ``to_repr``), so raw values can be exposed under the declared name. + """ + return { + entry.source: key + for key, entry in self.items() + if entry.source is not None and entry.to_repr is None + } + + +class _FieldMap(_BaseFieldMap): + """ + A field map built by serializer introspection, covering every declared + output field (including trivial passthroughs). + + ``map_row`` builds a fresh output dict, reading raw values without + mutating the input row, so the result contains exactly the declared + outputs — method-field sources pulled in for invocation never leak into + output because they're not field_map keys. + """ + + def map_row(self, row: Row) -> Row: + return {key: entry.extract(row) for key, entry in self.items()} + + +class _LegacyFieldMap(_BaseFieldMap): + """ + A field map normalized from a legacy explicit ``values`` / ``field_map`` + viewset definition. + + ``map_row`` mutates the row in place, popping sources and writing + targets. Required for back-compat with viewsets that rely on + ``values()`` keys passing through to the output when not claimed by a + ``field_map`` entry. + """ + + def map_row(self, row: Row) -> Row: + for key, entry in self.items(): + entry.apply(key, row) + return row + + +# A ``joined_many`` entry: (field_name, nested_pk_output_name) per many=True +# nested serializer, used by _auto_consolidate for dedup. nested_pk_name is +# None for scalar-many fields (dedup by value itself). +JoinedMany = Tuple[Tuple[str, Optional[str]], ...] + +# Nested cache entries keyed by dotted path, built during introspection so +# deferred nested serializers can be serialized separately via the +# ``serialize_queryset`` API. +NestedCacheEntry = Tuple[List[str], _FieldMap, JoinedMany] +NestedCache = Dict[str, NestedCacheEntry] + +# Return shape of the top-level introspection call. +IntrospectionResult = Tuple[List[str], _FieldMap, JoinedMany, NestedCache] + + +def _get_source_path(field: DrfField, field_name: str, prefix: str) -> Optional[str]: + """Extract the source path for a serializer field. + + Converts DRF dot-notation sources (e.g. "parent.name") to Django ORM + double-underscore notation (e.g. "parent__name") for use in values() queries. + """ + source = getattr(field, "source", None) + if source == "*" or isinstance(source, (list, tuple)): + return None + source_path = source if source else field_name + # DRF uses dot notation for relationship traversal (e.g. "parent.name"), + # but Django ORM values() requires double-underscore notation ("parent__name"). + source_path = source_path.replace(".", "__") + return f"{prefix}{source_path}" if prefix else source_path + + +def _is_nested_model_serializer(field: DrfField) -> bool: + """Check if a field is a nested ``ModelSerializer`` (or ``ListSerializer`` + wrapping one). + + Plain ``Serializer`` subclasses (e.g. structural wrappers around a + ``JSONField``) are intentionally excluded — they have no ``Meta.model`` + to introspect and are handled by the regular-field path, where + ``field.to_representation`` runs on the raw value. + """ + if isinstance(field, drf_serializers.ListSerializer): + return isinstance(field.child, ModelSerializer) + return isinstance(field, ModelSerializer) + + +def _source_crosses_many_relation( + model: Optional[Type[Model]], source_path: str +) -> bool: + """ + Check whether a source path crosses a one-to-many or many-to-many relation. + + Used to detect fields like ``roles__kind`` where ``roles`` is a reverse FK, + so the values() query produces multiple rows that need list consolidation. + """ + if model is None: + return False + parts = source_path.split("__") + current_model: Type[Model] = model + for part in parts: + try: + field = current_model._meta.get_field(part) # type: ignore[attr-defined] + except FieldDoesNotExist: + return False + if getattr(field, "one_to_many", False) or getattr( + field, "many_to_many", False + ): + return True + related_model = getattr(field, "related_model", None) + if related_model is None: + return False + current_model = related_model + return False + + +def _get_model_field_for_source( + model: Optional[Type[Model]], source_path: str +) -> Optional[Union[Field, ForeignObjectRel]]: + """ + Walk a source path like 'user__profile__name' to get the final model field. + + Returns the final model field, or None if the path is invalid or doesn't + resolve to a concrete model field. ``_meta.get_field`` returns either a + ``models.Field`` subclass or a ``ForeignObjectRel`` for reverse accessors; + callers only need the ``related_model`` / ``choices`` attributes common + to both. + """ + if model is None: + return None + + parts = source_path.split("__") + current_model: Type[Model] = model + + for i, part in enumerate(parts): + try: + field = current_model._meta.get_field(part) # type: ignore[attr-defined] + except FieldDoesNotExist: + return None + + # If this is the last part, return the field + if i == len(parts) - 1: + return field + + # Otherwise, it must be a relation - get the related model + related_model = getattr(field, "related_model", None) + if related_model is None: + # Not a relation, but we expected more path segments + return None + current_model = related_model + return None + + +def _field_matches_inferred_type( + declared_field: DrfField, + source_path: str, + serializer_class: Type[ModelSerializer], + model: Optional[Type[Model]], +) -> bool: + """ + Check if a declared field matches what DRF's ModelSerializer would auto-generate. + + Returns True if the declared field type is exactly what ModelSerializer + would have inferred for the given model field, meaning we can skip calling + to_representation (it would be effectively a no-op for simple types). + """ + model_field = _get_model_field_for_source(model, source_path) + if model_field is None: + return False + + # For relation fields, check against the serializer's related field class + # (typically PrimaryKeyRelatedField). values() returns the raw FK value + # (e.g. a UUID string), and PrimaryKeyRelatedField.to_representation + # expects a model instance — so when they match, we skip to_representation + # and pass the raw value through, which is already the PK. + # No default check needed here: FK columns always have a value in output; + # any default on the serializer field is for input (deserialization) only. + if getattr(model_field, "related_model", None) is not None: + return type(declared_field) is serializer_class.serializer_related_field + + # Fields with an explicit default need the 3-tuple path so that + # None values (e.g. from a LEFT JOIN miss) get the default substituted. + # The simple rename path doesn't handle None → default. + if declared_field.default is not empty: + return False + + # Special case: fields with choices become ChoiceField + if getattr(model_field, "choices", None): + inferred_class = serializer_class.serializer_choice_field + else: + field_mapping = ClassLookupDict(serializer_class.serializer_field_mapping) + try: + inferred_class = field_mapping[model_field] + except KeyError: + return False + + # Exact class match only - subclasses may override to_representation + return type(declared_field) is inferred_class + + +def _resolve_nested_null_check_key( + child: ModelSerializer, source: str +) -> Optional[str]: + """ + Resolve the prefixed PK key for a nested serializer's model. + + Used for LEFT JOIN null detection: if the PK is null, the entire nested + object is null (no related row exists). + + Returns the prefixed PK key (e.g., 'child_set__id') or None if the nested + serializer has no model. + """ + child_model = getattr(getattr(type(child), "Meta", None), "model", None) + if child_model is None: + return None + return "{}__{}".format(source, child_model._meta.pk.name) + + +def _resolve_nested_pk_output_name( + null_check_key: str, prefix: str, child_field_map: _FieldMap +) -> str: + """ + Find the output field name of the nested PK, for deduplication in + _auto_consolidate. The null_check_key is the prefixed source name + (e.g. 'child_set__id'); we need to find what child_field_map renames + it to (e.g. 'id' -> 'identifier'), or fall back to the source name. + """ + pk_source = null_check_key[len(prefix) :] + for out_name, map_val in child_field_map.items(): + if map_val.source == pk_source: + return out_name + return pk_source + + +def _make_nested_extractor( + null_check_key: str, + value_pairs: Tuple[Tuple[str, str], ...], + child_field_map: _FieldMap, +) -> Callable[[Row], Optional[Row]]: + """ + Create a field_map callable that extracts a nested item from a raw row. + + Reads all prefixed keys from the raw row without mutating it, builds a + child dict keyed by unprefixed source names, then applies the child's + field_map for renames and transforms. Returns the nested dict or None + if the FK is null. + """ + + def extract(row: Row) -> Optional[Row]: + if row.get(null_check_key) is None: + return None + + nested_item: Row = {} + for source_name, prefixed_key in value_pairs: + nested_item[source_name] = row.get(prefixed_key) + if child_field_map: + nested_item = child_field_map.map_row(nested_item) + return nested_item + + return extract + + +def _deep_nesting_error( + field_name: str, + child: ModelSerializer, + deferred_in_child: Set[str], +) -> Optional[str]: + """Return a deep-nesting error message if ``child`` has further nested + serializers (excluding any deferred at the child level), else ``None``. + """ + if any( + _is_nested_model_serializer(f) and gn not in deferred_in_child + for gn, f in cast(Dict[str, DrfField], child.fields).items() + ): + return ( + "Nested serializer field '{}' contains further nested " + "serializers. Deep nesting is not supported for " + "auto-consolidation. Use deferred_fields to fetch '{}' " + "separately and implement a custom consolidate() " + "method.".format(field_name, field_name) + ) + return None + + +def _check_serializer_constraints( + serializer: ModelSerializer, deferred_fields: Tuple[str, ...] = () +) -> List[str]: + """ + DEBUG-only preflight: validate serializer structure before introspection. + + Returns a list of error messages (empty if no issues). The caller is + responsible for raising; collecting lets a single run surface every + violation across the (possibly recursive) tree at once. + + Checks at this level (and recursively into deferred nested children): + - No deep nesting (nested serializers within nested serializers, + unless the inner one is itself deferred via a nested-path entry) + - At most one many=True nested serializer (multiple cause cartesian products) + """ + unnested_deferred_fields = {p for p in deferred_fields if "__" not in p} + nested_deferred_by_child: Dict[str, List[str]] = {} + for path in deferred_fields: + if "__" in path: + head, tail = path.split("__", 1) + nested_deferred_by_child.setdefault(head, []).append(tail) + + errors: List[str] = [] + many_fields: List[str] = [] + # cast: ``fields`` is a ``cached_property`` that Pyright can't resolve + # to the underlying BindingDict. + for field_name, field in cast(Dict[str, DrfField], serializer.fields).items(): + if not _is_nested_model_serializer(field): + continue + child = cast( + ModelSerializer, + field.child if isinstance(field, drf_serializers.ListSerializer) else field, + ) + if field_name in unnested_deferred_fields: + # Recurse into deferred nested so deep-level violations surface too. + nested_errors = _check_serializer_constraints( + child, tuple(nested_deferred_by_child.get(field_name, ())) + ) + errors.extend("{}: {}".format(field_name, e) for e in nested_errors) + continue + if getattr(field, "write_only", False): + continue + unnested_deferred_in_child = { + p for p in nested_deferred_by_child.get(field_name, ()) if "__" not in p + } + deep_err = _deep_nesting_error(field_name, child, unnested_deferred_in_child) + if deep_err is not None: + errors.append(deep_err) + if isinstance(field, drf_serializers.ListSerializer): + many_fields.append(field_name) + if len(many_fields) > 1: + field_names = ", ".join(sorted(many_fields)) + errors.append( + "Multiple many=True nested serializers cannot be joined in a " + "single query (cartesian product). Found: {}. Move all but one " + "to deferred_fields and handle them in consolidate().".format(field_names) + ) + return errors + + +def _introspect_regular_field( + field_name: str, + field: DrfField, + declared_fields: Dict[str, DrfField], + serializer_class: Type[ModelSerializer], + model: Optional[Type[Model]], +) -> Tuple[Union[str, Tuple[str, ...], None], Optional[FieldMapEntry]]: + """ + Introspect a regular (non-nested) serializer field. + + Returns ``(source_path, entry)``: + + - ``source_path``: value(s) to fetch via ``values()``. ``None`` to skip + the field entirely (e.g. ``source='*'``). For a ``ValuesMethodField``, + a tuple of the declared source paths (caller extends ``values`` with + all of them). + - ``entry``: the field_map entry producing the output value. A + ``CallableFieldEntry`` wrapping an invoker closure for method fields + (wraps the row dict in a ``_SourcesProxy`` and calls the bound + method); a ``SourceFieldEntry`` otherwise, with ``to_repr`` set when + the field transforms its raw value. + """ + if isinstance(field, ValuesMethodField): + source_paths = tuple(source.replace(".", "__") for source in field.sources) + # Capture the bound ``to_representation`` once at introspection time + # so subclass overrides are honoured (and we skip a per-row attribute + # lookup on the field instance). + to_representation = field.to_representation + + def invoke(row: Row) -> Any: + proxy = _SourcesProxy(row, source_paths) + return to_representation(proxy) + + return source_paths, CallableFieldEntry(invoke) + + if isinstance(field, SerializerMethodField): + raise TypeError( + "{}.{}: ValuesViewset does not support plain " + "SerializerMethodField. Use ValuesMethodField(sources=(...)) " + "to declare which row columns the method reads, or a typed " + "field with source= for simple traversals.".format( + serializer_class.__name__, field_name + ) + ) + + source_path = _get_source_path(field, field_name, "") + if source_path is None: + return None, None + + if field_name in declared_fields and not _field_matches_inferred_type( + field, source_path, serializer_class, model + ): + default = field.default if field.default is not empty else None + return source_path, SourceFieldEntry( + source_path, field.to_representation, default + ) + # Trivial passthrough (source == name, matching type) still emits an + # entry so the field_map is a complete spec of output fields. + return source_path, SourceFieldEntry(source_path) + + +def _introspect_deferred_nested( + field_name: str, + child: ModelSerializer, + nested_deferred: Tuple[str, ...] = (), +) -> NestedCache: + """ + Introspect a deferred nested serializer (one listed in deferred_fields). + + Introspects as top-level so the child's own nested serializers are + processed, but does not add values to the parent query. + + Returns a dict of nested_cache entries keyed by path. + """ + ( + child_values, + child_field_map, + child_joined_many, + child_nested, + ) = _introspect_serializer_fields(child, deferred_fields=nested_deferred) + entries: NestedCache = { + field_name: (child_values, child_field_map, child_joined_many), + } + for sub_path, sub_info in child_nested.items(): + entries[f"{field_name}__{sub_path}"] = sub_info + return entries + + +def _introspect_joined_nested( + field_name: str, + field: DrfField, + child: ModelSerializer, + is_many: bool, + nested_deferred: Tuple[str, ...] = (), +) -> Tuple[List[str], FieldMap, NestedCache, List[Tuple[str, Optional[str]]]]: + """ + Introspect a joined (non-deferred) nested serializer. + + Returns ``(prefixed_values, field_map_updates, nested_entries, + joined_many_entries)``. ``joined_many_entries`` is empty for single FK + fields or a one-item list of ``(field_name, nested_pk_name)`` for + many=True fields. + """ + (child_values, child_field_map, _, child_nested,) = _introspect_serializer_fields( + child, deferred_fields=nested_deferred, _is_nested=True + ) + nested_entries: NestedCache = { + field_name: (child_values, child_field_map, ()), + } + for sub_path, sub_info in child_nested.items(): + nested_entries[f"{field_name}__{sub_path}"] = sub_info + + # Prefix child values for the parent's values() call + source = getattr(field, "source", None) or field_name + prefix = f"{source}__" + prefixed_values = [f"{prefix}{v}" for v in child_values] + + extractor: Optional[Callable[[Row], Optional[Row]]] = None + joined_many_entry: Optional[Tuple[str, Optional[str]]] = None + + if prefixed_values: + value_pairs: Tuple[Tuple[str, str], ...] = tuple( + zip(child_values, prefixed_values) + ) + + null_check_key = _resolve_nested_null_check_key(child, source) + if null_check_key is None: + null_check_key = prefixed_values[0] + + extractor = _make_nested_extractor( + null_check_key, + value_pairs, + child_field_map, + ) + + if is_many: + nested_pk_name = _resolve_nested_pk_output_name( + null_check_key, + prefix, + child_field_map, + ) + joined_many_entry = (field_name, nested_pk_name) + + field_map_updates: FieldMap = {} + if extractor is not None: + field_map_updates[field_name] = CallableFieldEntry(extractor) + joined_many_entries: List[Tuple[str, Optional[str]]] = [] + if joined_many_entry is not None: + joined_many_entries.append(joined_many_entry) + + return prefixed_values, field_map_updates, nested_entries, joined_many_entries + + +def _introspect_serializer_fields( + serializer: ModelSerializer, + deferred_fields: Tuple[str, ...] = (), + _is_nested: bool = False, +) -> IntrospectionResult: # noqa: C901 + """ + Introspect a serializer to derive values tuple and field transformations. + + Args: + serializer: The DRF serializer to introspect + deferred_fields: Field names that should be fetched separately (not joined) + _is_nested: Internal flag; True when recursing into a nested serializer + so that further nested fields are skipped. + + Returns: + - values: Fields to fetch via queryset.values() + - field_map: Transforms for fields to map from values call + - joined_many: many=True nested fields as (field_name, nested_pk_name) for dedup. + nested_pk_name is None for scalar many-relation fields (dedup by value). + - nested_cache: path-keyed dict of (values, field_map, joined_many) for + all nested serializers encountered (deferred and joined alike) + """ + values: List[str] = [] + field_map = _FieldMap() + joined_many_fields: List[Tuple[str, Optional[str]]] = [] + nested_cache: NestedCache = {} + declared_fields: Dict[str, DrfField] = getattr(serializer, "_declared_fields", {}) + + # Get serializer class and model for type inference + serializer_class = type(serializer) + model: Optional[Type[Model]] = getattr( + getattr(serializer_class, "Meta", None), "model", None + ) + + serializer_fields: Dict[str, DrfField] = cast( + Dict[str, DrfField], serializer.fields + ) + for field_name, field in serializer_fields.items(): + if getattr(field, "write_only", False): + continue + + if _is_nested_model_serializer(field): + # --- Nested ModelSerializer fields --- + is_many = isinstance(field, drf_serializers.ListSerializer) + child = cast(ModelSerializer, field.child if is_many else field) + nested_deferred = tuple( + p.split("__", 1)[1] + for p in deferred_fields + if "__" in p and p.split("__", 1)[0] == field_name + ) + + if field_name in deferred_fields: + nested_cache.update( + _introspect_deferred_nested(field_name, child, nested_deferred) + ) + continue + + # Skip further joining when already inside a nested serializer + # (deep JOIN); deferred grand-children are handled above. + if _is_nested: + continue + + # Joined nested field — introspect and build extractor + ( + prefixed, + fm_updates, + nested_entries, + many_entries, + ) = _introspect_joined_nested( + field_name, field, child, is_many, nested_deferred + ) + values.extend(prefixed) + field_map.update(fm_updates) + nested_cache.update(nested_entries) + joined_many_fields.extend(many_entries) + continue + + if field_name in deferred_fields: + continue + + source_path, entry = _introspect_regular_field( + field_name, field, declared_fields, serializer_class, model + ) + if source_path is None: + continue + + # ValuesMethodField returns a tuple of source paths and an + # invoker entry. Extend values() with all of them. + if isinstance(source_path, tuple): + values.extend(source_path) + field_map[field_name] = entry + continue + + # Detect fields whose source crosses a one-to-many relation + # (e.g. roles__kind where roles is a reverse FK). These produce + # multiple rows in values() and need list consolidation — a plain + # rename here, consolidation collects the list. + if _source_crosses_many_relation(model, source_path): + entry = SourceFieldEntry(source_path) + joined_many_fields.append((field_name, None)) + + values.append(source_path) + field_map[field_name] = entry + continue + + # Dedupe values() paths: method-field sources can overlap with declared + # field sources, and the same path only needs to be fetched once. + # Sorted so the column order (and hence the generated SQL) is + # consistent across runs — set iteration order varies per process. + values = sorted(set(values)) + + return values, field_map, tuple(joined_many_fields), nested_cache + + +def derive_values_from_serializer( + serializer: ModelSerializer, + deferred_fields: Tuple[str, ...] = (), + *, + check_constraints: bool = False, +) -> IntrospectionResult: + """ + Public entry point: derive values tuple and field mappings from a DRF serializer. + + Args: + serializer: The DRF serializer to introspect + deferred_fields: Field names that should be fetched separately (not joined) + check_constraints: When True, runs DEBUG preflight checks before introspection + + Returns: + - values: Fields to fetch via queryset.values() + - field_map: Transforms for fields to map from values call + - joined_many: many=True nested fields as (field_name, nested_pk_name) for dedup + - nested_cache: path-keyed dict of (values, field_map, joined_many) for + all nested serializers encountered during introspection + """ + if check_constraints: + errors = _check_serializer_constraints(serializer, deferred_fields) + if errors: + raise TypeError("\n".join(errors)) + return _introspect_serializer_fields(serializer, deferred_fields=deferred_fields) + + +def normalize_field_map(field_map: Dict[str, Any]) -> _LegacyFieldMap: + """ + Normalize a user-written legacy field_map to canonical entry objects. + + Converts str shorthand entries (``{"name": "source"}``) to + ``SourceFieldEntry`` plain renames and bare callables to + ``CallableFieldEntry``. Returns a new ``_LegacyFieldMap`` with + mutate-in-place ``map_row`` semantics; the input is not mutated. + """ + return _LegacyFieldMap( + ( + key, + ( + SourceFieldEntry(value) + if isinstance(value, str) + else CallableFieldEntry(value) + ), + ) + for key, value in field_map.items() + )