diff --git a/databases/backends/common/records.py b/databases/backends/common/records.py index 1d8a2fd4..e963af50 100644 --- a/databases/backends/common/records.py +++ b/databases/backends/common/records.py @@ -1,11 +1,12 @@ -import json +import enum import typing -from datetime import date, datetime +from datetime import date, datetime, time from sqlalchemy.engine.interfaces import Dialect from sqlalchemy.engine.row import Row as SQLRow from sqlalchemy.sql.compiler import _CompileLabel from sqlalchemy.sql.schema import Column +from sqlalchemy.sql.sqltypes import JSON from sqlalchemy.types import TypeEngine from databases.interfaces import Record as RecordInterface @@ -62,12 +63,10 @@ def __getitem__(self, key: typing.Any) -> typing.Any: raw = self._row[idx] processor = datatype._cached_result_processor(self._dialect, None) - if self._dialect.name not in DIALECT_EXCLUDE: - if isinstance(raw, dict): - raw = json.dumps(raw) + if self._dialect.name in DIALECT_EXCLUDE: + if processor is not None and isinstance(raw, (int, str, float)): + return processor(raw) - if processor is not None and (not isinstance(raw, (datetime, date))): - return processor(raw) return raw def __iter__(self) -> typing.Iterator: diff --git a/tests/test_databases.py b/tests/test_databases.py index cd907fd1..d9d9e4d6 100644 --- a/tests/test_databases.py +++ b/tests/test_databases.py @@ -1,6 +1,7 @@ import asyncio import datetime import decimal +import enum import functools import gc import itertools @@ -55,6 +56,47 @@ def process_result_value(self, value, dialect): sqlalchemy.Column("published", sqlalchemy.DateTime), ) +# Used to test Date +events = sqlalchemy.Table( + "events", + metadata, + sqlalchemy.Column("id", sqlalchemy.Integer, primary_key=True), + sqlalchemy.Column("date", sqlalchemy.Date), +) + + +# Used to test Time +daily_schedule = sqlalchemy.Table( + "daily_schedule", + metadata, + sqlalchemy.Column("id", sqlalchemy.Integer, primary_key=True), + sqlalchemy.Column("time", sqlalchemy.Time), +) + + +class TshirtSize(enum.Enum): + SMALL = "SMALL" + MEDIUM = "MEDIUM" + LARGE = "LARGE" + XL = "XL" + + +class TshirtColor(enum.Enum): + BLUE = 0 + GREEN = 1 + YELLOW = 2 + RED = 3 + + +# Used to test Enum +tshirt_size = sqlalchemy.Table( + "tshirt_size", + metadata, + sqlalchemy.Column("id", sqlalchemy.Integer, primary_key=True), + sqlalchemy.Column("size", sqlalchemy.Enum(TshirtSize)), + sqlalchemy.Column("color", sqlalchemy.Enum(TshirtColor)), +) + # Used to test JSON session = sqlalchemy.Table( "session", @@ -928,6 +970,52 @@ async def test_datetime_field(database_url): assert results[0]["published"] == now +@pytest.mark.parametrize("database_url", DATABASE_URLS) +@async_adapter +async def test_date_field(database_url): + """ + Test Date columns, to ensure records are coerced to/from proper Python types. + """ + + async with Database(database_url) as database: + async with database.transaction(force_rollback=True): + now = datetime.date.today() + + # execute() + query = events.insert() + values = {"date": now} + await database.execute(query, values) + + # fetch_all() + query = events.select() + results = await database.fetch_all(query=query) + assert len(results) == 1 + assert results[0]["date"] == now + + +@pytest.mark.parametrize("database_url", DATABASE_URLS) +@async_adapter +async def test_time_field(database_url): + """ + Test Time columns, to ensure records are coerced to/from proper Python types. + """ + + async with Database(database_url) as database: + async with database.transaction(force_rollback=True): + now = datetime.datetime.now().time().replace(microsecond=0) + + # execute() + query = daily_schedule.insert() + values = {"time": now} + await database.execute(query, values) + + # fetch_all() + query = daily_schedule.select() + results = await database.fetch_all(query=query) + assert len(results) == 1 + assert results[0]["time"] == now + + @pytest.mark.parametrize("database_url", DATABASE_URLS) @async_adapter async def test_decimal_field(database_url): @@ -957,7 +1045,32 @@ async def test_decimal_field(database_url): @pytest.mark.parametrize("database_url", DATABASE_URLS) @async_adapter -async def test_json_field(database_url): +async def test_enum_field(database_url): + """ + Test enum columns, to ensure correct cross-database support. + """ + + async with Database(database_url) as database: + async with database.transaction(force_rollback=True): + # execute() + size = TshirtSize.SMALL + color = TshirtColor.GREEN + values = {"size": size, "color": color} + query = tshirt_size.insert() + await database.execute(query, values) + + # fetch_all() + query = tshirt_size.select() + results = await database.fetch_all(query=query) + + assert len(results) == 1 + assert results[0]["size"] == size + assert results[0]["color"] == color + + +@pytest.mark.parametrize("database_url", DATABASE_URLS) +@async_adapter +async def test_json_dict_field(database_url): """ Test JSON columns, to ensure correct cross-database support. """ @@ -978,6 +1091,29 @@ async def test_json_field(database_url): assert results[0]["data"] == {"text": "hello", "boolean": True, "int": 1} +@pytest.mark.parametrize("database_url", DATABASE_URLS) +@async_adapter +async def test_json_list_field(database_url): + """ + Test JSON columns, to ensure correct cross-database support. + """ + + async with Database(database_url) as database: + async with database.transaction(force_rollback=True): + # execute() + data = ["lemon", "raspberry", "lime", "pumice"] + values = {"data": data} + query = session.insert() + await database.execute(query, values) + + # fetch_all() + query = session.select() + results = await database.fetch_all(query=query) + + assert len(results) == 1 + assert results[0]["data"] == ["lemon", "raspberry", "lime", "pumice"] + + @pytest.mark.parametrize("database_url", DATABASE_URLS) @async_adapter async def test_custom_field(database_url):