From 6b0387d80835e49f836c2eb0900eedcc3c729eeb Mon Sep 17 00:00:00 2001 From: Peter DeVita Date: Wed, 21 Feb 2024 15:59:14 -0500 Subject: [PATCH] Fix JSON and enum type columns --- databases/backends/common/records.py | 9 ++-- tests/test_databases.py | 75 +++++++++++++++++++++++++++- 2 files changed, 79 insertions(+), 5 deletions(-) diff --git a/databases/backends/common/records.py b/databases/backends/common/records.py index 1d8a2fd4..f3f81cc2 100644 --- a/databases/backends/common/records.py +++ b/databases/backends/common/records.py @@ -1,4 +1,4 @@ -import json +import enum import typing from datetime import date, datetime @@ -6,6 +6,7 @@ from sqlalchemy.engine.row import Row as SQLRow from sqlalchemy.sql.compiler import _CompileLabel from sqlalchemy.sql.schema import Column +from sqlalchemy.sql.sqltypes import JSON from sqlalchemy.types import TypeEngine from databases.interfaces import Record as RecordInterface @@ -63,10 +64,10 @@ def __getitem__(self, key: typing.Any) -> typing.Any: processor = datatype._cached_result_processor(self._dialect, None) if self._dialect.name not in DIALECT_EXCLUDE: - if isinstance(raw, dict): - raw = json.dumps(raw) + if isinstance(datatype, JSON): + return raw - if processor is not None and (not isinstance(raw, (datetime, date))): + if processor is not None and not isinstance(raw, (datetime, date, enum.Enum)): return processor(raw) return raw diff --git a/tests/test_databases.py b/tests/test_databases.py index cd907fd1..6bfc742a 100644 --- a/tests/test_databases.py +++ b/tests/test_databases.py @@ -1,6 +1,7 @@ import asyncio import datetime import decimal +import enum import functools import gc import itertools @@ -55,6 +56,30 @@ def process_result_value(self, value, dialect): sqlalchemy.Column("published", sqlalchemy.DateTime), ) + +class TshirtSize(enum.Enum): + SMALL = "SMALL" + MEDIUM = "MEDIUM" + LARGE = "LARGE" + XL = "XL" + + +class TshirtColor(enum.Enum): + BLUE = 0 + GREEN = 1 + YELLOW = 2 + RED = 3 + + +# Used to test Enum +tshirt_size = sqlalchemy.Table( + "tshirt_size", + metadata, + sqlalchemy.Column("id", sqlalchemy.Integer, primary_key=True), + sqlalchemy.Column("size", sqlalchemy.Enum(TshirtSize)), + sqlalchemy.Column("color", sqlalchemy.Enum(TshirtColor)) +) + # Used to test JSON session = sqlalchemy.Table( "session", @@ -957,7 +982,32 @@ async def test_decimal_field(database_url): @pytest.mark.parametrize("database_url", DATABASE_URLS) @async_adapter -async def test_json_field(database_url): +async def test_enum_field(database_url): + """ + Test JSON columns, to ensure correct cross-database support. + """ + + async with Database(database_url) as database: + async with database.transaction(force_rollback=True): + # execute() + size = TshirtSize.SMALL + color = TshirtColor.GREEN + values = {"size": size, "color": color} + query = tshirt_size.insert() + await database.execute(query, values) + + # fetch_all() + query = tshirt_size.select() + results = await database.fetch_all(query=query) + + assert len(results) == 1 + assert results[0]["size"] == size + assert results[0]["color"] == color + + +@pytest.mark.parametrize("database_url", DATABASE_URLS) +@async_adapter +async def test_json_dict_field(database_url): """ Test JSON columns, to ensure correct cross-database support. """ @@ -978,6 +1028,29 @@ async def test_json_field(database_url): assert results[0]["data"] == {"text": "hello", "boolean": True, "int": 1} +@pytest.mark.parametrize("database_url", DATABASE_URLS) +@async_adapter +async def test_json_list_field(database_url): + """ + Test JSON columns, to ensure correct cross-database support. + """ + + async with Database(database_url) as database: + async with database.transaction(force_rollback=True): + # execute() + data = ['lemon', 'raspberry', 'lime', 'pumice'] + values = {"data": data} + query = session.insert() + await database.execute(query, values) + + # fetch_all() + query = session.select() + results = await database.fetch_all(query=query) + + assert len(results) == 1 + assert results[0]["data"] == ['lemon', 'raspberry', 'lime', 'pumice'] + + @pytest.mark.parametrize("database_url", DATABASE_URLS) @async_adapter async def test_custom_field(database_url):