Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add projection argument to Reference.fetch() #380

Merged
merged 13 commits into from
Sep 21, 2022
11 changes: 9 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import pytest
from functools import namedtuple

from umongo import Document, fields
from umongo import Document, EmbeddedDocument, fields
from umongo.instance import Instance


Expand All @@ -17,16 +17,23 @@ def classroom_model(instance):
@instance.register
class Teacher(Document):
name = fields.StrField(required=True)
has_apple = fields.BooleanField(required=False, attribute='_has_apple')

@instance.register
class Room(EmbeddedDocument):
seats = fields.IntField(required=True, attribute='_seats')

@instance.register
class Course(Document):
name = fields.StrField(required=True)
teacher = fields.ReferenceField(Teacher, required=True, allow_none=True)
room = fields.EmbeddedField(Room, required=False, allow_none=True)

@instance.register
class Student(Document):
name = fields.StrField(required=True)
birthday = fields.DateTimeField()
courses = fields.ListField(fields.ReferenceField(Course))

return namedtuple('Mapping', ('Teacher', 'Course', 'Student'))(Teacher, Course, Student)
Mapping = namedtuple('Mapping', ('Teacher', 'Course', 'Student', 'Room'))
return Mapping(Teacher, Course, Student, Room)
4 changes: 4 additions & 0 deletions tests/frameworks/test_motor_asyncio.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,6 +337,10 @@ async def do_test():
assert teacher_fetched.name == 'Dr. Brown'
teacher_fetched = await course.teacher.fetch(force_reload=True)
assert teacher_fetched.name == 'M. Strickland'
# Test fetch with projection
teacher_fetched = await course.teacher.fetch(projection={'has_apple': 0},
force_reload=True)
assert teacher_fetched.has_apple is None
# Test bad ref as well
course.teacher = Reference(classroom_model.Teacher, ObjectId())
with pytest.raises(ma.ValidationError) as exc:
Expand Down
5 changes: 4 additions & 1 deletion tests/frameworks/test_pymongo.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ class Dummy(Document):
assert exc.value.messages == {'required_name': ['Missing data for required field.']}

def test_reference(self, classroom_model):
teacher = classroom_model.Teacher(name='M. Strickland')
teacher = classroom_model.Teacher(name='M. Strickland', has_apple=True)
teacher.commit()
course = classroom_model.Course(name='Hoverboard 101', teacher=teacher)
course.commit()
Expand All @@ -240,6 +240,9 @@ def test_reference(self, classroom_model):
teacher.commit()
assert course.teacher.fetch().name == 'Dr. Brown'
assert course.teacher.fetch(force_reload=True).name == 'M. Strickland'
# Test fetch with projection
assert course.teacher.fetch(projection={'has_apple': 0},
force_reload=True).has_apple is None
# Test bad ref as well
course.teacher = Reference(classroom_model.Teacher, ObjectId())
with pytest.raises(ma.ValidationError) as exc:
Expand Down
44 changes: 44 additions & 0 deletions tests/frameworks/test_tools.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import pytest

from pymongo import MongoClient

from umongo.frameworks import pymongo as framework_pymongo # noqa
from umongo.frameworks.tools import cook_find_projection

from ..common import TEST_DB


# All dependencies here are mandatory
dep_error = None


def make_db():
return MongoClient()[TEST_DB]


@pytest.fixture
def db():
return make_db()


def test_cook_find_projection(classroom_model):
projection = {'has_apple': 0}
cooked = cook_find_projection(classroom_model.Teacher, projection=projection)
assert cooked == {'_has_apple': 0}

projection = ['has_apple']
cooked = cook_find_projection(classroom_model.Teacher, projection=projection)
assert cooked == {'_has_apple': 1}

projection = ['name', 'has_apple']
cooked = cook_find_projection(classroom_model.Teacher, projection=projection)
assert cooked == {'name': 1, '_has_apple': 1}

# projection into a nested document's field which has a specified `attribute`
projection = ['room.seats']
cooked = cook_find_projection(classroom_model.Course, projection=projection)
assert cooked == {'room._seats': 1}

projection = {'room.seats': 0}
cooked = cook_find_projection(classroom_model.Course, projection=projection)
assert cooked == {'room._seats': 0}
3 changes: 3 additions & 0 deletions tests/frameworks/test_txmongo.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,9 @@ def test_reference(self, classroom_model):
assert teacher_fetched.name == 'Dr. Brown'
teacher_fetched = yield course.teacher.fetch(force_reload=True)
assert teacher_fetched.name == 'M. Strickland'
# Test fetch with projection
assert course.teacher.fetch(projection={'has_apple': 0},
force_reload=True).has_apple is None
# Test bad ref as well
course.teacher = Reference(classroom_model.Teacher, ObjectId())
with pytest.raises(ma.ValidationError) as exc:
Expand Down
16 changes: 12 additions & 4 deletions umongo/data_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,18 +165,26 @@ def __init__(self, document_cls, pk):
self.pk = pk
self._document = None

def fetch(self, no_data=False, force_reload=False):
def fetch(self, no_data=False, force_reload=False, projection=None):
"""
Retrieve from the database the referenced document

:param no_data: if True, the caller is only interested in whether or
not the document is present in database. This means the
:param no_data: if True, the caller is only interested in whether
the document is present in database. This means the
implementation may not retrieve document's data to save bandwidth.
:param force_reload: if True, ignore any cached data and reload referenced
document from database.
:param projection: if supplied, this is a dictionary or list describing
a projection which limits the data returned from database.
"""
raise NotImplementedError

@property
def exists(self):
"""
Check if the reference document exists in the database.
"""
raise NotImplementedError
# TODO replace no_data by `exists` function

def __repr__(self):
return '<object %s.%s(document=%s, pk=%r)>' % (
Expand Down
23 changes: 17 additions & 6 deletions umongo/frameworks/motor_asyncio.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from ..fields import ReferenceField, ListField, DictField, EmbeddedField
from ..query_mapper import map_query

from .tools import cook_find_filter, remove_cls_field_from_embedded_docs
from .tools import cook_find_filter, cook_find_projection, remove_cls_field_from_embedded_docs


SESSION = ContextVar("session", default=None)
Expand Down Expand Up @@ -254,12 +254,15 @@ async def io_validate(self, validate_all=False):
self.schema, self._data, partial=self._data.get_modified_fields())

@classmethod
async def find_one(cls, filter=None, *args, **kwargs):
async def find_one(cls, filter=None, projection=None, *args, **kwargs):
"""
Find a single document in database.
"""
filter = cook_find_filter(cls, filter)
ret = await cls.collection.find_one(filter, session=SESSION.get(), *args, **kwargs)
if projection:
projection = cook_find_projection(cls, projection)
ret = await cls.collection.find_one(filter, projection=projection,
session=SESSION.get(), *args, **kwargs)
if ret is not None:
ret = cls.build_from_mongo(ret, use_cls=True)
return ret
Expand Down Expand Up @@ -341,7 +344,10 @@ async def _io_validate_data_proxy(schema, data_proxy, partial=None):
async def _reference_io_validate(field, value):
if value is None:
return
await value.fetch(no_data=True)
exists = await value.exists
if not exists:
raise ma.ValidationError(value.error_messages['not_found'].format(
document=value.document_cls.__name__))


async def _list_io_validate(field, value):
Expand Down Expand Up @@ -394,16 +400,21 @@ def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._document = None

async def fetch(self, no_data=False, force_reload=False):
async def fetch(self, no_data=False, force_reload=False, projection=None):
if not self._document or force_reload:
if self.pk is None:
raise NoneReferenceError('Cannot retrieve a None Reference')
self._document = await self.document_cls.find_one(self.pk)
self._document = await self.document_cls.find_one(self.pk, projection=projection)
if not self._document:
raise ma.ValidationError(self.error_messages['not_found'].format(
document=self.document_cls.__name__))
return self._document

@property
async def exists(self):
return await self.document_cls.collection.find_one(self.pk,
projection={'_id': True}) is not None


class MotorAsyncIOBuilder(BaseBuilder):

Expand Down
21 changes: 15 additions & 6 deletions umongo/frameworks/pymongo.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from ..fields import ReferenceField, ListField, DictField, EmbeddedField
from ..query_mapper import map_query

from .tools import cook_find_filter, remove_cls_field_from_embedded_docs
from .tools import cook_find_filter, cook_find_projection, remove_cls_field_from_embedded_docs


SESSION = ContextVar("session", default=None)
Expand Down Expand Up @@ -197,12 +197,15 @@ def io_validate(self, validate_all=False):
self.schema, self._data, partial=self._data.get_modified_fields())

@classmethod
def find_one(cls, filter=None, *args, **kwargs):
def find_one(cls, filter=None, projection=None, *args, **kwargs):
"""
Find a single document in database.
"""
filter = cook_find_filter(cls, filter)
ret = cls.collection.find_one(filter, session=SESSION.get(), *args, **kwargs)
if projection:
projection = cook_find_projection(cls, projection)
ret = cls.collection.find_one(filter, projection=projection,
session=SESSION.get(), *args, **kwargs)
if ret is not None:
ret = cls.build_from_mongo(ret, use_cls=True)
return ret
Expand Down Expand Up @@ -275,7 +278,9 @@ def _io_validate_data_proxy(schema, data_proxy, partial=None):
def _reference_io_validate(field, value):
if value is None:
return
value.fetch(no_data=True)
if not value.exists:
raise ma.ValidationError(value.error_messages['not_found'].format(
document=value.document_cls.__name__))


def _list_io_validate(field, value):
Expand Down Expand Up @@ -322,16 +327,20 @@ def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._document = None

def fetch(self, no_data=False, force_reload=False):
def fetch(self, no_data=False, force_reload=False, projection=None):
if not self._document or force_reload:
if self.pk is None:
raise NoneReferenceError('Cannot retrieve a None Reference')
self._document = self.document_cls.find_one(self.pk)
self._document = self.document_cls.find_one(self.pk, projection=projection)
if not self._document:
raise ma.ValidationError(self.error_messages['not_found'].format(
document=self.document_cls.__name__))
return self._document

@property
def exists(self):
return self.document_cls.collection.find_one(self.pk, projection={'_id': True}) is not None


class PyMongoBuilder(BaseBuilder):

Expand Down
15 changes: 15 additions & 0 deletions umongo/frameworks/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,21 @@ def cook_find_filter(doc_cls, filter):
return filter


def cook_find_projection(doc_cls, projection):
"""
Replace field names in a projection by their database names.
"""
# a projection may be either:
# - a list of field names to return, or
# - a dict of field names and values to either return (value of 1) or not return (value of 0)
# in order to reuse as much of the `cook_find_filter` logic as possible,
# convert a list projection to a dict which produces the same result
if isinstance(projection, list):
projection = {field: 1 for field in projection}
projection = map_query(projection, doc_cls.schema.fields)
return projection


def remove_cls_field_from_embedded_docs(dict_in, embedded_docs):
"""Recursively remove _cls field from nested embedded documents

Expand Down
12 changes: 7 additions & 5 deletions umongo/frameworks/txmongo.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from ..fields import ReferenceField, ListField, DictField, EmbeddedField
from ..query_mapper import map_query

from .tools import cook_find_filter, remove_cls_field_from_embedded_docs
from .tools import cook_find_filter, cook_find_projection, remove_cls_field_from_embedded_docs


class TxMongoDocument(DocumentImplementation):
Expand Down Expand Up @@ -154,12 +154,14 @@ def io_validate(self, validate_all=False):

@classmethod
@inlineCallbacks
def find_one(cls, filter=None, *args, **kwargs):
def find_one(cls, filter=None, projection=None, *args, **kwargs):
"""
Find a single document in database.
"""
filter = cook_find_filter(cls, filter)
ret = yield cls.collection.find_one(filter, *args, **kwargs)
if projection:
projection = cook_find_projection(cls, projection)
ret = yield cls.collection.find_one(filter, projection=projection, *args, **kwargs)
if ret is not None:
ret = cls.build_from_mongo(ret, use_cls=True)
return ret
Expand Down Expand Up @@ -334,11 +336,11 @@ def __init__(self, *args, **kwargs):
self._document = None

@inlineCallbacks
def fetch(self, no_data=False, force_reload=False):
def fetch(self, no_data=False, force_reload=False, projection=None):
if not self._document or force_reload:
if self.pk is None:
raise NoneReferenceError('Cannot retrieve a None Reference')
self._document = yield self.document_cls.find_one(self.pk)
self._document = yield self.document_cls.find_one(self.pk, projection)
if not self._document:
raise ma.ValidationError(self.error_messages['not_found'].format(
document=self.document_cls.__name__))
Expand Down
4 changes: 2 additions & 2 deletions umongo/query_mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@ def map_entry_with_dots(entry, fields):

def map_query(query, fields):
"""
Retrieve given fields whithin the query and replace there name with
the one they should have within the database.
Retrieve given fields within the query and replace their names with
the names they should have within the database.
"""
if isinstance(query, dict):
mapped_query = {}
Expand Down