diff --git a/extension_service/datastore/providers/firestore.py b/extension_service/datastore/providers/firestore.py index 036e485d..a873ccc6 100644 --- a/extension_service/datastore/providers/firestore.py +++ b/extension_service/datastore/providers/firestore.py @@ -161,13 +161,17 @@ async def get_airport_by_id(self, id: int) -> Optional[models.Airport]: query = self.__client.collection("airports").where( filter=FieldFilter("id", "==", id) ) - return models.Airport.model_validate(await query.get().to_dict()) + airport_doc = await query.get() + airport_dict = airport_doc.to_dict() | {"id": airport_doc.id} + return models.Airport.model_validate(airport_dict) async def get_airport_by_iata(self, iata: str) -> Optional[models.Airport]: query = self.__client.collection("airports").where( filter=FieldFilter("iata", "==", iata) ) - return models.Airport.model_validate(await query.get().to_dict()) + airport_doc = await query.get() + airport_dict = airport_doc.to_dict() | {"id": airport_doc.id} + return models.Airport.model_validate(airport_dict) async def search_airports( self, @@ -187,14 +191,19 @@ async def search_airports( query = query.where("name", ">=", name).where("name", "<=", name + "\uf8ff") docs = query.stream() - airports = [models.Airport.model_validate(dict(doc)) async for doc in docs] + airports = [] + async for doc in docs: + airport_dict = doc.to_dict() | {"id": doc.id} + airports.append(models.Airport.model_validate(airport_dict)) return airports async def get_amenity(self, id: int) -> Optional[models.Amenity]: query = self.__client.collection("amenities").where( filter=FieldFilter("id", "==", id) ) - return models.Amenity.model_validate(await query.get().to_dict()) + amenity_doc = await query.get() + amenity_dict = amenity_doc.to_dict() | {"id": amenity_doc.id} + return models.Amenity.model_validate(amenity_dict) async def amenities_search( self, query_embedding: list[float], similarity_threshold: float, top_k: int @@ -203,9 +212,11 @@ async def amenities_search( async def get_flight(self, flight_id: int) -> Optional[models.Flight]: query = self.__client.collection("flights").where( - filter=FieldFilter("id", "==", id) + filter=FieldFilter("id", "==", flight_id) ) - return models.Flight.model_validate(await query.get().to_dict()) + flight_doc = await query.get() + flight_dict = flight_doc.to_dict() | {"id": flight_doc.id} + return models.Flight.model_validate(flight_dict) async def search_flights_by_number( self, @@ -218,9 +229,11 @@ async def search_flights_by_number( .where(filter=FieldFilter("flight_number", "==", number)) ) - flights = [ - models.Flight.model_validate(dict(doc)) async for doc in query.stream() - ] + docs = query.stream() + flights = [] + async for doc in docs: + flight_dict = doc.to_dict() | {"id": doc.id} + flights.append(models.Flight.model_validate(flight_dict)) return flights async def search_flights_by_airports( @@ -243,7 +256,10 @@ async def search_flights_by_airports( query = query.where("arrival_airport", "==", arrival_airport) docs = query.stream() - flights = [models.Flight.model_validate(dict(doc)) async for doc in docs] + flights = [] + async for doc in docs: + flight_dict = doc.to_dict() | {"id": doc.id} + flights.append(models.Flight.model_validate(flight_dict)) return flights async def close(self):