From f56af61389d2cc8b516d388a8fbe647fd697019c Mon Sep 17 00:00:00 2001 From: eminano Date: Thu, 6 Jun 2024 14:17:29 +0200 Subject: [PATCH] Use a temporary connection for replication start --- .../postgres/pg_replication_handler.go | 40 +++++++++---------- 1 file changed, 20 insertions(+), 20 deletions(-) diff --git a/internal/replication/postgres/pg_replication_handler.go b/internal/replication/postgres/pg_replication_handler.go index 9eb06ac..8839bba 100644 --- a/internal/replication/postgres/pg_replication_handler.go +++ b/internal/replication/postgres/pg_replication_handler.go @@ -17,12 +17,10 @@ import ( type Handler struct { logger loglib.Logger - // Create two connections. One for querying, one for handling replication - // events. - pgConn *pgx.Conn - pgReplicationConn *pgconn.PgConn + pgReplicationConn *pgconn.PgConn pgReplicationSlotName string + pgConnBuilder func() (*pgx.Conn, error) lsnParser replication.LSNParser } @@ -46,12 +44,12 @@ func NewHandler(ctx context.Context, cfg Config, opts ...Option) (*Handler, erro if err != nil { return nil, fmt.Errorf("failed parsing postgres connection string: %w", err) } - pgConn, err := pgx.ConnectConfig(ctx, pgCfg) - if err != nil { - return nil, fmt.Errorf("create postgres client: %w", err) + + connBuilder := func() (*pgx.Conn, error) { + return pgx.ConnectConfig(ctx, pgCfg) } - // open a second Postgres connection, this one dedicated for replication + // open a Postgres connection dedicated for replication copyConfig := pgCfg.Copy() copyConfig.RuntimeParams["replication"] = "database" @@ -62,8 +60,8 @@ func NewHandler(ctx context.Context, cfg Config, opts ...Option) (*Handler, erro h := &Handler{ logger: loglib.NewNoopLogger(), - pgConn: pgConn, pgReplicationConn: pgReplicationConn, + pgConnBuilder: connBuilder, lsnParser: &LSNParser{}, } @@ -98,7 +96,13 @@ func (h *Handler) StartReplication(ctx context.Context) error { logLSNPosition: sysID.XLogPos, }) - startPos, err := h.getLastSyncedLSN(ctx) + conn, err := h.pgConnBuilder() + if err != nil { + return fmt.Errorf("creating pg connection: %w", err) + } + defer conn.Close(ctx) + + startPos, err := h.getLastSyncedLSN(ctx, conn) if err != nil { return fmt.Errorf("read last position: %w", err) } @@ -111,7 +115,7 @@ func (h *Handler) StartReplication(ctx context.Context) error { // todo(deverts): If we don't have a position. Read from as early as possible. // this _could_ be too old. In the future, it would be good to calculate if we're // too far behind, so we can fix it. - startPos, err = h.getRestartLSN(ctx, h.pgReplicationSlotName) + startPos, err = h.getRestartLSN(ctx, conn, h.pgReplicationSlotName) if err != nil { return fmt.Errorf("get restart LSN: %w", err) } @@ -200,19 +204,15 @@ func (h *Handler) GetLSNParser() replication.LSNParser { // Close closes the database connections. func (h *Handler) Close() error { - err := h.pgReplicationConn.Close(context.Background()) - if err != nil { - return err - } - return h.pgConn.Close(context.Background()) + return h.pgReplicationConn.Close(context.Background()) } // getRestartLSN returns the absolute earliest possible LSN we can support. If // the consumer's LSN is earlier than this, we cannot (easily) catch the // consumer back up. -func (h *Handler) getRestartLSN(ctx context.Context, slotName string) (replication.LSN, error) { +func (h *Handler) getRestartLSN(ctx context.Context, conn *pgx.Conn, slotName string) (replication.LSN, error) { var restartLSN string - err := h.pgConn.QueryRow( + err := conn.QueryRow( ctx, `select restart_lsn from pg_replication_slots where slot_name=$1`, slotName, @@ -226,9 +226,9 @@ func (h *Handler) getRestartLSN(ctx context.Context, slotName string) (replicati // getLastSyncedLSN gets the `confirmed_flush_lsn` from PG. This is the last LSN // that the consumer confirmed it had completed. -func (h *Handler) getLastSyncedLSN(ctx context.Context) (replication.LSN, error) { +func (h *Handler) getLastSyncedLSN(ctx context.Context, conn *pgx.Conn) (replication.LSN, error) { var confirmedFlushLSN string - err := h.pgConn.QueryRow(ctx, `select confirmed_flush_lsn from pg_replication_slots where slot_name=$1`, h.pgReplicationSlotName).Scan(&confirmedFlushLSN) + err := conn.QueryRow(ctx, `select confirmed_flush_lsn from pg_replication_slots where slot_name=$1`, h.pgReplicationSlotName).Scan(&confirmedFlushLSN) if err != nil { return 0, err }