Skip to content

Commit

Permalink
Merge pull request #151 from TheoPascoli/Feature-add-Postgresql-support
Browse files Browse the repository at this point in the history
Feature add postgresql support
  • Loading branch information
alicecaron authored Sep 17, 2024
2 parents 05b3b0f + bd8a2ef commit 24f703e
Show file tree
Hide file tree
Showing 8 changed files with 165 additions and 46 deletions.
3 changes: 3 additions & 0 deletions openstef_dbc/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@ def __call__(cls, *args, **kwargs):
)
return cls._instances[cls]

def clear(cls):
cls._instances = {}

@classmethod
def get_instance(cls, instance_cls):
return cls._instances[instance_cls]
112 changes: 84 additions & 28 deletions openstef_dbc/data_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from openstef_dbc import Singleton
from openstef_dbc.ktp_api import KtpApi
from openstef_dbc.log import logging
from enum import Enum

# Define abstract interface

Expand All @@ -21,7 +22,7 @@ class _DataInterface(metaclass=Singleton):
def __init__(self, config):
"""Generic data interface.
All connections and queries to the InfluxDB database, MySQL databases and
All connections and queries to the InfluxDB database, SQL databases and
influx API are governed by this class.
Args:
Expand All @@ -35,17 +36,44 @@ def __init__(self, config):
influxdb_host (str): InfluxDB host.
influxdb_port (int): InfluxDB port.
influx_organization (str): InfluxDB organization.
mysql_username (str): MySQL username.
mysql_password (str): MySQL password.
mysql_host (str): MySQL host.
mysql_port (int): MYSQL port.
mysql_database_name (str): MySQL database name.
sql_db_username (str): SQL database username.
sql_db_password (str): SQL database password.
sql_db_host (str): SQL database host.
sql_db_port (int): SQL database port.
sql_db_database_name (str): SQL database name.
proxies Union[dict[str, str], None]: Proxies.
sql_db_type (str, optional): SQL Database type engine to use('mysql' or 'postgresql'), if not defined mysql is used by default.
"""

self.logger = logging.get_logger(self.__class__.__name__)
self.influx_organization = config.influx_organization

# Get db type from config, set 'mysql' if the variable does not exist
self.sql_db_type = getattr(config, "sql_db_type", "MYSQL")

if self.sql_db_type not in SupportedSqlTypes.__members__.keys():
raise ValueError(
f"Unsupported database sql type '{self.sql_db_type}'. Please use one of the following {SupportedSqlTypes.__members__.keys()}."
)

# Set SQL engine according to given sql_db_type
if self.sql_db_type == SupportedSqlTypes.POSTGRESQL.name:
self.sql_engine = self._create_postgresql_engine(
username=config.sql_db_username,
password=config.sql_db_password,
host=config.sql_db_host,
port=config.sql_db_port,
db=config.sql_db_database_name,
)
else:
self.sql_engine = self._create_mysql_engine(
username=config.sql_db_username,
password=config.sql_db_password,
host=config.sql_db_host,
port=config.sql_db_port,
db=config.sql_db_database_name,
)

self.ktp_api = KtpApi(
username=config.api_username,
password=config.api_password,
Expand All @@ -65,14 +93,6 @@ def __init__(self, config):
self.influx_query_api = self.influx_client.query_api()
self.influx_write_api = self.influx_client.write_api(write_options=SYNCHRONOUS)

self.mysql_engine = self._create_mysql_engine(
username=config.mysql_username,
password=config.mysql_password,
host=config.mysql_host,
port=config.mysql_port,
db=config.mysql_database_name,
)

# Set geopy proxies
# https://geopy.readthedocs.io/en/stable/#geopy.geocoders.options
# https://docs.python.org/3/library/urllib.request.html#urllib.request.ProxyHandler
Expand All @@ -96,6 +116,9 @@ def get_instance():
"Please call _DataInterface(config) first."
) from exc

def get_sql_db_type(self):
return self.sql_db_type

def _create_influx_client(
self, token: str, host: str, port: int, organization: str
) -> None:
Expand Down Expand Up @@ -130,6 +153,23 @@ def _create_mysql_engine(
self.logger.error("Could not connect to MySQL database", exc_info=exc)
raise

def _create_postgresql_engine(
self, username: str, password: str, host: str, port: int, db: str
):
"""Create PostgreSQL engine.
Differs from sql_connection in the sense that this write_engine
*can* write pandas dataframe directly.
"""
connector = "postgresql+psycopg2"
database_url = f"{connector}://{username}:{password}@{host}:{port}/{db}"
try:
return sqlalchemy.create_engine(database_url)
except Exception as exc:
self.logger.error("Could not connect to PostgreSQL database", exc_info=exc)
raise

def exec_influx_query(self, query: str, bind_params: dict = {}) -> dict:
"""Execute an InfluxDB query.
Expand Down Expand Up @@ -223,33 +263,35 @@ def check_influx_available(self):

def exec_sql_query(self, query: str, params: dict = None):
try:
with self.mysql_engine.connect() as connection:
with self.sql_engine.connect() as connection:
if params is None:
params = {}
cursor = connection.execute(query, **params)
if cursor.cursor is not None:
return pd.DataFrame(cursor.fetchall())
except sqlalchemy.exc.OperationalError as e:
self.logger.error("Lost connection to MySQL database", exc_info=e)
self.logger.error(
"Lost connection to {} database".format(self.sql_db_type), exc_info=e
)
raise
except sqlalchemy.exc.ProgrammingError as e:
self.logger.error(
"Error occured during executing query", query=query, exc_info=e
)
raise
except sqlalchemy.exc.DatabaseError as e:
self.logger.error("Can't connect to MySQL database", exc_info=e)
self.logger.error(
"Can't connect to {} database".format(self.sql_db_type), exc_info=e
)
raise

def exec_sql_write(self, statement: str, params: dict = None) -> None:
try:
with self.mysql_engine.connect() as connection:
with self.sql_engine.connect() as connection:
response = connection.execute(statement, params=params)

self.logger.info(
"Added {} new systems to the systems table in the MySQL database".format(
response.rowcount
)
f"Added {response.rowcount} new systems to the systems table in the {self.sql_db_type} database"
)
except Exception as e:
self.logger.error(
Expand All @@ -260,13 +302,27 @@ def exec_sql_write(self, statement: str, params: dict = None) -> None:
def exec_sql_dataframe_write(
self, dataframe: pd.DataFrame, table: str, **kwargs
) -> None:
dataframe.to_sql(table, self.mysql_engine, **kwargs)
dataframe.to_sql(table, self.sql_engine, **kwargs)

def check_mysql_available(self):
"""Check if a basic mysql query gives a valid response"""
query = "SHOW DATABASES"
response = self.exec_sql_query(query)
def check_sql_available(self):
"""Check if a basic SQL query gives a valid response."""
query = "SELECT 1"

available = len(list(response["Database"])) > 0
try:
response = self.exec_sql_query(query)
available = response is not None and len(response) > 0

return available
if available:
return True
else:
print("The SQL query was executed, but no data was returned.")
return False

except Exception as e:
print(f"Error while checking {self.sql_db_type} availability: {e}")
return False


class SupportedSqlTypes(Enum):
MYSQL = "mysql"
POSTGRESQL = "postgresql"
13 changes: 7 additions & 6 deletions openstef_dbc/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,13 +113,14 @@ def __init__(self, config):
influxdb_token (str): Token to authenticate to InfluxDB.
influxdb_host (str): InfluxDB host.
influxdb_port (int): InfluxDB port.
mysql_username (str): MySQL username.
mysql_password (str): MySQL password.
mysql_host (str): MySQL host.
mysql_port (int): MYSQL port.
mysql_database_name (str): MySQL database name.
influx_organization (str): InfluxDB organization.
sql_db_username (str): SQL database username.
sql_db_password (str): SQL database password.
sql_db_host (str): SQL database host.
sql_db_port (int): SQL database port.
sql_db_database_name (str): SQL database name.
proxies Union[dict[str, str], None]: Proxies.
sql_db_type (str, optional): SQL Database type ('mysql' or 'postgresql').
"""

self._datainterface = _DataInterface(config)
Expand Down
16 changes: 15 additions & 1 deletion openstef_dbc/services/write.py
Original file line number Diff line number Diff line change
Expand Up @@ -548,8 +548,22 @@ def write_realised_pvdata(self, df: pd.DataFrame, region: str) -> None:
# Get rid of last comma
values = values[0:-1]

db_type = _DataInterface.get_instance().get_sql_db_type()

# Compose query for writing new systems
query = "INSERT IGNORE INTO `systems` (sid, region) VALUES " + values
if db_type == "mysql":
# Compose query for writing new systems in MySQL
query = "INSERT IGNORE INTO `systems` (sid, region) VALUES " + values
elif db_type == "postgresql":
# Compose query for writing new systems in PostgreSQL
query = (
"INSERT INTO systems (sid, region) VALUES "
+ values
+ " ON CONFLICT (sid) DO NOTHING"
)
else:
self.logger.error("Unsupported database type: {}".format(db_type))
return

# Execute query
_DataInterface.get_instance().exec_sql_write(query)
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ openstef~=3.4.4
pydantic-settings>=2.1.0,<3.0.0
influxdb-client~=1.36.1
mysql-connector-python~=8.3.0
psycopg2-binary~=2.9.6
PyMySQL~=1.0.2
PyYAML~=6.0
requests~=2.28.1
Expand Down
10 changes: 5 additions & 5 deletions tests/integration/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@ class Settings(BaseSettings):
influxdb_token: str = "tokenonlyfortesting"
influxdb_host: str = "http://localhost"
influxdb_port: str = "8086"
mysql_username: str = "test"
mysql_password: str = "test"
mysql_host: str = "localhost"
mysql_port: int = 1234
mysql_database_name: str = "test"
sql_db_username: str = "test"
sql_db_password: str = "test"
sql_db_host: str = "localhost"
sql_db_port: int = 1234
sql_db_database_name: str = "test"
proxies: Union[dict[str, str], None] = None
29 changes: 24 additions & 5 deletions tests/unit/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,28 @@ class Settings(BaseSettings):
influxdb_token: str = "token"
influxdb_host: str = "host"
influxdb_port: str = "123"
mysql_username: str = "test"
mysql_password: str = "test"
mysql_host: str = "host"
mysql_port: int = 123
mysql_database_name: str = "database_name"
sql_db_username: str = "test"
sql_db_password: str = "test"
sql_db_host: str = "host"
sql_db_port: int = 123
sql_db_database_name: str = "database_name"
proxies: Union[dict[str, str], None] = None
sql_db_type: str = "MYSQL"


class SettingsWithoutOptional(BaseSettings):
api_username: str = "test"
api_password: str = "demo"
api_admin_username: str = "test"
api_admin_password: str = "demo"
api_url: str = "localhost"
influx_organization: str = "myorg"
influxdb_token: str = "token"
influxdb_host: str = "host"
influxdb_port: str = "123"
sql_db_username: str = "test"
sql_db_password: str = "test"
sql_db_host: str = "host"
sql_db_port: int = 123
sql_db_database_name: str = "database_name"
proxies: Union[dict[str, str], None] = None
27 changes: 26 additions & 1 deletion tests/unit/test_data_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@
# SPDX-License-Identifier: MPL-2.0

import unittest
from copy import deepcopy
from unittest.mock import MagicMock, patch

import pandas as pd
from openstef_dbc.data_interface import _DataInterface
from tests.unit.settings import Settings
from tests.unit.settings import Settings, SettingsWithoutOptional


@patch("openstef_dbc.data_interface.KtpApi", MagicMock())
Expand Down Expand Up @@ -46,6 +47,30 @@ def test_get_instance(self):
# should be the same instance
self.assertIs(data_interface_1, data_interface_2)

def test_get_sql_db_type_for_mysql(self):
_DataInterface.clear()
config = Settings()
config.sql_db_type = "MYSQL"
self.assertEqual("MYSQL", _DataInterface(config).get_sql_db_type())

def test_get_sql_db_type_for_postgresql(self):
_DataInterface.clear()
config = Settings()
config.sql_db_type = "POSTGRESQL"
self.assertEqual("POSTGRESQL", _DataInterface(config).get_sql_db_type())

def test_get_sql_db_type_when_not_defined_in_settings(self):
_DataInterface.clear()
config = SettingsWithoutOptional()
self.assertEqual("MYSQL", _DataInterface(config).get_sql_db_type())

def test_get_sql_db_type_for_not_implemented_type(self):
_DataInterface.clear()
config = Settings()
config.sql_db_type = "oracle"
with self.assertRaises(ValueError):
_DataInterface(config)

@patch("openstef_dbc.Singleton.get_instance", side_effect=KeyError)
def test_get_instance_error(self, get_instance_mock):
with self.assertRaises(RuntimeError):
Expand Down

0 comments on commit 24f703e

Please sign in to comment.