Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
Move make_conn/get_pool into storage layer
Browse files Browse the repository at this point in the history
  • Loading branch information
erikjohnston committed Dec 16, 2019
1 parent 7872acd commit 439e043
Show file tree
Hide file tree
Showing 11 changed files with 126 additions and 109 deletions.
10 changes: 2 additions & 8 deletions scripts-dev/update_database
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ from synapse.config.homeserver import HomeServerConfig
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.server import HomeServer
from synapse.storage import DataStore
from synapse.storage.prepare_database import prepare_database

logger = logging.getLogger("update_database")

Expand Down Expand Up @@ -77,13 +76,8 @@ if __name__ == "__main__":
# Instantiate and initialise the homeserver object.
hs = MockHomeserver(config)

# Update the database to the latest schema.
database = hs.config.get_single_database()
db_conn = database.make_conn()
prepare_database(db_conn, database.engine, config=config)
db_conn.commit()

# setup instantiates the store within the homeserver object.
# Setup instantiates the store within the homeserver object and updates the
# DB.
hs.setup()
store = hs.get_datastore()

Expand Down
15 changes: 8 additions & 7 deletions scripts/synapse_port_db
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ from synapse.storage.data_stores.main.stats import StatsStore
from synapse.storage.data_stores.main.user_directory import (
UserDirectoryBackgroundUpdateStore,
)
from synapse.storage.database import Database
from synapse.storage.database import Database, make_conn
from synapse.storage.engines import create_engine
from synapse.storage.prepare_database import prepare_database
from synapse.util import Clock
Expand Down Expand Up @@ -440,9 +440,9 @@ class Porter(object):
else:
return

def setup_db(self, db_config: DatabaseConnectionConfig):
db_conn = db_config.make_conn()
prepare_database(db_conn, db_config.engine, config=None)
def setup_db(self, db_config: DatabaseConnectionConfig, engine):
db_conn = make_conn(db_config, engine)
prepare_database(db_conn, engine, config=None)

db_conn.commit()

Expand All @@ -460,15 +460,16 @@ class Porter(object):
"""
self.progress.set_state("Preparing %s" % db_config.config["name"])

conn = self.setup_db(db_config)
engine = create_engine(db_config.config)
conn = self.setup_db(db_config, engine)

hs = MockHomeserver(self.hs_config)

store = Store(Database(hs, db_config), conn, hs)
store = Store(Database(hs, db_config, engine), conn, hs)

yield store.db.runInteraction(
"%s_engine.check_database" % db_config.config["name"],
db_config.engine.check_database,
engine.check_database,
)

return store
Expand Down
34 changes: 0 additions & 34 deletions synapse/config/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,7 @@

import yaml

from twisted.enterprise import adbapi

from synapse.config._base import Config, ConfigError
from synapse.storage.engines import create_engine

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -53,37 +50,6 @@ def __init__(self, name: str, db_config: dict, data_stores: List[str]):
self.config = db_config
self.data_stores = data_stores

self.engine = create_engine(db_config)
self.config["args"]["cp_openfun"] = self.engine.on_new_connection

self._pool = None

def get_pool(self, reactor) -> adbapi.ConnectionPool:
"""Get the connection pool for the database.
"""

if self._pool is None:
self._pool = adbapi.ConnectionPool(
self.config["name"], cp_reactor=reactor, **self.config.get("args", {})
)

return self._pool

def make_conn(self):
"""Make a new connection to the database and return it.
Returns:
Connection
"""

db_params = {
k: v
for k, v in self.config.get("args", {}).items()
if not k.startswith("cp_")
}
db_conn = self.engine.module.connect(**db_params)
return db_conn


class DatabaseConfig(Config):
section = "database"
Expand Down
3 changes: 3 additions & 0 deletions synapse/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,9 @@ def get_clock(self):
def get_datastore(self):
return self.datastores.main

def get_datastores(self):
return self.datastores

def get_config(self):
return self.config

Expand Down
20 changes: 12 additions & 8 deletions synapse/storage/data_stores/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@

import logging

from synapse.storage.database import Database
from synapse.storage.database import Database, make_conn
from synapse.storage.engines import create_engine
from synapse.storage.prepare_database import prepare_database

logger = logging.getLogger(__name__)
Expand All @@ -34,25 +35,28 @@ def __init__(self, main_store_class, hs):
# Note we pass in the main store class here as workers use a different main
# store.

self.databases = []

for database_config in hs.config.database.databases:
db_name = database_config.name
with database_config.make_conn() as db_conn:
engine = create_engine(database_config.config)

with make_conn(database_config, engine) as db_conn:
logger.info("Preparing database %r...", db_name)

database_config.engine.check_database(db_conn.cursor())
engine.check_database(db_conn.cursor())
prepare_database(
db_conn,
database_config.engine,
hs.config,
data_stores=database_config.data_stores,
db_conn, engine, hs.config, data_stores=database_config.data_stores,
)

database = Database(hs, database_config)
database = Database(hs, database_config, engine)

if "main" in database_config.data_stores:
logger.info("Starting 'main' data store")
self.main = main_store_class(database, db_conn, hs)

db_conn.commit()

self.databases.append(database)

logger.info("Database %r prepared", db_name)
39 changes: 36 additions & 3 deletions synapse/storage/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,11 @@

from prometheus_client import Histogram

from twisted.enterprise import adbapi
from twisted.internet import defer

from synapse.api.errors import StoreError
from synapse.config.database import DatabaseConnectionConfig
from synapse.logging.context import LoggingContext, make_deferred_yieldable
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage.background_updates import BackgroundUpdater
Expand Down Expand Up @@ -74,6 +76,37 @@
}


def make_pool(
reactor, db_config: DatabaseConnectionConfig, engine
) -> adbapi.ConnectionPool:
"""Get the connection pool for the database.
"""

return adbapi.ConnectionPool(
db_config.config["name"],
cp_reactor=reactor,
cp_openfun=engine.on_new_connection,
**db_config.config.get("args", {})
)


def make_conn(db_config: DatabaseConnectionConfig, engine):
"""Make a new connection to the database and return it.
Returns:
Connection
"""

db_params = {
k: v
for k, v in db_config.config.get("args", {}).items()
if not k.startswith("cp_")
}
db_conn = engine.module.connect(**db_params)
engine.on_new_connection(db_conn)
return db_conn


class LoggingTransaction(object):
"""An object that almost-transparently proxies for the 'txn' object
passed to the constructor. Adds logging and metrics to the .execute()
Expand Down Expand Up @@ -218,11 +251,11 @@ class Database(object):

_TXN_ID = 0

def __init__(self, hs, database_config):
def __init__(self, hs, database_config: DatabaseConnectionConfig, engine):
self.hs = hs
self._clock = hs.get_clock()
self._database_config = database_config
self._db_pool = database_config.get_pool(hs.get_reactor())
self._db_pool = make_pool(hs.get_reactor(), database_config, engine)

self.updates = BackgroundUpdater(hs, self)

Expand All @@ -235,7 +268,7 @@ def __init__(self, hs, database_config):
# to watch it
self._txn_perf_counters = PerformanceCounters()

self.engine = database_config.engine
self.engine = engine

# A set of tables that are not safe to use native upserts in.
self._unsafe_to_upsert_tables = set(UNIQUE_INDEX_BACKGROUND_UPDATES.keys())
Expand Down
5 changes: 3 additions & 2 deletions tests/replication/slave/storage/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
ReplicationClientHandler,
)
from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory
from synapse.storage.database import Database
from synapse.storage.database import make_conn

from tests import unittest
from tests.server import FakeTransport
Expand All @@ -44,8 +44,9 @@ def prepare(self, reactor, clock, hs):
db_config = hs.config.database.get_single_database()
self.master_store = self.hs.get_datastore()
self.storage = hs.get_storage()
database = hs.get_datastores().databases[0]
self.slaved_store = self.STORE_TYPE(
Database(hs, db_config), db_config.get_pool(reactor).connect(), self.hs
database, make_conn(db_config, database.engine), self.hs
)
self.event_id = 0

Expand Down
44 changes: 23 additions & 21 deletions tests/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,33 +308,35 @@ def setup_test_homeserver(cleanup_func, *args, **kwargs):

# Make the thread pool synchronous.
clock = server.get_clock()
pool = database.get_pool(clock._reactor)

def runWithConnection(func, *args, **kwargs):
return threads.deferToThreadPool(
pool._reactor,
pool.threadpool,
pool._runWithConnection,
func,
*args,
**kwargs
)

def runInteraction(interaction, *args, **kwargs):
return threads.deferToThreadPool(
pool._reactor,
pool.threadpool,
pool._runInteraction,
interaction,
*args,
**kwargs
)
for database in server.get_datastores().databases:
pool = database._db_pool

def runWithConnection(func, *args, **kwargs):
return threads.deferToThreadPool(
pool._reactor,
pool.threadpool,
pool._runWithConnection,
func,
*args,
**kwargs
)

def runInteraction(interaction, *args, **kwargs):
return threads.deferToThreadPool(
pool._reactor,
pool.threadpool,
pool._runInteraction,
interaction,
*args,
**kwargs
)

if pool:
pool.runWithConnection = runWithConnection
pool.runInteraction = runInteraction
pool.threadpool = ThreadPool(clock._reactor)
pool.running = True

return server


Expand Down
Loading

0 comments on commit 439e043

Please sign in to comment.