Skip to content

Commit

Permalink
resolve comments
Browse files Browse the repository at this point in the history
  • Loading branch information
Yuan325 committed Oct 23, 2023
1 parent 2953f8e commit b31d68a
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 23 deletions.
8 changes: 4 additions & 4 deletions extension_service/datastore/datastore.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,10 @@ async def export_data(
) -> tuple[list[models.Airport], list[models.Amenity], List[models.Flight]]:
pass

@abstractmethod
async def get_airport(self, id: int) -> models.Airport | None:
raise NotImplementedError("Subclass should implement this!")

@abstractmethod
async def get_amenity(self, id: int) -> list[Dict[str, Any]]:
raise NotImplementedError("Subclass should implement this!")
Expand All @@ -69,10 +73,6 @@ async def amenities_search(
) -> list[Dict[str, Any]]:
raise NotImplementedError("Subclass should implement this!")

@abstractmethod
async def get_airport(self, id: int) -> list[models.Airport]:
raise NotImplementedError("Subclass should implement this!")

@abstractmethod
async def close(self):
pass
Expand Down
25 changes: 14 additions & 11 deletions extension_service/datastore/providers/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,20 @@ async def export_data(
flights = [models.Flight.model_validate(dict(f)) for f in await flights_task]
return airports, amenities, flights

async def get_airport(self, id: int) -> models.Airport | None:
result = await self.__pool.fetchrow(
"""
SELECT id, iata, name, city, country FROM airports WHERE id=$1
""",
id,
)

if result is None:
return None

result = models.Airport.model_validate(dict(result))
return result

async def get_amenity(self, id: int) -> list[Dict[str, Any]]:
results = await self.__pool.fetch(
"""
Expand Down Expand Up @@ -219,16 +233,5 @@ async def amenities_search(
results = [dict(r) for r in results]
return results

async def get_airport(self, id: int) -> list[models.Airport]:
results = await self.__pool.fetch(
"""
SELECT id, iata, name, city, country FROM airports WHERE id=$1
""",
id,
)

airports = [models.Airport.model_validate(dict(r)) for r in results]
return airports

async def close(self):
await self.__pool.close()
14 changes: 6 additions & 8 deletions extension_service/datastore/providers/postgres_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,14 +68,12 @@ async def test_get_airport():
mockCl = await mock_postgres_provider(mocks)
res = await mockCl.get_airport(1)
expected_res = [
models.Airport.model_validate(
{
"id": 1,
"iata": "FOO",
"name": "Foo Bar",
"city": "baz",
"country": "bundy",
}
models.Airport(
id=1,
iata="FOO",
name="Foo Bar",
city="baz",
country="bundy",
)
]
assert res == expected_res

0 comments on commit b31d68a

Please sign in to comment.