From d7b770174f0d1a6b1c1b10b22163cb68fe211993 Mon Sep 17 00:00:00 2001 From: Roman Date: Mon, 21 Aug 2023 11:31:44 -0600 Subject: [PATCH 01/24] fix: 576 --- beanie/odm/fields.py | 17 +++++++++-------- tests/odm/test_relations.py | 15 +++++++++------ 2 files changed, 18 insertions(+), 14 deletions(-) diff --git a/beanie/odm/fields.py b/beanie/odm/fields.py index 4242b996..0026add3 100644 --- a/beanie/odm/fields.py +++ b/beanie/odm/fields.py @@ -296,14 +296,15 @@ async def fetch_list( ) ids_to_fetch.append(link.ref.id) - fetched_models = await document_class.find( # type: ignore - In("_id", ids_to_fetch), - with_children=True, - fetch_links=fetch_links, - ).to_list() - - for model in fetched_models: - data[model.id] = model + if ids_to_fetch: + fetched_models = await document_class.find( # type: ignore + In("_id", ids_to_fetch), + with_children=True, + fetch_links=fetch_links, + ).to_list() + + for model in fetched_models: + data[model.id] = model return list(data.values()) diff --git a/tests/odm/test_relations.py b/tests/odm/test_relations.py index d6599850..2ae7832c 100644 --- a/tests/odm/test_relations.py +++ b/tests/odm/test_relations.py @@ -173,6 +173,9 @@ async def test_multi_insert_links(self): assert isinstance(win, Window) assert win.id + async def test_fetch_after_insert(self, house_not_inserted): + await house_not_inserted.fetch_all_links() + class TestFind: async def test_prefetch_find_many(self, houses): @@ -395,19 +398,19 @@ class TestOther: async def test_query_composition(self): SYS = {"id", "revision_id"} - # Simple fields are initialized using the pydantic __fields__ internal property + # Simple fields are initialized using the pydantic model_fields internal property # such fields are properly isolated when multi inheritance is involved. - assert set(RootDocument.__fields__.keys()) == SYS | { + assert set(RootDocument.model_fields.keys()) == SYS | { "name", "link_root", } - assert set(ADocument.__fields__.keys()) == SYS | { + assert set(ADocument.model_fields.keys()) == SYS | { "name", "link_root", "surname", "link_a", } - assert set(BDocument.__fields__.keys()) == SYS | { + assert set(BDocument.model_fields.keys()) == SYS | { "name", "link_root", "email", @@ -491,7 +494,7 @@ async def test_with_chaining_aggregation(self): async def test_with_extra_allow(self, houses): res = await House.find(fetch_links=True).to_list() - assert res[0].__fields__.keys() == { + assert res[0].model_fields.keys() == { "id", "revision_id", "windows", @@ -503,7 +506,7 @@ async def test_with_extra_allow(self, houses): } res = await House.find_one(fetch_links=True) - assert res.__fields__.keys() == { + assert res.model_fields.keys() == { "id", "revision_id", "windows", From 645b47ee36509caeb5bf1607cfff99bc7db4381a Mon Sep 17 00:00:00 2001 From: Roman Date: Mon, 21 Aug 2023 11:44:12 -0600 Subject: [PATCH 02/24] fix: don't compare types --- tests/odm/test_id.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/odm/test_id.py b/tests/odm/test_id.py index 6c87249b..4e9f7279 100644 --- a/tests/odm/test_id.py +++ b/tests/odm/test_id.py @@ -14,4 +14,4 @@ async def test_integer_id(): doc = DocumentWithCustomIdInt(name="TEST", id=1) await doc.insert() new_doc = await DocumentWithCustomIdInt.get(doc.id) - assert type(new_doc.id) == int + assert isinstance(new_doc.id, int) From 329eadadfd0f08ed9ac99966e12f91e1401aca90 Mon Sep 17 00:00:00 2001 From: Roman Date: Mon, 21 Aug 2023 11:44:26 -0600 Subject: [PATCH 03/24] fix: don't compare types --- tests/odm/test_id.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/odm/test_id.py b/tests/odm/test_id.py index 4e9f7279..839e357c 100644 --- a/tests/odm/test_id.py +++ b/tests/odm/test_id.py @@ -7,7 +7,7 @@ async def test_uuid_id(): doc = DocumentWithCustomIdUUID(name="TEST") await doc.insert() new_doc = await DocumentWithCustomIdUUID.get(doc.id) - assert type(new_doc.id) == UUID + assert isinstance(new_doc.id, UUID) async def test_integer_id(): From 7c2a8795fc66149dc447b9d0616142ab965ce52c Mon Sep 17 00:00:00 2001 From: Roman Date: Mon, 21 Aug 2023 12:24:48 -0600 Subject: [PATCH 04/24] fix: 591 --- beanie/odm/documents.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/beanie/odm/documents.py b/beanie/odm/documents.py index 079a51f6..0d816cf5 100644 --- a/beanie/odm/documents.py +++ b/beanie/odm/documents.py @@ -1,5 +1,5 @@ import asyncio -from typing import ClassVar, AbstractSet +from typing import ClassVar, AbstractSet, Iterable from typing import ( Dict, Optional, @@ -373,7 +373,7 @@ async def insert_one( @classmethod async def insert_many( cls: Type[DocType], - documents: List[DocType], + documents: Iterable[DocType], session: Optional[ClientSession] = None, link_rule: WriteRules = WriteRules.DO_NOTHING, **pymongo_kwargs, From b9abfc4822592a3eb5592f571c9b10ef46f770df Mon Sep 17 00:00:00 2001 From: Roman Date: Mon, 21 Aug 2023 15:40:15 -0600 Subject: [PATCH 05/24] fix: 606, 600 --- beanie/odm/operators/find/logical.py | 17 ++++++++++++++++- beanie/odm/queries/find.py | 7 ++++++- beanie/odm/utils/parsing.py | 13 +++++++++++-- tests/odm/conftest.py | 4 ++++ tests/odm/models.py | 17 +++++++++++++++++ tests/odm/operators/find/test_logical.py | 13 +++++++++++-- tests/odm/test_relations.py | 18 ++++++++++++++++++ 7 files changed, 83 insertions(+), 6 deletions(-) diff --git a/beanie/odm/operators/find/logical.py b/beanie/odm/operators/find/logical.py index 0662adbb..2526d725 100644 --- a/beanie/odm/operators/find/logical.py +++ b/beanie/odm/operators/find/logical.py @@ -148,4 +148,19 @@ def __init__(self, expression: Mapping[str, Any]): @property def query(self): - return {"$not": self.expression} + if len(self.expression) == 1: + expression_key = list(self.expression.keys())[0] + if expression_key.startswith("$"): + raise AttributeError( + "Not operator can not be used with operators" + ) + value = self.expression[expression_key] + if isinstance(value, dict): + internal_key = list(value.keys())[0] + if internal_key.startswith("$"): + return {expression_key: {"$not": value}} + + return {expression_key: {"$not": {"$eq": value}}} + raise AttributeError( + "Not operator can only be used with one expression" + ) diff --git a/beanie/odm/queries/find.py b/beanie/odm/queries/find.py index 9ba8f6d5..dcf6880f 100644 --- a/beanie/odm/queries/find.py +++ b/beanie/odm/queries/find.py @@ -622,8 +622,13 @@ def build_aggregation_pipeline(self): aggregation_pipeline: List[Dict[str, Any]] = construct_lookup_queries( self.document_model ) + filter_query = self.get_filter_query() + if "$text" in filter_query: + text_query = filter_query["$text"] + aggregation_pipeline.insert(0, {"$match": {"$text": text_query}}) + del filter_query["$text"] - aggregation_pipeline.append({"$match": self.get_filter_query()}) + aggregation_pipeline.append({"$match": filter_query}) sort_pipeline = {"$sort": {i[0]: i[1] for i in self.sort_expressions}} if sort_pipeline["$sort"]: diff --git a/beanie/odm/utils/parsing.py b/beanie/odm/utils/parsing.py index 265e6d37..b3d3f2d1 100644 --- a/beanie/odm/utils/parsing.py +++ b/beanie/odm/utils/parsing.py @@ -6,13 +6,19 @@ DocWasNotRegisteredInUnionClass, ) from beanie.odm.interfaces.detector import ModelType -from beanie.odm.utils.pydantic import parse_model +from beanie.odm.utils.pydantic import parse_model, get_config_value if TYPE_CHECKING: from beanie.odm.documents import Document def merge_models(left: BaseModel, right: BaseModel) -> None: + """ + Merge two models + :param left: left model + :param right: right model + :return: None + """ from beanie.odm.fields import Link if hasattr(left, "_previous_revision_id") and hasattr( @@ -24,7 +30,10 @@ def merge_models(left: BaseModel, right: BaseModel) -> None: if isinstance(right_value, BaseModel) and isinstance( left_value, BaseModel ): - merge_models(left_value, right_value) + if get_config_value(left_value, "frozen"): + left.__setattr__(k, right_value) + else: + merge_models(left_value, right_value) continue if isinstance(right_value, list): links_found = False diff --git a/tests/odm/conftest.py b/tests/odm/conftest.py index 9c6aa739..f45b18f7 100644 --- a/tests/odm/conftest.py +++ b/tests/odm/conftest.py @@ -84,6 +84,8 @@ DocumentWithIndexMerging1, DocumentWithIndexMerging2, DocumentWithCustomInit, + DocumentWithTextIndexAndLink, + LinkDocumentForTextSeacrh, ) from tests.odm.views import TestView, TestViewWithLink @@ -253,6 +255,8 @@ async def init(db): DocumentWithIndexMerging1, DocumentWithIndexMerging2, DocumentWithCustomInit, + DocumentWithTextIndexAndLink, + LinkDocumentForTextSeacrh, ] await init_beanie( database=db, diff --git a/tests/odm/models.py b/tests/odm/models.py index 0ec7b0b2..b3bfa19d 100644 --- a/tests/odm/models.py +++ b/tests/odm/models.py @@ -837,3 +837,20 @@ class DocumentWithCustomInit(Document): @classmethod async def custom_init(cls): cls.s = "TEST2" + + +class LinkDocumentForTextSeacrh(Document): + i: int + + +class DocumentWithTextIndexAndLink(Document): + s: str + link: Link[LinkDocumentForTextSeacrh] + + class Settings: + indexes = [ + pymongo.IndexModel( + [("s", pymongo.TEXT)], + name="text_index", + ) + ] diff --git a/tests/odm/operators/find/test_logical.py b/tests/odm/operators/find/test_logical.py index 03d78907..574728f2 100644 --- a/tests/odm/operators/find/test_logical.py +++ b/tests/odm/operators/find/test_logical.py @@ -1,3 +1,5 @@ +import pytest + from beanie.odm.operators.find.logical import And, Not, Nor, Or from tests.odm.models import Sample @@ -10,9 +12,16 @@ async def test_and(): assert q == {"$and": [{"integer": 1}, {"nested.integer": {"$gt": 3}}]} -async def test_not(): +async def test_not(preset_documents): q = Not(Sample.integer == 1) - assert q == {"$not": {"integer": 1}} + assert q == {"integer": {"$not": {"$eq": 1}}} + + docs = await Sample.find(q).to_list() + assert len(docs) == 7 + + with pytest.raises(AttributeError): + q = Not(And(Sample.integer == 1, Sample.nested.integer > 3)) + await Sample.find(q).to_list() async def test_nor(): diff --git a/tests/odm/test_relations.py b/tests/odm/test_relations.py index 2ae7832c..80bded59 100644 --- a/tests/odm/test_relations.py +++ b/tests/odm/test_relations.py @@ -29,6 +29,8 @@ DocumentWithListLink, DocumentWithListOfLinks, DocumentToBeLinked, + DocumentWithTextIndexAndLink, + LinkDocumentForTextSeacrh, ) @@ -335,6 +337,22 @@ async def test_fetch_list_with_some_prefetched(self): for i in range(10): assert doc_with_links.links[i].id == docs[i].id + async def test_text_search(self): + doc = DocumentWithTextIndexAndLink( + s="hello world", link=LinkDocumentForTextSeacrh(i=1) + ) + await doc.insert(link_rule=WriteRules.WRITE) + + doc2 = DocumentWithTextIndexAndLink( + s="hi world", link=LinkDocumentForTextSeacrh(i=2) + ) + await doc2.insert(link_rule=WriteRules.WRITE) + + docs = await DocumentWithTextIndexAndLink.find( + {"$text": {"$search": "hello"}}, fetch_links=True + ).to_list() + assert len(docs) == 1 + class TestReplace: async def test_do_nothing(self, house): From 95184fb6b27a2ca2b90b4ce36b0806acbf38b5c0 Mon Sep 17 00:00:00 2001 From: Roman Date: Mon, 21 Aug 2023 16:41:50 -0600 Subject: [PATCH 06/24] fix: 629 --- beanie/__init__.py | 3 +++ beanie/odm/utils/parsing.py | 1 + docs/tutorial/update.md | 9 +++++++++ tests/odm/conftest.py | 2 ++ tests/odm/documents/test_update.py | 11 +++++++++++ tests/odm/models.py | 4 ++++ tests/odm/query/test_update.py | 17 +++++++++++++++++ tests/odm/test_relations.py | 1 + 8 files changed, 48 insertions(+) diff --git a/beanie/__init__.py b/beanie/__init__.py index e690f12c..24e03fd9 100644 --- a/beanie/__init__.py +++ b/beanie/__init__.py @@ -23,6 +23,7 @@ WriteRules, DeleteRules, ) +from beanie.odm.queries.update import UpdateResponse from beanie.odm.settings.timeseries import TimeSeriesConfig, Granularity from beanie.odm.utils.init import init_beanie from beanie.odm.documents import Document @@ -64,4 +65,6 @@ "DeleteRules", # Custom Types "DecimalAnnotation", + # UpdateResponse + "UpdateResponse", ] diff --git a/beanie/odm/utils/parsing.py b/beanie/odm/utils/parsing.py index b3d3f2d1..d1b03af8 100644 --- a/beanie/odm/utils/parsing.py +++ b/beanie/odm/utils/parsing.py @@ -43,6 +43,7 @@ def merge_models(left: BaseModel, right: BaseModel) -> None: break if links_found: continue + left.__setattr__(k, right_value) elif not isinstance(right_value, Link): left.__setattr__(k, right_value) diff --git a/docs/tutorial/update.md b/docs/tutorial/update.md index d9034bb0..bef86213 100644 --- a/docs/tutorial/update.md +++ b/docs/tutorial/update.md @@ -80,3 +80,12 @@ await Product.find_one(Product.name == "Milka").delete() await Product.find(Product.category.name == "Chocolate").delete() ``` + +## Response Type + +For the object methods `update` and `upsert`, you can use the `response_type` parameter to specify the type of response. + +The options are: +- `UpdateResponse.UPDATE_RESULT` - returns the result of the update operation. +- `UpdateResponse.NEW_DOCUMENT` - returns the newly updated document. +- `UpdateResponse.OLD_DOCUMENT` - returns the document before the update. diff --git a/tests/odm/conftest.py b/tests/odm/conftest.py index f45b18f7..a5360e65 100644 --- a/tests/odm/conftest.py +++ b/tests/odm/conftest.py @@ -86,6 +86,7 @@ DocumentWithCustomInit, DocumentWithTextIndexAndLink, LinkDocumentForTextSeacrh, + TestDocumentWithList, ) from tests.odm.views import TestView, TestViewWithLink @@ -257,6 +258,7 @@ async def init(db): DocumentWithCustomInit, DocumentWithTextIndexAndLink, LinkDocumentForTextSeacrh, + TestDocumentWithList, ] await init_beanie( database=db, diff --git a/tests/odm/documents/test_update.py b/tests/odm/documents/test_update.py index adc941d4..4a2e2072 100644 --- a/tests/odm/documents/test_update.py +++ b/tests/odm/documents/test_update.py @@ -9,6 +9,7 @@ DocumentTestModel, ModelWithOptionalField, DocumentWithKeepNullsFalse, + TestDocumentWithList, ) @@ -278,3 +279,13 @@ async def test_save_changes_keep_nulls_false(): # {"test_str": "smth_else"}, session=session # ).to_list() # assert len(smth_else_documetns) == 17 + + +async def test_update_list(): + test_record = TestDocumentWithList(list_values=["1", "2", "3"]) + test_record = await test_record.insert() + update_data = test_record.dict() + update_data["list_values"] = ["5", "6", "7"] + + updated_test_record = await test_record.update({"$set": update_data}) + assert updated_test_record.list_values == update_data["list_values"] diff --git a/tests/odm/models.py b/tests/odm/models.py index b3bfa19d..320679b0 100644 --- a/tests/odm/models.py +++ b/tests/odm/models.py @@ -854,3 +854,7 @@ class Settings: name="text_index", ) ] + + +class TestDocumentWithList(Document): + list_values: List[str] diff --git a/tests/odm/query/test_update.py b/tests/odm/query/test_update.py index f0025660..977734f0 100644 --- a/tests/odm/query/test_update.py +++ b/tests/odm/query/test_update.py @@ -3,6 +3,7 @@ import pytest from beanie.odm.operators.update.general import Set, Max +from beanie.odm.queries.update import UpdateResponse from tests.odm.models import Sample @@ -194,6 +195,22 @@ async def test_update_one_upsert_without_insert( assert len(new_docs) == 0 +async def test_update_one_upsert_without_insert_return_doc( + preset_documents, sample_doc_not_saved +): + result = await Sample.find_one(Sample.integer > 1).upsert( + Set({Sample.integer: 100}), + on_insert=sample_doc_not_saved, + response_type=UpdateResponse.NEW_DOCUMENT, + ) + assert isinstance(result, Sample) + + new_docs = await Sample.find_many( + Sample.string == sample_doc_not_saved.string + ).to_list() + assert len(new_docs) == 0 + + async def test_update_pymongo_kwargs(preset_documents): with pytest.raises(TypeError): await Sample.find_many(Sample.increment > 4).update( diff --git a/tests/odm/test_relations.py b/tests/odm/test_relations.py index 80bded59..ef8298ea 100644 --- a/tests/odm/test_relations.py +++ b/tests/odm/test_relations.py @@ -172,6 +172,7 @@ async def test_multi_insert_links(self): house.windows.append(new_window) await house.save(link_rule=WriteRules.WRITE) for win in house.windows: + print(type(win), win) assert isinstance(win, Window) assert win.id From cb10d9df1fa011cd974eea7892033d8a2e4b21e3 Mon Sep 17 00:00:00 2001 From: Roman Date: Mon, 21 Aug 2023 18:32:34 -0600 Subject: [PATCH 07/24] fix: 646 --- beanie/odm/queries/update.py | 20 +++++++++++++++++++- tests/odm/documents/test_update.py | 11 +++++++++++ 2 files changed, 30 insertions(+), 1 deletion(-) diff --git a/beanie/odm/queries/update.py b/beanie/odm/queries/update.py index 42cd2290..786c66f2 100644 --- a/beanie/odm/queries/update.py +++ b/beanie/odm/queries/update.py @@ -69,16 +69,34 @@ def __init__( @property def update_query(self) -> Dict[str, Any]: - query: Dict[str, Any] = {} + query: Union[Dict[str, Any], List[Dict[str, Any]], None] = None for expression in self.update_expressions: if isinstance(expression, BaseUpdateOperator): + if query is None: + query = {} + if isinstance(query, list): + raise TypeError("Wrong expression type") query.update(expression.query) elif isinstance(expression, dict): + if query is None: + query = {} + if isinstance(query, list): + raise TypeError("Wrong expression type") query.update(expression) elif isinstance(expression, SetRevisionId): + if query is None: + query = {} + if isinstance(query, list): + raise TypeError("Wrong expression type") set_query = query.get("$set", {}) set_query.update(expression.query.get("$set", {})) query["$set"] = set_query + elif isinstance(expression, list): + if query is None: + query = [] + if isinstance(query, dict): + raise TypeError("Wrong expression type") + query.extend(expression) else: raise TypeError("Wrong expression type") return Encoder(custom_encoders=self.encoders).encode(query) diff --git a/tests/odm/documents/test_update.py b/tests/odm/documents/test_update.py index 4a2e2072..8084f7d8 100644 --- a/tests/odm/documents/test_update.py +++ b/tests/odm/documents/test_update.py @@ -10,6 +10,7 @@ ModelWithOptionalField, DocumentWithKeepNullsFalse, TestDocumentWithList, + Sample, ) @@ -289,3 +290,13 @@ async def test_update_list(): updated_test_record = await test_record.update({"$set": update_data}) assert updated_test_record.list_values == update_data["list_values"] + + +async def test_update_using_pipeline(preset_documents): + await Sample.all().update( + [{"$set": {"integer": 10000}}, {"$set": {"string": "TEST3"}}] + ) + all_docs = await Sample.find_many({}).to_list() + for doc in all_docs: + assert doc.integer == 10000 + assert doc.string == "TEST3" From dc4d78be9dcbcac620bdb2f09628259ef79ad272 Mon Sep 17 00:00:00 2001 From: Roman Date: Mon, 21 Aug 2023 19:44:28 -0600 Subject: [PATCH 08/24] fix: 648 --- beanie/odm/documents.py | 125 ++++++++++++++++++++++++------------ tests/fastapi/routes.py | 6 ++ tests/fastapi/test_api.py | 8 +++ tests/odm/test_relations.py | 1 - 4 files changed, 99 insertions(+), 41 deletions(-) diff --git a/beanie/odm/documents.py b/beanie/odm/documents.py index 0d816cf5..f534dada 100644 --- a/beanie/odm/documents.py +++ b/beanie/odm/documents.py @@ -1057,46 +1057,91 @@ def get_hidden_fields(cls): if get_extra_field_info(model_field, "hidden") is True ) - def dict( - self, - *, - include: Union["AbstractSetIntStr", "MappingIntStrAny"] = None, - exclude: Union["AbstractSetIntStr", "MappingIntStrAny"] = None, - by_alias: bool = False, - skip_defaults: bool = False, - exclude_hidden: bool = True, - exclude_unset: bool = False, - exclude_defaults: bool = False, - exclude_none: bool = False, - ) -> "DictStrAny": - """ - Overriding of the respective method from Pydantic - Hides fields, marked as "hidden - """ - if exclude_hidden: - if isinstance(exclude, AbstractSet): - exclude = {*self._hidden_fields, *exclude} - elif isinstance(exclude, Mapping): - exclude = dict( - {k: True for k in self._hidden_fields}, **exclude - ) # type: ignore - elif exclude is None: - exclude = self._hidden_fields - - kwargs = { - "include": include, - "exclude": exclude, - "by_alias": by_alias, - "exclude_unset": exclude_unset, - "exclude_defaults": exclude_defaults, - "exclude_none": exclude_none, - } - - # TODO: Remove this check when skip_defaults are no longer supported - if skip_defaults: - kwargs["skip_defaults"] = skip_defaults - - return super().dict(**kwargs) + if IS_PYDANTIC_V2: + + def model_dump( + self, + *, + mode="python", + include: Union["AbstractSetIntStr", "MappingIntStrAny"] = None, + exclude: Union["AbstractSetIntStr", "MappingIntStrAny"] = None, + by_alias: bool = False, + exclude_hidden: bool = True, + exclude_unset: bool = False, + exclude_defaults: bool = False, + exclude_none: bool = False, + round_trip: bool = False, + warnings: bool = True, + ) -> "DictStrAny": + """ + Overriding of the respective method from Pydantic + Hides fields, marked as "hidden + """ + if exclude_hidden: + if isinstance(exclude, AbstractSet): + exclude = {*self._hidden_fields, *exclude} + elif isinstance(exclude, Mapping): + exclude = dict( + {k: True for k in self._hidden_fields}, **exclude + ) # type: ignore + elif exclude is None: + exclude = self._hidden_fields + + kwargs = { + "include": include, + "exclude": exclude, + "by_alias": by_alias, + "exclude_unset": exclude_unset, + "exclude_defaults": exclude_defaults, + "exclude_none": exclude_none, + "round_trip": round_trip, + "warnings": warnings, + } + + return super().model_dump(**kwargs) + + else: + + def dict( + self, + *, + include: Union["AbstractSetIntStr", "MappingIntStrAny"] = None, + exclude: Union["AbstractSetIntStr", "MappingIntStrAny"] = None, + by_alias: bool = False, + skip_defaults: bool = False, + exclude_hidden: bool = True, + exclude_unset: bool = False, + exclude_defaults: bool = False, + exclude_none: bool = False, + ) -> "DictStrAny": + """ + Overriding of the respective method from Pydantic + Hides fields, marked as "hidden + """ + if exclude_hidden: + if isinstance(exclude, AbstractSet): + exclude = {*self._hidden_fields, *exclude} + elif isinstance(exclude, Mapping): + exclude = dict( + {k: True for k in self._hidden_fields}, **exclude + ) # type: ignore + elif exclude is None: + exclude = self._hidden_fields + + kwargs = { + "include": include, + "exclude": exclude, + "by_alias": by_alias, + "exclude_unset": exclude_unset, + "exclude_defaults": exclude_defaults, + "exclude_none": exclude_none, + } + + # TODO: Remove this check when skip_defaults are no longer supported + if skip_defaults: + kwargs["skip_defaults"] = skip_defaults + + return super().dict(**kwargs) @wrap_with_actions(event_type=EventTypes.VALIDATE_ON_SAVE) async def validate_self(self, *args, **kwargs): diff --git a/tests/fastapi/routes.py b/tests/fastapi/routes.py index 3926bbcc..fc655339 100644 --- a/tests/fastapi/routes.py +++ b/tests/fastapi/routes.py @@ -23,6 +23,12 @@ async def create_window(window: WindowAPI): return window +@house_router.post("/windows_2/") +async def create_window_2(window: WindowAPI): + await window.create() + return window + + @house_router.post("/houses/", response_model=HouseAPI) async def create_house(window: WindowAPI): house = HouseAPI(name="test_name", windows=[window]) diff --git a/tests/fastapi/test_api.py b/tests/fastapi/test_api.py index 161d6874..320e6881 100644 --- a/tests/fastapi/test_api.py +++ b/tests/fastapi/test_api.py @@ -35,3 +35,11 @@ async def test_create_house_2(api_client): resp = await api_client.post("/v1/houses_2/", json=payload) resp_json = resp.json() assert len(resp_json["windows"]) == 1 + + +async def test_revision_id(api_client): + payload = {"x": 10, "y": 20} + resp = await api_client.post("/v1/windows_2/", json=payload) + resp_json = resp.json() + assert "revision_id" not in resp_json + assert resp_json == {"x": 10, "y": 20, "_id": resp_json["_id"]} diff --git a/tests/odm/test_relations.py b/tests/odm/test_relations.py index ef8298ea..80bded59 100644 --- a/tests/odm/test_relations.py +++ b/tests/odm/test_relations.py @@ -172,7 +172,6 @@ async def test_multi_insert_links(self): house.windows.append(new_window) await house.save(link_rule=WriteRules.WRITE) for win in house.windows: - print(type(win), win) assert isinstance(win, Window) assert win.id From 3c58c17d37c3b931661440fc4607183ea3496ea8 Mon Sep 17 00:00:00 2001 From: Roman Date: Sat, 9 Sep 2023 13:31:14 -0600 Subject: [PATCH 09/24] new type: BsonBinary --- beanie/__init__.py | 2 ++ beanie/odm/custom_types/bson/__init__.py | 0 beanie/odm/custom_types/bson/binary.py | 37 ++++++++++++++++++++++ tests/fastapi/routes.py | 5 ++- tests/odm/conftest.py | 2 ++ tests/odm/custom_types/__init__.py | 0 tests/odm/custom_types/test_bson_binary.py | 19 +++++++++++ tests/odm/models.py | 7 +++- 8 files changed, 68 insertions(+), 4 deletions(-) create mode 100644 beanie/odm/custom_types/bson/__init__.py create mode 100644 beanie/odm/custom_types/bson/binary.py create mode 100644 tests/odm/custom_types/__init__.py create mode 100644 tests/odm/custom_types/test_bson_binary.py diff --git a/beanie/__init__.py b/beanie/__init__.py index 24e03fd9..63817ffe 100644 --- a/beanie/__init__.py +++ b/beanie/__init__.py @@ -15,6 +15,7 @@ ) from beanie.odm.bulk import BulkWriter from beanie.odm.custom_types import DecimalAnnotation +from beanie.odm.custom_types.bson.binary import BsonBinary from beanie.odm.fields import ( PydanticObjectId, Indexed, @@ -65,6 +66,7 @@ "DeleteRules", # Custom Types "DecimalAnnotation", + "BsonBinary", # UpdateResponse "UpdateResponse", ] diff --git a/beanie/odm/custom_types/bson/__init__.py b/beanie/odm/custom_types/bson/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/beanie/odm/custom_types/bson/binary.py b/beanie/odm/custom_types/bson/binary.py new file mode 100644 index 00000000..b6acb7cb --- /dev/null +++ b/beanie/odm/custom_types/bson/binary.py @@ -0,0 +1,37 @@ +from typing import Any, Callable + +import bson +from pydantic import GetJsonSchemaHandler +from pydantic.fields import FieldInfo +from pydantic.json_schema import JsonSchemaValue +from pydantic_core import core_schema + + +class BsonBinary(bson.Binary): + @classmethod + def __get_pydantic_core_schema__( + cls, + _source_type: Any, + _handler: Callable[[Any], core_schema.CoreSchema], # type: ignore + ) -> core_schema.CoreSchema: # type: ignore + def validate(value, _: FieldInfo) -> bson.Binary: + if isinstance(value, bson.Binary): + return value + if isinstance(value, bytes): + return bson.Binary(value) + raise ValueError("Value must be bytes or bson.Binary") + + python_schema = core_schema.general_plain_validator_function(validate) # type: ignore + + return core_schema.json_or_python_schema( + json_schema=core_schema.float_schema(), + python_schema=python_schema, + ) + + @classmethod + def __get_pydantic_json_schema__( + cls, + _core_schema: core_schema.CoreSchema, # type: ignore + handler: GetJsonSchemaHandler, + ) -> JsonSchemaValue: + return handler(core_schema.str_schema()) diff --git a/tests/fastapi/routes.py b/tests/fastapi/routes.py index fc655339..581012c2 100644 --- a/tests/fastapi/routes.py +++ b/tests/fastapi/routes.py @@ -24,9 +24,8 @@ async def create_window(window: WindowAPI): @house_router.post("/windows_2/") -async def create_window_2(window: WindowAPI): - await window.create() - return window +async def create_window_2(window: WindowAPI) -> WindowAPI: + return await window.save() @house_router.post("/houses/", response_model=HouseAPI) diff --git a/tests/odm/conftest.py b/tests/odm/conftest.py index a5360e65..4d86c279 100644 --- a/tests/odm/conftest.py +++ b/tests/odm/conftest.py @@ -87,6 +87,7 @@ DocumentWithTextIndexAndLink, LinkDocumentForTextSeacrh, TestDocumentWithList, + TestDocumentWithBsonBinaryField, ) from tests.odm.views import TestView, TestViewWithLink @@ -259,6 +260,7 @@ async def init(db): DocumentWithTextIndexAndLink, LinkDocumentForTextSeacrh, TestDocumentWithList, + TestDocumentWithBsonBinaryField, ] await init_beanie( database=db, diff --git a/tests/odm/custom_types/__init__.py b/tests/odm/custom_types/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/odm/custom_types/test_bson_binary.py b/tests/odm/custom_types/test_bson_binary.py new file mode 100644 index 00000000..e2616c1c --- /dev/null +++ b/tests/odm/custom_types/test_bson_binary.py @@ -0,0 +1,19 @@ +import bson + +from tests.odm.models import TestDocumentWithBsonBinaryField + + +async def test_bson_binary(): + doc = TestDocumentWithBsonBinaryField(binary_field=bson.Binary(b"test")) + await doc.insert() + assert doc.binary_field == bson.Binary(b"test") + + new_doc = await TestDocumentWithBsonBinaryField.get(doc.id) + assert new_doc.binary_field == bson.Binary(b"test") + + doc = TestDocumentWithBsonBinaryField(binary_field=b"test") + await doc.insert() + assert doc.binary_field == bson.Binary(b"test") + + new_doc = await TestDocumentWithBsonBinaryField.get(doc.id) + assert new_doc.binary_field == bson.Binary(b"test") diff --git a/tests/odm/models.py b/tests/odm/models.py index 320679b0..3e83648b 100644 --- a/tests/odm/models.py +++ b/tests/odm/models.py @@ -1,5 +1,4 @@ import datetime -from beanie import DecimalAnnotation from ipaddress import ( IPv4Address, IPv4Interface, @@ -25,6 +24,7 @@ from pydantic.color import Color from pymongo import IndexModel +from beanie import DecimalAnnotation from beanie import ( Document, Indexed, @@ -35,6 +35,7 @@ Save, ) from beanie.odm.actions import Delete, after_event, before_event +from beanie.odm.custom_types.bson.binary import BsonBinary from beanie.odm.fields import Link, PydanticObjectId, BackLink from beanie.odm.settings.timeseries import TimeSeriesConfig from beanie.odm.union_doc import UnionDoc @@ -858,3 +859,7 @@ class Settings: class TestDocumentWithList(Document): list_values: List[str] + + +class TestDocumentWithBsonBinaryField(Document): + binary_field: BsonBinary From 68f406d473e932aaffdbf65db9c157d6f0b61fd3 Mon Sep 17 00:00:00 2001 From: Roman Date: Sat, 9 Sep 2023 14:41:04 -0600 Subject: [PATCH 10/24] rollback: response model in the fastapi atest --- tests/fastapi/routes.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/fastapi/routes.py b/tests/fastapi/routes.py index 581012c2..7d1f918e 100644 --- a/tests/fastapi/routes.py +++ b/tests/fastapi/routes.py @@ -24,7 +24,7 @@ async def create_window(window: WindowAPI): @house_router.post("/windows_2/") -async def create_window_2(window: WindowAPI) -> WindowAPI: +async def create_window_2(window: WindowAPI): return await window.save() From b89f95d70edf799a1519fd4a2bf4918439e03539 Mon Sep 17 00:00:00 2001 From: Roman Date: Sat, 9 Sep 2023 15:05:10 -0600 Subject: [PATCH 11/24] fix: 664 validate on save before save --- beanie/odm/documents.py | 1 + 1 file changed, 1 insertion(+) diff --git a/beanie/odm/documents.py b/beanie/odm/documents.py index f534dada..835d91db 100644 --- a/beanie/odm/documents.py +++ b/beanie/odm/documents.py @@ -483,6 +483,7 @@ async def replace( @wrap_with_actions(EventTypes.SAVE) @save_state_after + @validate_self_before async def save( self: DocType, session: Optional[ClientSession] = None, From dc2c3fe22107036d31635fcc6513210b491eeebf Mon Sep 17 00:00:00 2001 From: Roman Date: Sat, 9 Sep 2023 15:58:56 -0600 Subject: [PATCH 12/24] fix: 668 Support root models --- beanie/odm/utils/encoder.py | 15 ++++++++++++--- beanie/odm/utils/pydantic.py | 24 ++++++++++++++++++++++++ tests/odm/conftest.py | 2 ++ tests/odm/models.py | 8 ++++++++ tests/odm/test_root_models.py | 14 ++++++++++++++ 5 files changed, 60 insertions(+), 3 deletions(-) create mode 100644 tests/odm/test_root_models.py diff --git a/beanie/odm/utils/encoder.py b/beanie/odm/utils/encoder.py index f22fa402..af8e5fa3 100644 --- a/beanie/odm/utils/encoder.py +++ b/beanie/odm/utils/encoder.py @@ -25,12 +25,13 @@ import bson from bson import ObjectId, DBRef, Binary, Decimal128, Regex -from pydantic import BaseModel +from pydantic import BaseModel, RootModel from pydantic import SecretBytes, SecretStr from pydantic.color import Color from beanie.odm.fields import Link, LinkTypes from beanie.odm import documents +from beanie.odm.utils.pydantic import get_iterator ENCODERS_BY_TYPE: Dict[Type[Any], Callable[[Any], Any]] = { Color: str, @@ -110,7 +111,7 @@ def encode_document(self, obj): if obj._inheritance_inited: obj_dict[obj.get_settings().class_id] = obj._class_id - for k, o in obj._iter(to_dict=False, by_alias=self.by_alias): + for k, o in get_iterator(obj, by_alias=self.by_alias): if k not in self.exclude and ( self.keep_nulls is True or o is not None ): @@ -165,7 +166,7 @@ def encode_base_model(self, obj): BaseModel case """ obj_dict = {} - for k, o in obj._iter(to_dict=False, by_alias=self.by_alias): + for k, o in get_iterator(obj, by_alias=self.by_alias): if k not in self.exclude and ( self.keep_nulls is True or o is not None ): @@ -173,6 +174,12 @@ def encode_base_model(self, obj): return obj_dict + def encode_root_model(self, obj): + """ + RootModel case + """ + return self._encode(obj.root) + def encode_dict(self, obj): """ Dictionary case @@ -204,6 +211,8 @@ def _encode( if isinstance(obj, documents.Document): return self.encode_document(obj) + if isinstance(obj, RootModel): + return self.encode_root_model(obj) if isinstance(obj, BaseModel): return self.encode_base_model(obj) if isinstance(obj, dict): diff --git a/beanie/odm/utils/pydantic.py b/beanie/odm/utils/pydantic.py index 1350520e..1d6c7899 100644 --- a/beanie/odm/utils/pydantic.py +++ b/beanie/odm/utils/pydantic.py @@ -60,3 +60,27 @@ def get_model_dump(model): return model.model_dump() else: return model.dict() + + +def get_iterator(model, by_alias=False): + if IS_PYDANTIC_V2: + + def _get_aliases(model): + aliases = {} + for k, v in model.model_fields.items(): + if v.alias is not None: + aliases[k] = v.alias + else: + aliases[k] = k + return aliases + + def _iter(model, by_alias=False): + for k, v in model.__iter__(): + if by_alias: + yield _get_aliases(model)[k], v + else: + yield k, v + + return _iter(model, by_alias=by_alias) + else: + return model._iter(to_dict=False, by_alias=False) diff --git a/tests/odm/conftest.py b/tests/odm/conftest.py index 4d86c279..cf55f06d 100644 --- a/tests/odm/conftest.py +++ b/tests/odm/conftest.py @@ -88,6 +88,7 @@ LinkDocumentForTextSeacrh, TestDocumentWithList, TestDocumentWithBsonBinaryField, + TestDocumentWithRootModelAsAField, ) from tests.odm.views import TestView, TestViewWithLink @@ -261,6 +262,7 @@ async def init(db): LinkDocumentForTextSeacrh, TestDocumentWithList, TestDocumentWithBsonBinaryField, + TestDocumentWithRootModelAsAField, ] await init_beanie( database=db, diff --git a/tests/odm/models.py b/tests/odm/models.py index 3e83648b..4a74f006 100644 --- a/tests/odm/models.py +++ b/tests/odm/models.py @@ -20,6 +20,7 @@ SecretBytes, SecretStr, ConfigDict, + RootModel, ) from pydantic.color import Color from pymongo import IndexModel @@ -863,3 +864,10 @@ class TestDocumentWithList(Document): class TestDocumentWithBsonBinaryField(Document): binary_field: BsonBinary + + +Pets = RootModel[List[str]] + + +class TestDocumentWithRootModelAsAField(Document): + pets: Pets diff --git a/tests/odm/test_root_models.py b/tests/odm/test_root_models.py new file mode 100644 index 00000000..d3e750f0 --- /dev/null +++ b/tests/odm/test_root_models.py @@ -0,0 +1,14 @@ +from tests.odm.models import TestDocumentWithRootModelAsAField + + +class TestRootModels: + async def test_insert(self): + doc = TestDocumentWithRootModelAsAField(pets=["dog", "cat", "fish"]) + await doc.insert() + + new_doc = await TestDocumentWithRootModelAsAField.get(doc.id) + assert new_doc.pets.root == ["dog", "cat", "fish"] + + collection = TestDocumentWithRootModelAsAField.get_motor_collection() + raw_doc = await collection.find_one({"_id": doc.id}) + assert raw_doc["pets"] == ["dog", "cat", "fish"] From e00f1550f89f00675bb599e8214e78cdecd4685e Mon Sep 17 00:00:00 2001 From: Roman Date: Sat, 9 Sep 2023 16:02:08 -0600 Subject: [PATCH 13/24] improvement: typing bug tag for GH actions --- .github/workflows/close_inactive_issues.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/close_inactive_issues.yml b/.github/workflows/close_inactive_issues.yml index f5935c97..d6ac6430 100644 --- a/.github/workflows/close_inactive_issues.yml +++ b/.github/workflows/close_inactive_issues.yml @@ -16,7 +16,7 @@ jobs: stale-pr-message: 'This PR is stale because it has been open 45 days with no activity.' close-issue-message: 'This issue was closed because it has been stalled for 14 days with no activity.' close-pr-message: 'This PR was closed because it has been stalled for 14 days with no activity.' - exempt-issue-labels: 'bug,feature-request' + exempt-issue-labels: 'bug,feature-request,typing bug,feature request' days-before-issue-stale: 30 days-before-pr-stale: 45 days-before-issue-close: 14 From 90ca8252c249d16db7afdc19754c279ff3a70155 Mon Sep 17 00:00:00 2001 From: Roman Date: Sat, 9 Sep 2023 16:15:12 -0600 Subject: [PATCH 14/24] fix: alias for extra field --- beanie/odm/utils/pydantic.py | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/beanie/odm/utils/pydantic.py b/beanie/odm/utils/pydantic.py index 1d6c7899..1b16e8fc 100644 --- a/beanie/odm/utils/pydantic.py +++ b/beanie/odm/utils/pydantic.py @@ -65,19 +65,17 @@ def get_model_dump(model): def get_iterator(model, by_alias=False): if IS_PYDANTIC_V2: - def _get_aliases(model): - aliases = {} - for k, v in model.model_fields.items(): - if v.alias is not None: - aliases[k] = v.alias - else: - aliases[k] = k - return aliases + def _get_alias(model, k): + v = model.model_fields.get(k) + if v is not None: + return v.alias or k + else: + return k def _iter(model, by_alias=False): for k, v in model.__iter__(): if by_alias: - yield _get_aliases(model)[k], v + yield _get_alias(model, k), v else: yield k, v From 1ed9832a4cd78189c31c217b8ac261fdfa229285 Mon Sep 17 00:00:00 2001 From: Roman Date: Sat, 9 Sep 2023 18:11:05 -0600 Subject: [PATCH 15/24] fix: fix 676 - Pydantic warnings --- beanie/odm/bulk.py | 14 ++- beanie/odm/documents.py | 31 +++-- beanie/odm/settings/base.py | 14 ++- beanie/odm/settings/document.py | 14 ++- beanie/odm/utils/encoder.py | 2 - pyproject.toml | 3 +- tests/odm/conftest.py | 26 ++--- tests/odm/custom_types/test_bson_binary.py | 10 +- tests/odm/documents/test_inheritance.py | 12 +- tests/odm/documents/test_update.py | 20 +++- tests/odm/models.py | 128 +++++++++++++++++---- tests/odm/query/test_find.py | 8 +- tests/odm/test_fields.py | 12 +- tests/odm/test_relations.py | 20 +++- tests/odm/test_root_models.py | 8 +- tests/odm/test_views.py | 8 +- tests/odm/views.py | 4 +- 17 files changed, 244 insertions(+), 90 deletions(-) diff --git a/beanie/odm/bulk.py b/beanie/odm/bulk.py index dde27a58..b7a7d7db 100644 --- a/beanie/odm/bulk.py +++ b/beanie/odm/bulk.py @@ -1,5 +1,6 @@ from typing import Dict, Any, List, Optional, Union, Type, Mapping +from beanie.odm.utils.pydantic import IS_PYDANTIC_V2 from pydantic import BaseModel, Field from pymongo import ( InsertOne, @@ -12,6 +13,9 @@ from pymongo.results import BulkWriteResult from pymongo.client_session import ClientSession +if IS_PYDANTIC_V2: + from pydantic import ConfigDict + class Operation(BaseModel): operation: Union[ @@ -27,8 +31,14 @@ class Operation(BaseModel): pymongo_kwargs: Dict[str, Any] = Field(default_factory=dict) object_class: Type - class Config: - arbitrary_types_allowed = True + if IS_PYDANTIC_V2: + model_config = ConfigDict( + arbitrary_types_allowed=True, + ) + else: + + class Config: + arbitrary_types_allowed = True class BulkWriter: diff --git a/beanie/odm/documents.py b/beanie/odm/documents.py index 835d91db..6d2da8c7 100644 --- a/beanie/odm/documents.py +++ b/beanie/odm/documents.py @@ -21,6 +21,7 @@ PrivateAttr, Field, ConfigDict, + model_validator, ) from pydantic.class_validators import root_validator from pydantic.main import BaseModel @@ -141,13 +142,6 @@ class Document( - [UpdateMethods](https://roman-right.github.io/beanie/api/interfaces/#aggregatemethods) """ - # class Config: - # json_encoders = { - # ObjectId: lambda v: str(v), - # } - # allow_population_by_field_name = True - # # fields = {"id": "_id"} - if IS_PYDANTIC_V2: model_config = ConfigDict( json_schema_extra=json_schema_extra, @@ -178,7 +172,12 @@ def schema_extra( ) # State - revision_id: Optional[UUID] = Field(default=None, hidden=True) + if IS_PYDANTIC_V2: + revision_id: Optional[UUID] = Field( + default=None, json_schema_extra={"hidden": True} + ) + else: + revision_id: Optional[UUID] = Field(default=None, hidden=True) # type: ignore _previous_revision_id: Optional[UUID] = PrivateAttr(default=None) _saved_state: Optional[Dict[str, Any]] = PrivateAttr(default=None) _previous_saved_state: Optional[Dict[str, Any]] = PrivateAttr(default=None) @@ -207,8 +206,8 @@ def __init__(self, *args, **kwargs): super(Document, self).__init__(*args, **kwargs) self.get_motor_collection() - @root_validator(pre=True) - def fill_back_refs(cls, values): + @classmethod + def _fill_back_refs(cls, values): if cls._link_fields: for field_name, link_info in cls._link_fields.items(): if ( @@ -231,6 +230,18 @@ def fill_back_refs(cls, values): ] return values + if IS_PYDANTIC_V2: + + @model_validator(mode="before") + def fill_back_refs(cls, values): + return cls._fill_back_refs(values) + + else: + + @root_validator(pre=True) + def fill_back_refs(cls, values): + return cls._fill_back_refs(values) + @classmethod async def get( cls: Type["DocType"], diff --git a/beanie/odm/settings/base.py b/beanie/odm/settings/base.py index b6bcb620..b96afd22 100644 --- a/beanie/odm/settings/base.py +++ b/beanie/odm/settings/base.py @@ -1,9 +1,13 @@ from datetime import timedelta from typing import Optional, Dict, Any, Type +from beanie.odm.utils.pydantic import IS_PYDANTIC_V2 from motor.motor_asyncio import AsyncIOMotorDatabase, AsyncIOMotorCollection from pydantic import BaseModel, Field +if IS_PYDANTIC_V2: + from pydantic import ConfigDict + class ItemSettings(BaseModel): name: Optional[str] = None @@ -23,5 +27,11 @@ class ItemSettings(BaseModel): is_root: bool = False - class Config: - arbitrary_types_allowed = True + if IS_PYDANTIC_V2: + model_config = ConfigDict( + arbitrary_types_allowed=True, + ) + else: + + class Config: + arbitrary_types_allowed = True diff --git a/beanie/odm/settings/document.py b/beanie/odm/settings/document.py index 41497c37..8c0e18c6 100644 --- a/beanie/odm/settings/document.py +++ b/beanie/odm/settings/document.py @@ -1,11 +1,15 @@ from typing import Optional, List +from beanie.odm.utils.pydantic import IS_PYDANTIC_V2 from pydantic import Field from beanie.odm.fields import IndexModelField from beanie.odm.settings.base import ItemSettings from beanie.odm.settings.timeseries import TimeSeriesConfig +if IS_PYDANTIC_V2: + from pydantic import ConfigDict + class DocumentSettings(ItemSettings): use_state_management: bool = False @@ -23,5 +27,11 @@ class DocumentSettings(ItemSettings): keep_nulls: bool = True - class Config: - arbitrary_types_allowed = True + if IS_PYDANTIC_V2: + model_config = ConfigDict( + arbitrary_types_allowed=True, + ) + else: + + class Config: + arbitrary_types_allowed = True diff --git a/beanie/odm/utils/encoder.py b/beanie/odm/utils/encoder.py index af8e5fa3..6f6ba4ac 100644 --- a/beanie/odm/utils/encoder.py +++ b/beanie/odm/utils/encoder.py @@ -27,14 +27,12 @@ from bson import ObjectId, DBRef, Binary, Decimal128, Regex from pydantic import BaseModel, RootModel from pydantic import SecretBytes, SecretStr -from pydantic.color import Color from beanie.odm.fields import Link, LinkTypes from beanie.odm import documents from beanie.odm.utils.pydantic import get_iterator ENCODERS_BY_TYPE: Dict[Type[Any], Callable[[Any], Any]] = { - Color: str, timedelta: lambda td: td.total_seconds(), Decimal: Decimal128, deque: list, diff --git a/pyproject.toml b/pyproject.toml index 5a4d7f0b..d298cfd9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -37,7 +37,8 @@ test = [ "asgi-lifespan>=1.0.1", "httpx>=0.23.0", "fastapi>=0.100", - "pydantic-settings>=2.0", + "pydantic-settings>=2", + "pydantic-extra-types>=2" ] doc = [ "Pygments>=2.8.0", diff --git a/tests/odm/conftest.py b/tests/odm/conftest.py index cf55f06d..49525cb0 100644 --- a/tests/odm/conftest.py +++ b/tests/odm/conftest.py @@ -54,8 +54,8 @@ SampleLazyParsing, SampleWithMutableObjects, SubDocument, - Test2NonRoot, - TestNonRoot, + Doc2NonRoot, + DocNonRoot, Vehicle, Window, WindowWithRevision, @@ -86,11 +86,11 @@ DocumentWithCustomInit, DocumentWithTextIndexAndLink, LinkDocumentForTextSeacrh, - TestDocumentWithList, - TestDocumentWithBsonBinaryField, - TestDocumentWithRootModelAsAField, + DocumentWithList, + DocumentWithBsonBinaryField, + DocumentWithRootModelAsAField, ) -from tests.odm.views import TestView, TestViewWithLink +from tests.odm.views import ViewForTest, ViewForTestWithLink @pytest.fixture @@ -215,8 +215,8 @@ async def init(db): DocumentForEncodingTest, DocumentForEncodingTestDate, DocumentWithStringField, - TestView, - TestViewWithLink, + ViewForTest, + ViewForTestWithLink, DocumentMultiModelOne, DocumentMultiModelTwo, DocumentUnion, @@ -232,8 +232,8 @@ async def init(db): Bus, Owner, SampleWithMutableObjects, - TestNonRoot, - Test2NonRoot, + DocNonRoot, + Doc2NonRoot, SampleLazyParsing, RootDocument, ADocument, @@ -260,9 +260,9 @@ async def init(db): DocumentWithCustomInit, DocumentWithTextIndexAndLink, LinkDocumentForTextSeacrh, - TestDocumentWithList, - TestDocumentWithBsonBinaryField, - TestDocumentWithRootModelAsAField, + DocumentWithList, + DocumentWithBsonBinaryField, + DocumentWithRootModelAsAField, ] await init_beanie( database=db, diff --git a/tests/odm/custom_types/test_bson_binary.py b/tests/odm/custom_types/test_bson_binary.py index e2616c1c..92468d41 100644 --- a/tests/odm/custom_types/test_bson_binary.py +++ b/tests/odm/custom_types/test_bson_binary.py @@ -1,19 +1,19 @@ import bson -from tests.odm.models import TestDocumentWithBsonBinaryField +from tests.odm.models import DocumentWithBsonBinaryField async def test_bson_binary(): - doc = TestDocumentWithBsonBinaryField(binary_field=bson.Binary(b"test")) + doc = DocumentWithBsonBinaryField(binary_field=bson.Binary(b"test")) await doc.insert() assert doc.binary_field == bson.Binary(b"test") - new_doc = await TestDocumentWithBsonBinaryField.get(doc.id) + new_doc = await DocumentWithBsonBinaryField.get(doc.id) assert new_doc.binary_field == bson.Binary(b"test") - doc = TestDocumentWithBsonBinaryField(binary_field=b"test") + doc = DocumentWithBsonBinaryField(binary_field=b"test") await doc.insert() assert doc.binary_field == bson.Binary(b"test") - new_doc = await TestDocumentWithBsonBinaryField.get(doc.id) + new_doc = await DocumentWithBsonBinaryField.get(doc.id) assert new_doc.binary_field == bson.Binary(b"test") diff --git a/tests/odm/documents/test_inheritance.py b/tests/odm/documents/test_inheritance.py index 28157210..2340279f 100644 --- a/tests/odm/documents/test_inheritance.py +++ b/tests/odm/documents/test_inheritance.py @@ -6,8 +6,8 @@ Car, Bus, Owner, - TestNonRoot, - Test2NonRoot, + DocNonRoot, + Doc2NonRoot, ) @@ -104,11 +104,11 @@ async def test_links(self, db): await e.delete() def test_non_root_inheritance(self): - assert TestNonRoot._class_id is None - assert Test2NonRoot._class_id is None + assert DocNonRoot._class_id is None + assert Doc2NonRoot._class_id is None - assert TestNonRoot.get_collection_name() == "TestNonRoot" - assert Test2NonRoot.get_collection_name() == "Test2NonRoot" + assert DocNonRoot.get_collection_name() == "DocNonRoot" + assert Doc2NonRoot.get_collection_name() == "Doc2NonRoot" def test_class_ids(self): assert Vehicle._class_id == "Vehicle" diff --git a/tests/odm/documents/test_update.py b/tests/odm/documents/test_update.py index 8084f7d8..fe5d5310 100644 --- a/tests/odm/documents/test_update.py +++ b/tests/odm/documents/test_update.py @@ -1,4 +1,5 @@ import pytest +from beanie.odm.utils.pydantic import IS_PYDANTIC_V2 from beanie.exceptions import ( DocumentNotFound, @@ -9,7 +10,7 @@ DocumentTestModel, ModelWithOptionalField, DocumentWithKeepNullsFalse, - TestDocumentWithList, + DocumentWithList, Sample, ) @@ -59,7 +60,10 @@ async def test_replace_many_not_all_the_docs_found(documents): async def test_replace(document): update_data = {"test_str": "REPLACED_VALUE"} - new_doc = document.copy(update=update_data) + if IS_PYDANTIC_V2: + new_doc = document.model_copy(update=update_data) + else: + new_doc = document.copy(update=update_data) # document.test_str = "REPLACED_VALUE" await new_doc.replace() new_document = await DocumentTestModel.get(document.id) @@ -80,7 +84,10 @@ async def test_replace_not_found(document_not_inserted): # SAVE async def test_save(document): update_data = {"test_str": "REPLACED_VALUE"} - new_doc = document.copy(update=update_data) + if IS_PYDANTIC_V2: + new_doc = document.model_copy(update=update_data) + else: + new_doc = document.copy(update=update_data) # document.test_str = "REPLACED_VALUE" await new_doc.save() new_document = await DocumentTestModel.get(document.id) @@ -283,9 +290,12 @@ async def test_save_changes_keep_nulls_false(): async def test_update_list(): - test_record = TestDocumentWithList(list_values=["1", "2", "3"]) + test_record = DocumentWithList(list_values=["1", "2", "3"]) test_record = await test_record.insert() - update_data = test_record.dict() + if IS_PYDANTIC_V2: + update_data = test_record.model_dump() + else: + update_data = test_record.dict() update_data["list_values"] = ["5", "6", "7"] updated_test_record = await test_record.update({"$set": update_data}) diff --git a/tests/odm/models.py b/tests/odm/models.py index 4a74f006..c80b3c8f 100644 --- a/tests/odm/models.py +++ b/tests/odm/models.py @@ -1,4 +1,5 @@ import datetime +from enum import Enum from ipaddress import ( IPv4Address, IPv4Interface, @@ -8,13 +9,22 @@ IPv6Network, ) from pathlib import Path -from typing import Dict, List, Optional, Set, Tuple, Union, ClassVar +from typing import ( + Dict, + List, + Optional, + Set, + Tuple, + Union, + ClassVar, + Any, + Callable, +) from uuid import UUID, uuid4 import pymongo from pydantic import ( BaseModel, - Extra, Field, PrivateAttr, SecretBytes, @@ -22,7 +32,9 @@ ConfigDict, RootModel, ) -from pydantic.color import Color +from pydantic.fields import FieldInfo +from pydantic_core import core_schema + from pymongo import IndexModel from beanie import DecimalAnnotation @@ -43,6 +55,41 @@ from beanie.odm.utils.pydantic import IS_PYDANTIC_V2 +class Color: + def __init__(self, value): + self.value = value + + def as_rgb(self): + return self.value + + def as_hex(self): + return self.value + + @classmethod + def __get_pydantic_core_schema__( + cls, + _source_type: Any, + _handler: Callable[[Any], core_schema.CoreSchema], # type: ignore + ) -> core_schema.CoreSchema: # type: ignore + def validate(value, _: FieldInfo) -> Color: + if isinstance(value, Color): + return value + if isinstance(value, dict): + return Color(value["value"]) + return Color(value) + + python_schema = core_schema.general_plain_validator_function(validate) + + return core_schema.json_or_python_schema( + json_schema=core_schema.str_schema(), + python_schema=python_schema, + ) + + +class Extra(str, Enum): + allow = "allow" + + class Option2(BaseModel): f: float @@ -83,10 +130,16 @@ class SubDocument(BaseModel): class DocumentTestModel(Document): test_int: int - test_list: List[SubDocument] = Field(hidden=True) test_doc: SubDocument test_str: str + if IS_PYDANTIC_V2: + test_list: List[SubDocument] = Field( + json_schema_extra={"hidden": True} + ) + else: + test_list: List[SubDocument] = Field(hidden=True) + class Settings: use_cache = True cache_expiration_time = datetime.timedelta(seconds=10) @@ -405,7 +458,7 @@ class Config: num_1: int -class DocumentWithExtrasKw(Document, extra=Extra.allow): +class DocumentWithExtrasKw(Document, extra="allow"): num_1: int @@ -439,11 +492,20 @@ class House(Document): door: Link[Door] roof: Optional[Link[Roof]] = None yards: Optional[List[Link[Yard]]] = None - name: Indexed(str) = Field(hidden=True) height: Indexed(int) = 2 + if IS_PYDANTIC_V2: + name: Indexed(str) = Field(json_schema_extra={"hidden": True}) + else: + name: Indexed(str) = Field(hidden=True) + + if IS_PYDANTIC_V2: + model_config = ConfigDict( + extra="allow", + ) + else: - class Config: - extra = Extra.allow + class Config: + extra = Extra.allow class DocumentForEncodingTest(Document): @@ -600,11 +662,11 @@ class Settings: use_state_management = True -class TestNonRoot(MixinNonRoot, MyDocNonRoot): +class DocNonRoot(MixinNonRoot, MyDocNonRoot): name: str -class Test2NonRoot(MyDocNonRoot): +class Doc2NonRoot(MyDocNonRoot): name: str @@ -624,13 +686,19 @@ class SampleLazyParsing(Document): [], ) + if IS_PYDANTIC_V2: + model_config = ConfigDict( + validate_assignment=True, + ) + else: + + class Config: + validate_assignment = True + class Settings: lazy_parsing = True use_state_management = True - class Config: - validate_assignment = True - class RootDocument(Document): name: str @@ -722,8 +790,14 @@ class DocumentWithDecimalField(Document): decimal_places=1, multiple_of=0.5, default=0 ) - class Config: - validate_assignment = True + if IS_PYDANTIC_V2: + model_config = ConfigDict( + validate_assignment=True, + ) + else: + + class Config: + validate_assignment = True class Settings: name = "amounts" @@ -770,7 +844,12 @@ class DocumentWithLink(Document): class DocumentWithBackLink(Document): - back_link: BackLink[DocumentWithLink] = Field(original_field="link") + if IS_PYDANTIC_V2: + back_link: BackLink[DocumentWithLink] = Field( + json_schema_extra={"original_field": "link"}, + ) + else: + back_link: BackLink[DocumentWithLink] = Field(original_field="link") i: int = 1 @@ -780,9 +859,14 @@ class DocumentWithListLink(Document): class DocumentWithListBackLink(Document): - back_link: List[BackLink[DocumentWithListLink]] = Field( - original_field="link" - ) + if IS_PYDANTIC_V2: + back_link: List[BackLink[DocumentWithListLink]] = Field( + json_schema_extra={"original_field": "link"}, + ) + else: + back_link: List[BackLink[DocumentWithListLink]] = Field( + original_field="link" + ) i: int = 1 @@ -858,16 +942,16 @@ class Settings: ] -class TestDocumentWithList(Document): +class DocumentWithList(Document): list_values: List[str] -class TestDocumentWithBsonBinaryField(Document): +class DocumentWithBsonBinaryField(Document): binary_field: BsonBinary Pets = RootModel[List[str]] -class TestDocumentWithRootModelAsAField(Document): +class DocumentWithRootModelAsAField(Document): pets: Pets diff --git a/tests/odm/query/test_find.py b/tests/odm/query/test_find.py index 63100fc3..085b8617 100644 --- a/tests/odm/query/test_find.py +++ b/tests/odm/query/test_find.py @@ -2,10 +2,14 @@ import pytest from pydantic import BaseModel -from pydantic.color import Color from beanie.odm.enums import SortDirection -from tests.odm.models import Sample, DocumentWithBsonEncodersFiledsTypes, House +from tests.odm.models import ( + Sample, + DocumentWithBsonEncodersFiledsTypes, + House, + Color, +) async def test_find_query(): diff --git a/tests/odm/test_fields.py b/tests/odm/test_fields.py index 731b9e95..6e15a773 100644 --- a/tests/odm/test_fields.py +++ b/tests/odm/test_fields.py @@ -108,8 +108,10 @@ async def test_custom_filed_types(): async def test_hidden(document): document = await DocumentTestModel.find_one() - - assert "test_list" not in document.dict() + if IS_PYDANTIC_V2: + assert "test_list" not in document.model_dump() + else: + assert "test_list" not in document.dict() def test_revision_id_not_in_schema(): @@ -135,8 +137,10 @@ class Foo(Document): @pytest.mark.parametrize("exclude", [{"test_int"}, {"test_doc": {"test_int"}}]) async def test_param_exclude(document, exclude): document = await DocumentTestModel.find_one() - - doc_dict = document.dict(exclude=exclude) + if IS_PYDANTIC_V2: + doc_dict = document.model_dump(exclude=exclude) + else: + doc_dict = document.dict(exclude=exclude) if isinstance(exclude, AbstractSet): for k in exclude: assert k not in doc_dict diff --git a/tests/odm/test_relations.py b/tests/odm/test_relations.py index 80bded59..1cb633ba 100644 --- a/tests/odm/test_relations.py +++ b/tests/odm/test_relations.py @@ -685,14 +685,26 @@ class HouseForReversedOrderInit(Document): class DoorForReversedOrderInit(Document): height: int = 2 width: int = 1 - house: BackLink[HouseForReversedOrderInit] = Field(original_field="door") + if IS_PYDANTIC_V2: + house: BackLink[HouseForReversedOrderInit] = Field( + json_schema_extra={"original_field": "door"} + ) + else: + house: BackLink[HouseForReversedOrderInit] = Field( + original_field="door" + ) class PersonForReversedOrderInit(Document): name: str - house: List[BackLink[HouseForReversedOrderInit]] = Field( - original_field="owners" - ) + if IS_PYDANTIC_V2: + house: List[BackLink[HouseForReversedOrderInit]] = Field( + json_schema_extra={"original_field": "owners"} + ) + else: + house: List[BackLink[HouseForReversedOrderInit]] = Field( + original_field="owners" + ) class TestDeleteBackLinks: diff --git a/tests/odm/test_root_models.py b/tests/odm/test_root_models.py index d3e750f0..50f33f7f 100644 --- a/tests/odm/test_root_models.py +++ b/tests/odm/test_root_models.py @@ -1,14 +1,14 @@ -from tests.odm.models import TestDocumentWithRootModelAsAField +from tests.odm.models import DocumentWithRootModelAsAField class TestRootModels: async def test_insert(self): - doc = TestDocumentWithRootModelAsAField(pets=["dog", "cat", "fish"]) + doc = DocumentWithRootModelAsAField(pets=["dog", "cat", "fish"]) await doc.insert() - new_doc = await TestDocumentWithRootModelAsAField.get(doc.id) + new_doc = await DocumentWithRootModelAsAField.get(doc.id) assert new_doc.pets.root == ["dog", "cat", "fish"] - collection = TestDocumentWithRootModelAsAField.get_motor_collection() + collection = DocumentWithRootModelAsAField.get_motor_collection() raw_doc = await collection.find_one({"_id": doc.id}) assert raw_doc["pets"] == ["dog", "cat", "fish"] diff --git a/tests/odm/test_views.py b/tests/odm/test_views.py index d94ad4b4..bd265478 100644 --- a/tests/odm/test_views.py +++ b/tests/odm/test_views.py @@ -1,15 +1,15 @@ -from tests.odm.views import TestView, TestViewWithLink +from tests.odm.views import ViewForTest, ViewForTestWithLink class TestViews: async def test_simple(self, documents): await documents(number=15) - results = await TestView.all().to_list() + results = await ViewForTest.all().to_list() assert len(results) == 6 async def test_aggregate(self, documents): await documents(number=15) - results = await TestView.aggregate( + results = await ViewForTest.aggregate( [ {"$set": {"test_field": 1}}, {"$match": {"$expr": {"$lt": ["$number", 12]}}}, @@ -20,7 +20,7 @@ async def test_aggregate(self, documents): async def test_link(self, documents_with_links): await documents_with_links() - results = await TestViewWithLink.all().to_list() + results = await ViewForTestWithLink.all().to_list() for document in results: await document.fetch_all_links() diff --git a/tests/odm/views.py b/tests/odm/views.py index 28f5a233..0e86357f 100644 --- a/tests/odm/views.py +++ b/tests/odm/views.py @@ -3,7 +3,7 @@ from tests.odm.models import DocumentTestModel, DocumentTestModelWithLink -class TestView(View): +class ViewForTest(View): number: int string: str @@ -16,7 +16,7 @@ class Settings: ] -class TestViewWithLink(View): +class ViewForTestWithLink(View): link: Link[DocumentTestModel] class Settings: From 833a6c51f9c63de5189bbc92c3d6f3ffbac915b1 Mon Sep 17 00:00:00 2001 From: Roman Date: Sat, 9 Sep 2023 20:30:59 -0600 Subject: [PATCH 16/24] fix: fix 695 - validate_call wrapper --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index d298cfd9..7c363bc8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,7 +21,7 @@ dependencies = [ "motor>=2.5.0,<4.0.0", "click>=7", "toml", - "lazy-model==0.1.0b0", + "lazy-model==0.2.0", "typing-extensions>=4.7" ] From 45728a1385ac484a9285d9337a3e8ef7079e662a Mon Sep 17 00:00:00 2001 From: Roman Date: Sat, 9 Sep 2023 20:38:45 -0600 Subject: [PATCH 17/24] call wrapper test model --- tests/odm/conftest.py | 2 ++ tests/odm/models.py | 9 +++++++++ 2 files changed, 11 insertions(+) diff --git a/tests/odm/conftest.py b/tests/odm/conftest.py index 49525cb0..7e354e07 100644 --- a/tests/odm/conftest.py +++ b/tests/odm/conftest.py @@ -89,6 +89,7 @@ DocumentWithList, DocumentWithBsonBinaryField, DocumentWithRootModelAsAField, + DocWithCallWrapper, ) from tests.odm.views import ViewForTest, ViewForTestWithLink @@ -263,6 +264,7 @@ async def init(db): DocumentWithList, DocumentWithBsonBinaryField, DocumentWithRootModelAsAField, + DocWithCallWrapper, ] await init_beanie( database=db, diff --git a/tests/odm/models.py b/tests/odm/models.py index c80b3c8f..5ee73974 100644 --- a/tests/odm/models.py +++ b/tests/odm/models.py @@ -31,6 +31,7 @@ SecretStr, ConfigDict, RootModel, + validate_call, ) from pydantic.fields import FieldInfo from pydantic_core import core_schema @@ -955,3 +956,11 @@ class DocumentWithBsonBinaryField(Document): class DocumentWithRootModelAsAField(Document): pets: Pets + + +class DocWithCallWrapper(Document): + name: str + + @validate_call + def foo(self, bar: str) -> None: + print(f"foo {bar}") From 356803ec85362c4407c730fad4417f0a9405ea2e Mon Sep 17 00:00:00 2001 From: Roman Date: Mon, 11 Sep 2023 20:48:29 -0600 Subject: [PATCH 18/24] test against pydantic 1.10 only --- .github/workflows/github-actions-tests.yml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/github-actions-tests.yml b/.github/workflows/github-actions-tests.yml index 63c08426..04c29ac1 100644 --- a/.github/workflows/github-actions-tests.yml +++ b/.github/workflows/github-actions-tests.yml @@ -6,9 +6,9 @@ jobs: strategy: fail-fast: false matrix: - python-version: [ 3.7, 3.8, 3.9, 3.10.6, 3.11 ] - mongodb-version: [ 4.4, 5.0 ] - pydantic-version: [ 1.10.11, 2.0 ] + python-version: [ 3.10.6 ] + mongodb-version: [ 5.0 ] + pydantic-version: [ 1.10.11 ] runs-on: ubuntu-latest steps: - uses: actions/checkout@v3 From ca600bdf954a69825b979103d412c31d8b666c5b Mon Sep 17 00:00:00 2001 From: Roman Date: Mon, 11 Sep 2023 20:50:30 -0600 Subject: [PATCH 19/24] test against pydantic 1.10 only --- .github/workflows/github-actions-tests.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/github-actions-tests.yml b/.github/workflows/github-actions-tests.yml index 04c29ac1..dd0b113a 100644 --- a/.github/workflows/github-actions-tests.yml +++ b/.github/workflows/github-actions-tests.yml @@ -20,9 +20,9 @@ jobs: with: mongodb-version: ${{ matrix.mongodb-version }} mongodb-replica-set: test-rs - - name: install pydantic - run: pip install pydantic==${{ matrix.pydantic-version }} - name: install dependencies run: pip install .[test] + - name: install pydantic + run: pip install pydantic==${{ matrix.pydantic-version }} - name: run tests run: pytest From f2ead1b060f0400209235d081776b6046c1b1e3c Mon Sep 17 00:00:00 2001 From: Roman Date: Mon, 11 Sep 2023 21:18:55 -0600 Subject: [PATCH 20/24] fix: pydantic 1.10.x --- beanie/odm/custom_types/bson/binary.py | 85 ++++++++++++++-------- beanie/odm/documents.py | 4 +- beanie/odm/utils/encoder.py | 9 ++- beanie/odm/utils/pydantic.py | 2 +- tests/odm/custom_types/test_bson_binary.py | 9 ++- tests/odm/models.py | 48 ++++++++++-- tests/odm/test_relations.py | 16 ++-- tests/odm/test_root_models.py | 22 +++--- 8 files changed, 136 insertions(+), 59 deletions(-) diff --git a/beanie/odm/custom_types/bson/binary.py b/beanie/odm/custom_types/bson/binary.py index b6acb7cb..3c9722c0 100644 --- a/beanie/odm/custom_types/bson/binary.py +++ b/beanie/odm/custom_types/bson/binary.py @@ -1,37 +1,64 @@ from typing import Any, Callable import bson -from pydantic import GetJsonSchemaHandler -from pydantic.fields import FieldInfo -from pydantic.json_schema import JsonSchemaValue -from pydantic_core import core_schema +from beanie.odm.utils.pydantic import IS_PYDANTIC_V2 + +if IS_PYDANTIC_V2: + from pydantic import GetJsonSchemaHandler + from pydantic.fields import FieldInfo + from pydantic.json_schema import JsonSchemaValue + from pydantic_core import core_schema class BsonBinary(bson.Binary): - @classmethod - def __get_pydantic_core_schema__( - cls, - _source_type: Any, - _handler: Callable[[Any], core_schema.CoreSchema], # type: ignore - ) -> core_schema.CoreSchema: # type: ignore - def validate(value, _: FieldInfo) -> bson.Binary: - if isinstance(value, bson.Binary): + if IS_PYDANTIC_V2: + + @classmethod + def __get_pydantic_core_schema__( + cls, + _source_type: Any, + _handler: Callable[[Any], core_schema.CoreSchema], # type: ignore + ) -> core_schema.CoreSchema: # type: ignore + def validate(value, _: FieldInfo) -> bson.Binary: + if isinstance(value, BsonBinary): + return value + if isinstance(value, bson.Binary): + return BsonBinary(value) + if isinstance(value, bytes): + return BsonBinary(value) + raise ValueError( + "Value must be bytes or bson.Binary or BsonBinary" + ) + + python_schema = core_schema.general_plain_validator_function(validate) # type: ignore + + return core_schema.json_or_python_schema( + json_schema=core_schema.float_schema(), + python_schema=python_schema, + ) + + @classmethod + def __get_pydantic_json_schema__( + cls, + _core_schema: core_schema.CoreSchema, # type: ignore + handler: GetJsonSchemaHandler, + ) -> JsonSchemaValue: + return handler(core_schema.str_schema()) + + else: + + @classmethod + def __get_validators__(cls): + yield cls.validate + + @classmethod + def validate(cls, value): + if isinstance(value, BsonBinary): return value + if isinstance(value, bson.Binary): + return BsonBinary(value) if isinstance(value, bytes): - return bson.Binary(value) - raise ValueError("Value must be bytes or bson.Binary") - - python_schema = core_schema.general_plain_validator_function(validate) # type: ignore - - return core_schema.json_or_python_schema( - json_schema=core_schema.float_schema(), - python_schema=python_schema, - ) - - @classmethod - def __get_pydantic_json_schema__( - cls, - _core_schema: core_schema.CoreSchema, # type: ignore - handler: GetJsonSchemaHandler, - ) -> JsonSchemaValue: - return handler(core_schema.str_schema()) + return BsonBinary(value) + raise ValueError( + "Value must be bytes or bson.Binary or BsonBinary" + ) diff --git a/beanie/odm/documents.py b/beanie/odm/documents.py index 6d2da8c7..fd221e51 100644 --- a/beanie/odm/documents.py +++ b/beanie/odm/documents.py @@ -21,7 +21,6 @@ PrivateAttr, Field, ConfigDict, - model_validator, ) from pydantic.class_validators import root_validator from pydantic.main import BaseModel @@ -99,6 +98,9 @@ ) from beanie.odm.utils.typing import extract_id_class +if IS_PYDANTIC_V2: + from pydantic import model_validator + if TYPE_CHECKING: from pydantic.typing import AbstractSetIntStr, MappingIntStrAny, DictStrAny diff --git a/beanie/odm/utils/encoder.py b/beanie/odm/utils/encoder.py index 6f6ba4ac..a86a5f84 100644 --- a/beanie/odm/utils/encoder.py +++ b/beanie/odm/utils/encoder.py @@ -25,12 +25,15 @@ import bson from bson import ObjectId, DBRef, Binary, Decimal128, Regex -from pydantic import BaseModel, RootModel +from pydantic import BaseModel from pydantic import SecretBytes, SecretStr from beanie.odm.fields import Link, LinkTypes from beanie.odm import documents -from beanie.odm.utils.pydantic import get_iterator +from beanie.odm.utils.pydantic import get_iterator, IS_PYDANTIC_V2 + +if IS_PYDANTIC_V2: + from pydantic import RootModel ENCODERS_BY_TYPE: Dict[Type[Any], Callable[[Any], Any]] = { timedelta: lambda td: td.total_seconds(), @@ -209,7 +212,7 @@ def _encode( if isinstance(obj, documents.Document): return self.encode_document(obj) - if isinstance(obj, RootModel): + if IS_PYDANTIC_V2 and isinstance(obj, RootModel): return self.encode_root_model(obj) if isinstance(obj, BaseModel): return self.encode_base_model(obj) diff --git a/beanie/odm/utils/pydantic.py b/beanie/odm/utils/pydantic.py index 1b16e8fc..b33b2509 100644 --- a/beanie/odm/utils/pydantic.py +++ b/beanie/odm/utils/pydantic.py @@ -81,4 +81,4 @@ def _iter(model, by_alias=False): return _iter(model, by_alias=by_alias) else: - return model._iter(to_dict=False, by_alias=False) + return model._iter(to_dict=False, by_alias=by_alias) diff --git a/tests/odm/custom_types/test_bson_binary.py b/tests/odm/custom_types/test_bson_binary.py index 92468d41..866fe7d2 100644 --- a/tests/odm/custom_types/test_bson_binary.py +++ b/tests/odm/custom_types/test_bson_binary.py @@ -1,19 +1,20 @@ import bson +from beanie import BsonBinary from tests.odm.models import DocumentWithBsonBinaryField async def test_bson_binary(): doc = DocumentWithBsonBinaryField(binary_field=bson.Binary(b"test")) await doc.insert() - assert doc.binary_field == bson.Binary(b"test") + assert doc.binary_field == BsonBinary(b"test") new_doc = await DocumentWithBsonBinaryField.get(doc.id) - assert new_doc.binary_field == bson.Binary(b"test") + assert new_doc.binary_field == BsonBinary(b"test") doc = DocumentWithBsonBinaryField(binary_field=b"test") await doc.insert() - assert doc.binary_field == bson.Binary(b"test") + assert doc.binary_field == BsonBinary(b"test") new_doc = await DocumentWithBsonBinaryField.get(doc.id) - assert new_doc.binary_field == bson.Binary(b"test") + assert new_doc.binary_field == BsonBinary(b"test") diff --git a/tests/odm/models.py b/tests/odm/models.py index 5ee73974..eac42199 100644 --- a/tests/odm/models.py +++ b/tests/odm/models.py @@ -30,8 +30,6 @@ SecretBytes, SecretStr, ConfigDict, - RootModel, - validate_call, ) from pydantic.fields import FieldInfo from pydantic_core import core_schema @@ -55,6 +53,9 @@ from beanie.odm.union_doc import UnionDoc from beanie.odm.utils.pydantic import IS_PYDANTIC_V2 +if IS_PYDANTIC_V2: + from pydantic import RootModel, validate_call + class Color: def __init__(self, value): @@ -66,6 +67,18 @@ def as_rgb(self): def as_hex(self): return self.value + @classmethod + def __get_validators__(cls): + yield cls.validate + + @classmethod + def validate(cls, value): + if isinstance(value, Color): + return value + if isinstance(value, dict): + return Color(value["value"]) + return Color(value) + @classmethod def __get_pydantic_core_schema__( cls, @@ -255,6 +268,15 @@ class DocumentWithCustomFiledsTypes(Document): tuple_type: Tuple[int, str] path: Path + if IS_PYDANTIC_V2: + model_config = ConfigDict( + arbitrary_types_allowed=True, + ) + else: + + class Config: + arbitrary_types_allowed = True + class DocumentWithBsonEncodersFiledsTypes(Document): color: Color @@ -266,6 +288,15 @@ class Settings: datetime.datetime: lambda o: o.isoformat(timespec="microseconds"), } + if IS_PYDANTIC_V2: + model_config = ConfigDict( + arbitrary_types_allowed=True, + ) + else: + + class Config: + arbitrary_types_allowed = True + class DocumentWithActions(Document): name: str @@ -951,7 +982,10 @@ class DocumentWithBsonBinaryField(Document): binary_field: BsonBinary -Pets = RootModel[List[str]] +if IS_PYDANTIC_V2: + Pets = RootModel[List[str]] +else: + Pets = List[str] class DocumentWithRootModelAsAField(Document): @@ -961,6 +995,8 @@ class DocumentWithRootModelAsAField(Document): class DocWithCallWrapper(Document): name: str - @validate_call - def foo(self, bar: str) -> None: - print(f"foo {bar}") + if IS_PYDANTIC_V2: + + @validate_call + def foo(self, bar: str) -> None: + print(f"foo {bar}") diff --git a/tests/odm/test_relations.py b/tests/odm/test_relations.py index 1cb633ba..16694683 100644 --- a/tests/odm/test_relations.py +++ b/tests/odm/test_relations.py @@ -6,7 +6,11 @@ from beanie import init_beanie, Document from beanie.exceptions import DocumentWasNotSaved from beanie.odm.fields import DeleteRules, Link, WriteRules, BackLink -from beanie.odm.utils.pydantic import parse_model, IS_PYDANTIC_V2 +from beanie.odm.utils.pydantic import ( + parse_model, + IS_PYDANTIC_V2, + get_model_fields, +) from tests.odm.models import ( Door, House, @@ -418,17 +422,17 @@ async def test_query_composition(self): # Simple fields are initialized using the pydantic model_fields internal property # such fields are properly isolated when multi inheritance is involved. - assert set(RootDocument.model_fields.keys()) == SYS | { + assert set(get_model_fields(RootDocument).keys()) == SYS | { "name", "link_root", } - assert set(ADocument.model_fields.keys()) == SYS | { + assert set(get_model_fields(ADocument).keys()) == SYS | { "name", "link_root", "surname", "link_a", } - assert set(BDocument.model_fields.keys()) == SYS | { + assert set(get_model_fields(BDocument).keys()) == SYS | { "name", "link_root", "email", @@ -512,7 +516,7 @@ async def test_with_chaining_aggregation(self): async def test_with_extra_allow(self, houses): res = await House.find(fetch_links=True).to_list() - assert res[0].model_fields.keys() == { + assert get_model_fields(res[0]).keys() == { "id", "revision_id", "windows", @@ -524,7 +528,7 @@ async def test_with_extra_allow(self, houses): } res = await House.find_one(fetch_links=True) - assert res.model_fields.keys() == { + assert get_model_fields(res).keys() == { "id", "revision_id", "windows", diff --git a/tests/odm/test_root_models.py b/tests/odm/test_root_models.py index 50f33f7f..bfe8e4d6 100644 --- a/tests/odm/test_root_models.py +++ b/tests/odm/test_root_models.py @@ -1,14 +1,18 @@ +from beanie.odm.utils.pydantic import IS_PYDANTIC_V2 + from tests.odm.models import DocumentWithRootModelAsAField -class TestRootModels: - async def test_insert(self): - doc = DocumentWithRootModelAsAField(pets=["dog", "cat", "fish"]) - await doc.insert() +if IS_PYDANTIC_V2: + + class TestRootModels: + async def test_insert(self): + doc = DocumentWithRootModelAsAField(pets=["dog", "cat", "fish"]) + await doc.insert() - new_doc = await DocumentWithRootModelAsAField.get(doc.id) - assert new_doc.pets.root == ["dog", "cat", "fish"] + new_doc = await DocumentWithRootModelAsAField.get(doc.id) + assert new_doc.pets.root == ["dog", "cat", "fish"] - collection = DocumentWithRootModelAsAField.get_motor_collection() - raw_doc = await collection.find_one({"_id": doc.id}) - assert raw_doc["pets"] == ["dog", "cat", "fish"] + collection = DocumentWithRootModelAsAField.get_motor_collection() + raw_doc = await collection.find_one({"_id": doc.id}) + assert raw_doc["pets"] == ["dog", "cat", "fish"] From f76ca6c3b0ec58718130d7d9f5c93fe92a82ec4f Mon Sep 17 00:00:00 2001 From: Roman Date: Mon, 11 Sep 2023 21:30:48 -0600 Subject: [PATCH 21/24] roll-back: test matrix --- .github/workflows/github-actions-tests.yml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/github-actions-tests.yml b/.github/workflows/github-actions-tests.yml index dd0b113a..558b27f8 100644 --- a/.github/workflows/github-actions-tests.yml +++ b/.github/workflows/github-actions-tests.yml @@ -6,9 +6,9 @@ jobs: strategy: fail-fast: false matrix: - python-version: [ 3.10.6 ] - mongodb-version: [ 5.0 ] - pydantic-version: [ 1.10.11 ] + python-version: [ 3.7, 3.8, 3.9, 3.10.6, 3.11 ] + mongodb-version: [ 4.4, 5.0 ] + pydantic-version: [ 1.10.11, 2.0 ] runs-on: ubuntu-latest steps: - uses: actions/checkout@v3 From 8a10b494b20ebfc819a9a5268e603a1105382ae5 Mon Sep 17 00:00:00 2001 From: Roman Date: Mon, 11 Sep 2023 21:37:29 -0600 Subject: [PATCH 22/24] lower coverage limit --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 7c363bc8..4dcf44f7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -78,7 +78,7 @@ exclude = ''' [tool.pytest.ini_options] minversion = "6.0" -addopts = "--cov-report term-missing --cov=beanie --cov-branch --cov-fail-under=85" +addopts = "--cov-report term-missing --cov=beanie --cov-branch --cov-fail-under=80" testpaths = [ "tests", ] From 5adce0be3abf7c470c1538c8c133706d6becc0c9 Mon Sep 17 00:00:00 2001 From: Roman Date: Mon, 11 Sep 2023 21:45:22 -0600 Subject: [PATCH 23/24] pydantic 2.3 in the versions matrix --- .github/workflows/github-actions-tests.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/github-actions-tests.yml b/.github/workflows/github-actions-tests.yml index 558b27f8..b2e283b9 100644 --- a/.github/workflows/github-actions-tests.yml +++ b/.github/workflows/github-actions-tests.yml @@ -8,7 +8,7 @@ jobs: matrix: python-version: [ 3.7, 3.8, 3.9, 3.10.6, 3.11 ] mongodb-version: [ 4.4, 5.0 ] - pydantic-version: [ 1.10.11, 2.0 ] + pydantic-version: [ 1.10.11, 2.3 ] runs-on: ubuntu-latest steps: - uses: actions/checkout@v3 From a77455532114e39276b513eebd846699d11c978a Mon Sep 17 00:00:00 2001 From: Roman Date: Wed, 13 Sep 2023 12:16:23 -0600 Subject: [PATCH 24/24] version: 1.22.0 --- .github/workflows/close_inactive_issues.yml | 2 +- beanie/__init__.py | 2 +- docs/changelog.md | 24 +++++++++++++++++++++ pyproject.toml | 2 +- tests/test_beanie.py | 2 +- 5 files changed, 28 insertions(+), 4 deletions(-) diff --git a/.github/workflows/close_inactive_issues.yml b/.github/workflows/close_inactive_issues.yml index d6ac6430..4a33c5fb 100644 --- a/.github/workflows/close_inactive_issues.yml +++ b/.github/workflows/close_inactive_issues.yml @@ -16,7 +16,7 @@ jobs: stale-pr-message: 'This PR is stale because it has been open 45 days with no activity.' close-issue-message: 'This issue was closed because it has been stalled for 14 days with no activity.' close-pr-message: 'This PR was closed because it has been stalled for 14 days with no activity.' - exempt-issue-labels: 'bug,feature-request,typing bug,feature request' + exempt-issue-labels: 'bug,feature-request,typing bug,feature request,doc,documentation' days-before-issue-stale: 30 days-before-pr-stale: 45 days-before-issue-close: 14 diff --git a/beanie/__init__.py b/beanie/__init__.py index 63817ffe..3f7831e0 100644 --- a/beanie/__init__.py +++ b/beanie/__init__.py @@ -31,7 +31,7 @@ from beanie.odm.views import View from beanie.odm.union_doc import UnionDoc -__version__ = "1.21.0" +__version__ = "1.22.0" __all__ = [ # ODM "Document", diff --git a/docs/changelog.md b/docs/changelog.md index e111c8a1..17847080 100644 --- a/docs/changelog.md +++ b/docs/changelog.md @@ -2,6 +2,30 @@ Beanie project +## [1.22.0] - 2023-09-13 + +### Fix | August 2023 +- Author - [Roman Right](https://github.com/roman-right) +- PR + +- Issues: + + - [[BUG] Issue with `List[Link[Type]]` when `fetch_all_links` is called](https://github.com/roman-right/beanie/issues/576) + - [Loosen type requirement for `insert_many()`?](https://github.com/roman-right/beanie/issues/591) + - [[BUG] Updating documents with a frozen BaseModel as field raises TypeError](https://github.com/roman-right/beanie/issues/599) + - [[BUG] Not operator cant be on top level](https://github.com/roman-right/beanie/issues/600) + - [[BUG] `Text` query doesn't work with `fetch_links=True`](https://github.com/roman-right/beanie/issues/606) + - [[BUG] List type fields in updated model record do not get update.](https://github.com/roman-right/beanie/issues/629) + - [[BUG] Undefined behavior when chaining update methods](https://github.com/roman-right/beanie/issues/646) + - [[BUG] Revision Id is in Responsemodel](https://github.com/roman-right/beanie/issues/648) + - [[BUG] Custom types like bson.Binary require `__get_pydantic_core_schema__`](https://github.com/roman-right/beanie/issues/651) + - [[BUG] `validate_on_save` doesn't work with `Document.save()`](https://github.com/roman-right/beanie/issues/664) + - [[BUG] Beanie persists `root` field](https://github.com/roman-right/beanie/issues/668) + - [Beanie 1.21 still triggers many deprecation warnings with Pydantic v2](https://github.com/roman-right/beanie/issues/676) + - [[BUG] TypeError: expected 1 argument, got 0 when beanie.Document has method wrapped in pydantic.validate_call](https://github.com/roman-right/beanie/issues/695) + +[1.22.0]: https://pypi.org/project/beanie/1.22.0 + ## [1.21.0] - 2023-08-03 ### Pydantic bump | final diff --git a/pyproject.toml b/pyproject.toml index 4dcf44f7..17bd2703 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "flit_core.buildapi" [project] name = "beanie" -version = "1.21.0" +version = "1.22.0" description = "Asynchronous Python ODM for MongoDB" readme = "README.md" requires-python = ">=3.7,<4.0" diff --git a/tests/test_beanie.py b/tests/test_beanie.py index 79e7b5f9..42693642 100644 --- a/tests/test_beanie.py +++ b/tests/test_beanie.py @@ -2,4 +2,4 @@ def test_version(): - assert __version__ == "1.21.0" + assert __version__ == "1.22.0"