Skip to content

Commit

Permalink
Merge pull request #36 from xataio/replication-handler-use-temporary-…
Browse files Browse the repository at this point in the history
…connection

Use a temporary connection for replication start
  • Loading branch information
eminano authored Jun 6, 2024
2 parents 888e76b + f56af61 commit 3180976
Showing 1 changed file with 20 additions and 20 deletions.
40 changes: 20 additions & 20 deletions internal/replication/postgres/pg_replication_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,10 @@ import (

type Handler struct {
logger loglib.Logger
// Create two connections. One for querying, one for handling replication
// events.
pgConn *pgx.Conn
pgReplicationConn *pgconn.PgConn

pgReplicationConn *pgconn.PgConn
pgReplicationSlotName string
pgConnBuilder func() (*pgx.Conn, error)

lsnParser replication.LSNParser
}
Expand All @@ -46,12 +44,12 @@ func NewHandler(ctx context.Context, cfg Config, opts ...Option) (*Handler, erro
if err != nil {
return nil, fmt.Errorf("failed parsing postgres connection string: %w", err)
}
pgConn, err := pgx.ConnectConfig(ctx, pgCfg)
if err != nil {
return nil, fmt.Errorf("create postgres client: %w", err)

connBuilder := func() (*pgx.Conn, error) {
return pgx.ConnectConfig(ctx, pgCfg)
}

// open a second Postgres connection, this one dedicated for replication
// open a Postgres connection dedicated for replication
copyConfig := pgCfg.Copy()
copyConfig.RuntimeParams["replication"] = "database"

Expand All @@ -62,8 +60,8 @@ func NewHandler(ctx context.Context, cfg Config, opts ...Option) (*Handler, erro

h := &Handler{
logger: loglib.NewNoopLogger(),
pgConn: pgConn,
pgReplicationConn: pgReplicationConn,
pgConnBuilder: connBuilder,
lsnParser: &LSNParser{},
}

Expand Down Expand Up @@ -98,7 +96,13 @@ func (h *Handler) StartReplication(ctx context.Context) error {
logLSNPosition: sysID.XLogPos,
})

startPos, err := h.getLastSyncedLSN(ctx)
conn, err := h.pgConnBuilder()
if err != nil {
return fmt.Errorf("creating pg connection: %w", err)
}
defer conn.Close(ctx)

startPos, err := h.getLastSyncedLSN(ctx, conn)
if err != nil {
return fmt.Errorf("read last position: %w", err)
}
Expand All @@ -111,7 +115,7 @@ func (h *Handler) StartReplication(ctx context.Context) error {
// todo(deverts): If we don't have a position. Read from as early as possible.
// this _could_ be too old. In the future, it would be good to calculate if we're
// too far behind, so we can fix it.
startPos, err = h.getRestartLSN(ctx, h.pgReplicationSlotName)
startPos, err = h.getRestartLSN(ctx, conn, h.pgReplicationSlotName)
if err != nil {
return fmt.Errorf("get restart LSN: %w", err)
}
Expand Down Expand Up @@ -200,19 +204,15 @@ func (h *Handler) GetLSNParser() replication.LSNParser {

// Close closes the database connections.
func (h *Handler) Close() error {
err := h.pgReplicationConn.Close(context.Background())
if err != nil {
return err
}
return h.pgConn.Close(context.Background())
return h.pgReplicationConn.Close(context.Background())
}

// getRestartLSN returns the absolute earliest possible LSN we can support. If
// the consumer's LSN is earlier than this, we cannot (easily) catch the
// consumer back up.
func (h *Handler) getRestartLSN(ctx context.Context, slotName string) (replication.LSN, error) {
func (h *Handler) getRestartLSN(ctx context.Context, conn *pgx.Conn, slotName string) (replication.LSN, error) {
var restartLSN string
err := h.pgConn.QueryRow(
err := conn.QueryRow(
ctx,
`select restart_lsn from pg_replication_slots where slot_name=$1`,
slotName,
Expand All @@ -226,9 +226,9 @@ func (h *Handler) getRestartLSN(ctx context.Context, slotName string) (replicati

// getLastSyncedLSN gets the `confirmed_flush_lsn` from PG. This is the last LSN
// that the consumer confirmed it had completed.
func (h *Handler) getLastSyncedLSN(ctx context.Context) (replication.LSN, error) {
func (h *Handler) getLastSyncedLSN(ctx context.Context, conn *pgx.Conn) (replication.LSN, error) {
var confirmedFlushLSN string
err := h.pgConn.QueryRow(ctx, `select confirmed_flush_lsn from pg_replication_slots where slot_name=$1`, h.pgReplicationSlotName).Scan(&confirmedFlushLSN)
err := conn.QueryRow(ctx, `select confirmed_flush_lsn from pg_replication_slots where slot_name=$1`, h.pgReplicationSlotName).Scan(&confirmedFlushLSN)
if err != nil {
return 0, err
}
Expand Down

0 comments on commit 3180976

Please sign in to comment.