From d6ca95f4c7ae1d1bdbd0a72b529da5e0ffd1e9f3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9ment=20Escolano?= Date: Fri, 5 Jan 2024 12:52:31 +0100 Subject: [PATCH 1/2] Group queries for PrimaryKeyRelatedField many serializers --- rest_framework/relations.py | 26 ++++++++++++++++++-------- tests/test_relations_pk.py | 7 +++++++ tests/utils.py | 22 +++++++++++++++++----- 3 files changed, 42 insertions(+), 13 deletions(-) diff --git a/rest_framework/relations.py b/rest_framework/relations.py index 4409bce77c..fe0acfc09f 100644 --- a/rest_framework/relations.py +++ b/rest_framework/relations.py @@ -249,18 +249,25 @@ def __init__(self, **kwargs): def use_pk_only_optimization(self): return True - def to_internal_value(self, data): + def to_many_internal_value(self, data): if self.pk_field is not None: - data = self.pk_field.to_internal_value(data) + data = [self.pk_field.to_internal_value(item) for item in data] queryset = self.get_queryset() try: - if isinstance(data, bool): - raise TypeError - return queryset.get(pk=data) - except ObjectDoesNotExist: - self.fail('does_not_exist', pk_value=data) + for item in data: + if isinstance(item, bool): + raise TypeError + result = queryset.filter(pk__in=data).all() + pks = [item.pk for item in result] + for item in data: + if item not in pks: + self.fail('does_not_exist', pk_value=item) + return result except (TypeError, ValueError): - self.fail('incorrect_type', data_type=type(data).__name__) + self.fail('incorrect_type', data_type=type(data[0]).__name__) + + def to_internal_value(self, data): + return self.to_many_internal_value([data])[0] def to_representation(self, value): if self.pk_field is not None: @@ -524,6 +531,9 @@ def to_internal_value(self, data): if not self.allow_empty and len(data) == 0: self.fail('empty') + if hasattr(self.child_relation, "to_many_internal_value"): + return self.child_relation.to_many_internal_value(data) + return [ self.child_relation.to_internal_value(item) for item in data diff --git a/tests/test_relations_pk.py b/tests/test_relations_pk.py index 7a4878a2bf..260229a91a 100644 --- a/tests/test_relations_pk.py +++ b/tests/test_relations_pk.py @@ -189,6 +189,13 @@ def test_many_to_many_create(self): ] assert serializer.data == expected + def test_many_to_many_grouped_queries(self): + data = {'id': 4, 'name': 'source-4', 'targets': [1, 3]} + serializer = ManyToManySourceSerializer(data=data) + # Only one query should be executed even with several targets + with self.assertNumQueries(1): + assert serializer.is_valid() + def test_many_to_many_unsaved(self): source = ManyToManySource(name='source-unsaved') diff --git a/tests/utils.py b/tests/utils.py index 4ceb353099..799d98ec7d 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -26,13 +26,25 @@ def __getitem__(self, val): return self.items[val] def get(self, **lookup): - for item in self.items: + result = self.filter(**lookup).all() + if len(result) > 0: + return result[0] + raise ObjectDoesNotExist() + + def all(self): + return list(self.items) + + def filter(self, **lookup): + return MockQueryset( + item + for item in self.items if all([ - attrgetter(key.replace('__', '.'))(item) == value + attrgetter(key.replace("__in", "").replace('__', '.'))(item) in value + if key.endswith("__in") + else attrgetter(key.replace('__', '.'))(item) == value for key, value in lookup.items() - ]): - return item - raise ObjectDoesNotExist() + ]) + ) class BadType: From b72027fcdbd3c2e7c32ade4abd85fb53512c18d5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9ment=20Escolano?= Date: Fri, 5 Jan 2024 13:19:31 +0100 Subject: [PATCH 2/2] Group queries for SlugRelatedField many serializers --- rest_framework/relations.py | 21 ++++++++++++++++----- tests/test_relations_slug.py | 6 ++++++ tests/utils.py | 10 ++++++++-- 3 files changed, 30 insertions(+), 7 deletions(-) diff --git a/rest_framework/relations.py b/rest_framework/relations.py index fe0acfc09f..d4b59c46d3 100644 --- a/rest_framework/relations.py +++ b/rest_framework/relations.py @@ -4,7 +4,7 @@ from urllib import parse from django.core.exceptions import ImproperlyConfigured, ObjectDoesNotExist -from django.db.models import Manager +from django.db.models import F, Manager from django.db.models.query import QuerySet from django.urls import NoReverseMatch, Resolver404, get_script_prefix, resolve from django.utils.encoding import smart_str, uri_to_iri @@ -458,15 +458,26 @@ def __init__(self, slug_field=None, **kwargs): self.slug_field = slug_field super().__init__(**kwargs) - def to_internal_value(self, data): + def to_many_internal_value(self, data): queryset = self.get_queryset() try: - return queryset.get(**{self.slug_field: data}) - except ObjectDoesNotExist: - self.fail('does_not_exist', slug_name=self.slug_field, value=smart_str(data)) + result = ( + queryset + .filter(**{self.slug_field + "__in": data}) + .annotate(_slug_field_value=F(self.slug_field)) + .all() + ) + slugs = [item._slug_field_value for item in result] + for item in data: + if item not in slugs: + self.fail('does_not_exist', slug_name=self.slug_field, value=smart_str(item)) + return result except (TypeError, ValueError): self.fail('invalid') + def to_internal_value(self, data): + return self.to_many_internal_value([data])[0] + def to_representation(self, obj): slug = self.slug_field if "__" in slug: diff --git a/tests/test_relations_slug.py b/tests/test_relations_slug.py index 0b9ca79d3d..c0343cb994 100644 --- a/tests/test_relations_slug.py +++ b/tests/test_relations_slug.py @@ -174,6 +174,12 @@ def test_reverse_foreign_key_create(self): ] assert serializer.data == expected + def test_reverse_foreign_key_create_grouped_queries(self): + data = {'id': 3, 'name': 'target-3', 'sources': ['source-1', 'source-3']} + serializer = ForeignKeyTargetSerializer(data=data) + with self.assertNumQueries(1): + assert serializer.is_valid() + def test_foreign_key_update_with_invalid_null(self): data = {'id': 1, 'name': 'source-1', 'target': None} instance = ForeignKeySource.objects.get(pk=1) diff --git a/tests/utils.py b/tests/utils.py index 799d98ec7d..b7c2813a6c 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -35,7 +35,7 @@ def all(self): return list(self.items) def filter(self, **lookup): - return MockQueryset( + return MockQueryset([ item for item in self.items if all([ @@ -44,7 +44,13 @@ def filter(self, **lookup): else attrgetter(key.replace('__', '.'))(item) == value for key, value in lookup.items() ]) - ) + ]) + + def annotate(self, **kwargs): + for key, value in kwargs.items(): + for item in self.items: + setattr(item, key, attrgetter(value.name.replace('__', '.'))(item)) + return self class BadType: