From af2e4889c5b2e8fd925507dab5473c52fd1e576e Mon Sep 17 00:00:00 2001 From: "Edgar R. M" Date: Mon, 7 Aug 2023 14:59:47 -0600 Subject: [PATCH] fix(targets): Correctly serialize `decimal.Decimal` in JSON fields of SQL targets (#1898) --- pyproject.toml | 6 ++--- singer_sdk/connectors/sql.py | 42 +++++++++++++++++++++++++++++++- tests/core/test_connector_sql.py | 25 +++++++++++++++++++ 3 files changed, 69 insertions(+), 4 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index bceced1d8..a795820d9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -2,8 +2,8 @@ name = "singer-sdk" version = "0.30.0" description = "A framework for building Singer taps" -authors = ["Meltano Team and Contributors"] -maintainers = ["Meltano Team and Contributors"] +authors = ["Meltano Team and Contributors "] +maintainers = ["Meltano Team and Contributors "] readme = "README.md" homepage = "https://sdk.meltano.com/en/latest/" repository = "https://github.com/meltano/sdk" @@ -144,7 +144,7 @@ name = "cz_version_bump" version = "0.30.0" tag_format = "v$major.$minor.$patch$prerelease" version_files = [ - "docs/conf.py", + "docs/conf.py:^release =", "pyproject.toml:^version =", "cookiecutter/tap-template/{{cookiecutter.tap_id}}/pyproject.toml:singer-sdk", "cookiecutter/target-template/{{cookiecutter.target_id}}/pyproject.toml:singer-sdk", diff --git a/singer_sdk/connectors/sql.py b/singer_sdk/connectors/sql.py index e9a65cf80..e05e359da 100644 --- a/singer_sdk/connectors/sql.py +++ b/singer_sdk/connectors/sql.py @@ -2,6 +2,8 @@ from __future__ import annotations +import decimal +import json import logging import typing as t import warnings @@ -9,6 +11,7 @@ from datetime import datetime from functools import lru_cache +import simplejson import sqlalchemy from sqlalchemy.engine import Engine @@ -316,7 +319,12 @@ def create_engine(self) -> Engine: Returns: A new SQLAlchemy Engine. """ - return sqlalchemy.create_engine(self.sqlalchemy_url, echo=False) + return sqlalchemy.create_engine( + self.sqlalchemy_url, + echo=False, + json_serializer=self.serialize_json, + json_deserializer=self.deserialize_json, + ) def quote(self, name: str) -> str: """Quote a name if it needs quoting, using '.' as a name-part delimiter. @@ -1124,3 +1132,35 @@ def _adapt_column_type( ) with self._connect() as conn: conn.execute(alter_column_ddl) + + def serialize_json(self, obj: object) -> str: + """Serialize an object to a JSON string. + + Target connectors may override this method to provide custom serialization logic + for JSON types. + + Args: + obj: The object to serialize. + + Returns: + The JSON string. + + .. versionadded:: 0.31.0 + """ + return simplejson.dumps(obj, use_decimal=True) + + def deserialize_json(self, json_str: str) -> object: + """Deserialize a JSON string to an object. + + Tap connectors may override this method to provide custom deserialization + logic for JSON types. + + Args: + json_str: The JSON string to deserialize. + + Returns: + The deserialized object. + + .. versionadded:: 0.31.0 + """ + return json.loads(json_str, parse_float=decimal.Decimal) diff --git a/tests/core/test_connector_sql.py b/tests/core/test_connector_sql.py index 1c04dbcdd..58ba59ec7 100644 --- a/tests/core/test_connector_sql.py +++ b/tests/core/test_connector_sql.py @@ -1,5 +1,6 @@ from __future__ import annotations +from decimal import Decimal from unittest import mock import pytest @@ -258,3 +259,27 @@ def test_merge_generic_sql_types( ): merged_type = connector.merge_sql_types(types) assert isinstance(merged_type, expected_type) + + def test_engine_json_serialization(self, connector: SQLConnector): + engine = connector._engine + meta = sqlalchemy.MetaData() + table = sqlalchemy.Table( + "test_table", + meta, + sqlalchemy.Column("id", sqlalchemy.Integer, primary_key=True), + sqlalchemy.Column("attrs", sqlalchemy.JSON), + ) + meta.create_all(engine) + with engine.connect() as conn: + conn.execute( + table.insert(), + [ + {"attrs": {"x": Decimal("1.0")}}, + {"attrs": {"x": Decimal("2.0"), "y": [1, 2, 3]}}, + ], + ) + result = conn.execute(table.select()) + assert result.fetchall() == [ + (1, {"x": Decimal("1.0")}), + (2, {"x": Decimal("2.0"), "y": [1, 2, 3]}), + ]