Skip to content

Commit

Permalink
Refs mozilla#265 - Uses customizable id_field to index and get django…
Browse files Browse the repository at this point in the history
… object.
  • Loading branch information
Florent authored and Florent Pigout committed Nov 4, 2014
1 parent bcec897 commit ba5bbb2
Show file tree
Hide file tree
Showing 5 changed files with 94 additions and 21 deletions.
6 changes: 5 additions & 1 deletion elasticutils/contrib/django/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,9 @@ class MappingType(BaseMappingType):
`get_model()`.
"""

id_field = 'pk'

def get_object(self):
"""Returns the database object for this result
Expand All @@ -205,7 +208,8 @@ def get_object(self):
self.get_model().objects.get(pk=self._id)
"""
return self.get_model().objects.get(pk=self._id)
kwargs = {self.id_field: self._id}
return self.get_model().objects.get(**kwargs)

@classmethod
def get_model(cls):
Expand Down
14 changes: 9 additions & 5 deletions elasticutils/contrib/django/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@


@task
def index_objects(mapping_type, ids, chunk_size=100, es=None, index=None):
def index_objects(mapping_type, ids, chunk_size=100, id_field='id', es=None,
index=None):
"""Index documents of a specified mapping type.
This allows for asynchronous indexing.
Expand Down Expand Up @@ -48,21 +49,24 @@ def update_in_index(sender, instance, **kw):

# Get the model this mapping type is based on.
model = mapping_type.get_model()
filter_key = '{0}__in'.format(id_field)

# Retrieve all the objects that we're going to index and do it in
# bulk.
for id_list in chunked(ids, chunk_size):
documents = []

for obj in model.objects.filter(id__in=id_list):
for obj in model.objects.filter(**{filter_key: id_list}):
try:
documents.append(mapping_type.extract_document(obj.id, obj))
except Exception as exc:
_id = str(getattr(obj, id_field))
documents.append(mapping_type.extract_document(_id, obj))
except StandardError as exc:
log.exception('Unable to extract document {0}: {1}'.format(
obj, repr(exc)))

if documents:
mapping_type.bulk_index(documents, id_field='id', es=es, index=index)
mapping_type.bulk_index(documents, id_field=id_field, es=es,
index=index)


@task
Expand Down
38 changes: 28 additions & 10 deletions elasticutils/contrib/django/tests/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# We need to put these in a separate module so they're easy to import
# on a test-by-test basis so that we can skip django-requiring tests
# if django isn't installed.
from uuid import UUID


from elasticutils.contrib.django import MappingType, Indexable
Expand All @@ -22,14 +23,20 @@ class SearchQuerySet(object):
# Yes. This is kind of crazy, but ... whatever.
def __init__(self, model):
self.model = model
self.id_field = model.id_field
self.steps = []

def get(self, pk):
pk = int(pk)
return [m for m in _model_cache if m.id == pk][0]

def filter(self, id__in=None):
self.steps.append(('filter', id__in))
def get(self, pk=None, uuid=None):
if pk:
pk = int(pk)
return [m for m in _model_cache if m.id == pk][0]
if uuid:
uuid = UUID(uuid)
return [m for m in _model_cache if m.uuid == uuid][0]
return []

def filter(self, id__in=None, uuid__in=None):
self.steps.append(('filter', id__in or uuid__in))
return self

def order_by(self, *fields):
Expand All @@ -47,7 +54,8 @@ def __iter__(self):

for mem in self.steps:
if mem[0] == 'filter':
objs = [obj for obj in objs if obj.id in mem[1]]
objs = [obj for obj in objs
if getattr(obj, self.id_field) in mem[1]]
elif mem[0] == 'order_by':
order_by_field = mem[1][0]
elif mem[0] == 'values_list':
Expand All @@ -58,16 +66,16 @@ def __iter__(self):

if values_list:
# Note: Hard-coded to just id and flat
objs = [obj.id for obj in objs]
objs = [getattr(obj, self.id_field) for obj in objs]
return iter(objs)


class Manager(object):
def get_query_set(self):
return SearchQuerySet(self)

def get(self, pk):
return self.get_query_set().get(pk)
def get(self, pk=None, uuid=None):
return self.get_query_set().get(pk=pk, uuid=uuid)

def filter(self, *args, **kwargs):
return self.get_query_set().filter(*args, **kwargs)
Expand All @@ -84,6 +92,7 @@ class FakeModel(object):
objects = Manager()

def __init__(self, **kw):
self.objects.id_field = kw.pop('id_field', 'id')
self._doc = kw
for key in kw:
setattr(self, key, kw[key])
Expand All @@ -102,3 +111,12 @@ def extract_document(cls, obj_id, obj=None):
'what to do with these args.')

return obj._doc

class FakeDjangoWithUuidMappingType(FakeDjangoMappingType):
id_field = 'uuid'

@classmethod
def extract_document(cls, obj_id, obj=None):
doc = super(FakeDjangoWithUuidMappingType, cls)\
.extract_document(obj_id, obj=obj)
return {k:str(v) for k,v in doc.iteritems()}
23 changes: 19 additions & 4 deletions elasticutils/contrib/django/tests/test_models.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
import uuid

from nose.tools import eq_

from elasticutils.contrib.django import S, get_es
from elasticutils.contrib.django.tests import (
FakeDjangoMappingType, FakeModel, reset_model_cache)
FakeDjangoMappingType, FakeDjangoWithUuidMappingType, FakeModel,
reset_model_cache)
from elasticutils.contrib.django.estestcase import ESTestCase


Expand All @@ -20,12 +23,13 @@ def tearDown(self):
IndexableTest.cleanup_index(FakeDjangoMappingType.get_index())
reset_model_cache()

def persist_data(self, data):
def persist_data(self, data, id_field='id'):
for doc in data:
FakeModel(**doc)
FakeModel(id_field=id_field, **doc)

# Index the document with .index()
FakeDjangoMappingType.index(doc, id_=doc['id'])
FakeDjangoMappingType.index({k:str(v) for k,v in doc.iteritems()},
id_=str(doc[id_field]))

self.refresh(FakeDjangoMappingType.get_index())

Expand All @@ -51,6 +55,17 @@ def test_get_object(self):
obj = s[0]
eq_(obj.object.id, 1)

def test_get_object_with_custom_pk(self):
data = [
{'uuid': uuid.uuid4(), 'name': 'odin skullcrusher'},
{'uuid': uuid.uuid4(), 'name': 'olaf bloodbiter'},
]
self.persist_data(data, id_field='uuid')

s = S(FakeDjangoWithUuidMappingType).query(name__prefix='odin')
obj = s[0]
eq_(obj.object.uuid, data[0]['uuid'])

def test_get_indexable(self):
self.persist_data([
{'id': 1, 'name': 'odin skullcrusher'},
Expand Down
34 changes: 33 additions & 1 deletion elasticutils/contrib/django/tests/test_tasks.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
import uuid

from nose.tools import eq_

from elasticutils.contrib.django import get_es
from elasticutils.contrib.django.tasks import index_objects, unindex_objects
from elasticutils.contrib.django.tests import (
FakeDjangoMappingType, FakeModel, reset_model_cache)
FakeDjangoMappingType, FakeDjangoWithUuidMappingType, FakeModel,
reset_model_cache)
from elasticutils.contrib.django.estestcase import ESTestCase


Expand Down Expand Up @@ -86,3 +89,32 @@ def bulk_index(cls, *args, **kwargs):
index_objects(MockMappingType, [1, 2, 3], es='crazy_es', index='crazy_index')
eq_(MockMappingType.index_kwarg, 'crazy_index')
eq_(MockMappingType.es_kwarg, 'crazy_es')

def test_tasks_with_custom_id_field(self):
docs = [
{'uuid': uuid.uuid4(), 'name': 'odin skullcrusher'},
{'uuid': uuid.uuid4(), 'name': 'heimdall kneebiter'},
{'uuid': uuid.uuid4(), 'name': 'erik rose'}
]

for d in docs:
FakeModel(id_field='uuid', **d)

ids = [d['uuid'] for d in docs]

# Test index_objects task
index_objects(FakeDjangoWithUuidMappingType, ids)
FakeDjangoWithUuidMappingType.refresh_index()
# nothing was indexed because a StandardError was catched silently,
# may be explicit should be better.
eq_(FakeDjangoWithUuidMappingType.search().count(), 0)

# Test everything has been indexed
index_objects(FakeDjangoWithUuidMappingType, ids, id_field='uuid')
FakeDjangoWithUuidMappingType.refresh_index()
eq_(FakeDjangoWithUuidMappingType.search().count(), 3)

# Test unindex_objects task
unindex_objects(FakeDjangoWithUuidMappingType, ids)
FakeDjangoWithUuidMappingType.refresh_index()
eq_(FakeDjangoWithUuidMappingType.search().count(), 0)

0 comments on commit ba5bbb2

Please sign in to comment.