From 1a023289239c8b8625e27c0121f5cfc80875d642 Mon Sep 17 00:00:00 2001 From: Anirudh Haritas Murali <49116134+anihm136@users.noreply.github.com> Date: Tue, 21 Nov 2023 22:48:18 +0530 Subject: [PATCH] fix(firestore): Add ID to all documents in Firestore provider (#94) Noticed a bug in the Firestore provider while taking this on a test run - many of the methods with the Firestore provider do not add the ID to the document and therefore fail model validation. Additionally, some places had incorrect conversion of documents into dict (using `dict(doc)` instead of `doc.to_dict()`). This PR resolves both of these issues. --------- Co-authored-by: Anirudh Murali --- .../datastore/providers/firestore.py | 36 +++++++++++++------ 1 file changed, 26 insertions(+), 10 deletions(-) diff --git a/extension_service/datastore/providers/firestore.py b/extension_service/datastore/providers/firestore.py index 036e485d..a873ccc6 100644 --- a/extension_service/datastore/providers/firestore.py +++ b/extension_service/datastore/providers/firestore.py @@ -161,13 +161,17 @@ async def get_airport_by_id(self, id: int) -> Optional[models.Airport]: query = self.__client.collection("airports").where( filter=FieldFilter("id", "==", id) ) - return models.Airport.model_validate(await query.get().to_dict()) + airport_doc = await query.get() + airport_dict = airport_doc.to_dict() | {"id": airport_doc.id} + return models.Airport.model_validate(airport_dict) async def get_airport_by_iata(self, iata: str) -> Optional[models.Airport]: query = self.__client.collection("airports").where( filter=FieldFilter("iata", "==", iata) ) - return models.Airport.model_validate(await query.get().to_dict()) + airport_doc = await query.get() + airport_dict = airport_doc.to_dict() | {"id": airport_doc.id} + return models.Airport.model_validate(airport_dict) async def search_airports( self, @@ -187,14 +191,19 @@ async def search_airports( query = query.where("name", ">=", name).where("name", "<=", name + "\uf8ff") docs = query.stream() - airports = [models.Airport.model_validate(dict(doc)) async for doc in docs] + airports = [] + async for doc in docs: + airport_dict = doc.to_dict() | {"id": doc.id} + airports.append(models.Airport.model_validate(airport_dict)) return airports async def get_amenity(self, id: int) -> Optional[models.Amenity]: query = self.__client.collection("amenities").where( filter=FieldFilter("id", "==", id) ) - return models.Amenity.model_validate(await query.get().to_dict()) + amenity_doc = await query.get() + amenity_dict = amenity_doc.to_dict() | {"id": amenity_doc.id} + return models.Amenity.model_validate(amenity_dict) async def amenities_search( self, query_embedding: list[float], similarity_threshold: float, top_k: int @@ -203,9 +212,11 @@ async def amenities_search( async def get_flight(self, flight_id: int) -> Optional[models.Flight]: query = self.__client.collection("flights").where( - filter=FieldFilter("id", "==", id) + filter=FieldFilter("id", "==", flight_id) ) - return models.Flight.model_validate(await query.get().to_dict()) + flight_doc = await query.get() + flight_dict = flight_doc.to_dict() | {"id": flight_doc.id} + return models.Flight.model_validate(flight_dict) async def search_flights_by_number( self, @@ -218,9 +229,11 @@ async def search_flights_by_number( .where(filter=FieldFilter("flight_number", "==", number)) ) - flights = [ - models.Flight.model_validate(dict(doc)) async for doc in query.stream() - ] + docs = query.stream() + flights = [] + async for doc in docs: + flight_dict = doc.to_dict() | {"id": doc.id} + flights.append(models.Flight.model_validate(flight_dict)) return flights async def search_flights_by_airports( @@ -243,7 +256,10 @@ async def search_flights_by_airports( query = query.where("arrival_airport", "==", arrival_airport) docs = query.stream() - flights = [models.Flight.model_validate(dict(doc)) async for doc in docs] + flights = [] + async for doc in docs: + flight_dict = doc.to_dict() | {"id": doc.id} + flights.append(models.Flight.model_validate(flight_dict)) return flights async def close(self):