Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reduce queries for supported ManyRelatedField #9211

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 34 additions & 13 deletions rest_framework/relations.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from urllib import parse

from django.core.exceptions import ImproperlyConfigured, ObjectDoesNotExist
from django.db.models import Manager
from django.db.models import F, Manager
from django.db.models.query import QuerySet
from django.urls import NoReverseMatch, Resolver404, get_script_prefix, resolve
from django.utils.encoding import smart_str, uri_to_iri
Expand Down Expand Up @@ -249,18 +249,25 @@ def __init__(self, **kwargs):
def use_pk_only_optimization(self):
return True

def to_internal_value(self, data):
def to_many_internal_value(self, data):
if self.pk_field is not None:
data = self.pk_field.to_internal_value(data)
data = [self.pk_field.to_internal_value(item) for item in data]
queryset = self.get_queryset()
try:
if isinstance(data, bool):
raise TypeError
return queryset.get(pk=data)
except ObjectDoesNotExist:
self.fail('does_not_exist', pk_value=data)
for item in data:
if isinstance(item, bool):
raise TypeError
result = queryset.filter(pk__in=data).all()
pks = [item.pk for item in result]
for item in data:
if item not in pks:
self.fail('does_not_exist', pk_value=item)
return result
except (TypeError, ValueError):
self.fail('incorrect_type', data_type=type(data).__name__)
self.fail('incorrect_type', data_type=type(data[0]).__name__)

def to_internal_value(self, data):
return self.to_many_internal_value([data])[0]

def to_representation(self, value):
if self.pk_field is not None:
Expand Down Expand Up @@ -451,15 +458,26 @@ def __init__(self, slug_field=None, **kwargs):
self.slug_field = slug_field
super().__init__(**kwargs)

def to_internal_value(self, data):
def to_many_internal_value(self, data):
queryset = self.get_queryset()
try:
return queryset.get(**{self.slug_field: data})
except ObjectDoesNotExist:
self.fail('does_not_exist', slug_name=self.slug_field, value=smart_str(data))
result = (
queryset
.filter(**{self.slug_field + "__in": data})
.annotate(_slug_field_value=F(self.slug_field))
.all()
)
slugs = [item._slug_field_value for item in result]
for item in data:
if item not in slugs:
self.fail('does_not_exist', slug_name=self.slug_field, value=smart_str(item))
return result
except (TypeError, ValueError):
self.fail('invalid')

def to_internal_value(self, data):
return self.to_many_internal_value([data])[0]

def to_representation(self, obj):
slug = self.slug_field
if "__" in slug:
Expand Down Expand Up @@ -524,6 +542,9 @@ def to_internal_value(self, data):
if not self.allow_empty and len(data) == 0:
self.fail('empty')

if hasattr(self.child_relation, "to_many_internal_value"):
return self.child_relation.to_many_internal_value(data)

return [
self.child_relation.to_internal_value(item)
for item in data
Expand Down
7 changes: 7 additions & 0 deletions tests/test_relations_pk.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,13 @@ def test_many_to_many_create(self):
]
assert serializer.data == expected

def test_many_to_many_grouped_queries(self):
data = {'id': 4, 'name': 'source-4', 'targets': [1, 3]}
serializer = ManyToManySourceSerializer(data=data)
# Only one query should be executed even with several targets
with self.assertNumQueries(1):
assert serializer.is_valid()

def test_many_to_many_unsaved(self):
source = ManyToManySource(name='source-unsaved')

Expand Down
6 changes: 6 additions & 0 deletions tests/test_relations_slug.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,12 @@ def test_reverse_foreign_key_create(self):
]
assert serializer.data == expected

def test_reverse_foreign_key_create_grouped_queries(self):
data = {'id': 3, 'name': 'target-3', 'sources': ['source-1', 'source-3']}
serializer = ForeignKeyTargetSerializer(data=data)
with self.assertNumQueries(1):
assert serializer.is_valid()

def test_foreign_key_update_with_invalid_null(self):
data = {'id': 1, 'name': 'source-1', 'target': None}
instance = ForeignKeySource.objects.get(pk=1)
Expand Down
28 changes: 23 additions & 5 deletions tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,31 @@ def __getitem__(self, val):
return self.items[val]

def get(self, **lookup):
for item in self.items:
result = self.filter(**lookup).all()
if len(result) > 0:
return result[0]
raise ObjectDoesNotExist()

def all(self):
return list(self.items)

def filter(self, **lookup):
return MockQueryset([
item
for item in self.items
if all([
attrgetter(key.replace('__', '.'))(item) == value
attrgetter(key.replace("__in", "").replace('__', '.'))(item) in value
if key.endswith("__in")
else attrgetter(key.replace('__', '.'))(item) == value
for key, value in lookup.items()
]):
return item
raise ObjectDoesNotExist()
])
])

def annotate(self, **kwargs):
for key, value in kwargs.items():
for item in self.items:
setattr(item, key, attrgetter(value.name.replace('__', '.'))(item))
return self


class BadType:
Expand Down
Loading