Skip to content

Commit

Permalink
refactoring using get_object_with_direct_siblings
Browse files Browse the repository at this point in the history
  • Loading branch information
haricot authored Apr 8, 2020
1 parent c610950 commit a292281
Showing 1 changed file with 27 additions and 32 deletions.
59 changes: 27 additions & 32 deletions shop/views/catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,7 @@ class ProductRetrieveView(generics.RetrieveAPIView):
:param use_modal_dialog: If ``True`` (default), render a modal dialog to confirm adding the
product to the cart.
:param one_product: If ``False`` (no default), product_prev, product_next are added to the context.
:param with_direct_siblings: If ``True`` (no default), product_prev, product_next are added to the context.
"""

renderer_classes = (ShopTemplateHTMLRenderer, JSONRenderer, BrowsableAPIRenderer)
Expand All @@ -269,7 +269,7 @@ class ProductRetrieveView(generics.RetrieveAPIView):
serializer_class = ProductSerializer
limit_choices_to = models.Q()
use_modal_dialog = True
one_product = True
with_direct_siblings = False

def dispatch(self, request, *args, **kwargs):
"""
Expand Down Expand Up @@ -312,7 +312,7 @@ def get_renderer_context(self):
renderer_context = super(ProductRetrieveView, self).get_renderer_context()
if renderer_context['request'].accepted_renderer.format == 'html':
# add the product as Python object to the context
if self.one_product:
if not self.with_direct_siblings:
product = self.get_object()
renderer_context.update(
app_label=product._meta.app_label.lower(),
Expand All @@ -321,6 +321,7 @@ def get_renderer_context(self):
)
else:
product = self.get_object()
self.product_prev, self.product_next = self.get_object_with_direct_siblings(product)
renderer_context.update(
app_label=product._meta.app_label.lower(),
product_prev=self.product_prev,
Expand All @@ -333,39 +334,33 @@ def get_renderer_context(self):
def get_object(self):
if not hasattr(self, '_product'):
assert self.lookup_url_kwarg in self.kwargs
if self.one_product:
filter_kwargs = {
'active': True,
self.lookup_field: self.kwargs[self.lookup_url_kwarg],
}
if hasattr(self.product_model, 'translations'):
filter_kwargs.update(translations__language_code=get_language_from_request(self.request))
queryset = self.product_model.objects.filter(self.limit_choices_to, **filter_kwargs)
self._product = get_object_or_404(queryset)
else:
filter_kwargs = {
'active': True,
}
slug = self.kwargs[self.lookup_url_kwarg]
if hasattr(self.product_model, 'translations'):
filter_kwargs.update(translations__language_code=get_language_from_request(self.request))
queryset = self.product_model.objects.filter(self.limit_choices_to, **filter_kwargs)
queryset = CMSPagesFilterBackend().filter_queryset(self.request, queryset, self)
self.product_prev , self.product_cur, self.product_next = self.prev_cur_next_products(queryset, slug)
self._product = get_object_or_404(ProductModel, id=self.product_cur.id)
filter_kwargs = {
'active': True,
self.lookup_field: self.kwargs[self.lookup_url_kwarg],
}
if hasattr(self.product_model, 'translations'):
filter_kwargs.update(translations__language_code=get_language_from_request(self.request))
queryset = self.product_model.objects.filter(self.limit_choices_to, **filter_kwargs)
self._product = get_object_or_404(queryset)
return self._product

def prev_cur_next_products(self, objects, slug):
prev = cur = next = None
nb = len(list(objects))
for index, obj in enumerate(objects):
if obj.slug == slug:
cur = objects[index]
def get_object_with_direct_siblings(self, product):
previous = next = None
filter_siblings = {
'active': True,
}
if hasattr(self.product_model, 'translations'):
filter_siblings.update(translations__language_code=get_language_from_request(self.request))
queryset = self.product_model.objects.filter(self.limit_choices_to, **filter_siblings)
queryset = CMSPagesFilterBackend().filter_queryset(self.request, queryset, self)
nb = queryset.count()
for index, obj in enumerate(queryset):
if obj.slug == product.slug:
if index > 0:
prev = objects[index-1]
previous = queryset[index-1]
if index < (nb - 1):
next = objects[index+1]
return prev, cur, next
next = queryset[index+1]
return previous, next

class OnePageResultsSetPagination(pagination.PageNumberPagination):
def __init__(self):
Expand Down

0 comments on commit a292281

Please sign in to comment.