From aaa5dba16a4ed728d7c55c624926ecb021654ff5 Mon Sep 17 00:00:00 2001 From: eminano Date: Mon, 1 Jul 2024 17:22:06 +0200 Subject: [PATCH 1/3] Add postgres internal library --- internal/postgres/errors.go | 19 +++ internal/postgres/mocks/mock_pg_conn.go | 24 +++ .../mocks/mock_pg_replication_conn.go | 37 +++++ internal/postgres/pg_conn.go | 40 +++++ internal/postgres/pg_replication_conn.go | 138 ++++++++++++++++++ 5 files changed, 258 insertions(+) create mode 100644 internal/postgres/errors.go create mode 100644 internal/postgres/mocks/mock_pg_conn.go create mode 100644 internal/postgres/mocks/mock_pg_replication_conn.go create mode 100644 internal/postgres/pg_conn.go create mode 100644 internal/postgres/pg_replication_conn.go diff --git a/internal/postgres/errors.go b/internal/postgres/errors.go new file mode 100644 index 0000000..c0cd5a2 --- /dev/null +++ b/internal/postgres/errors.go @@ -0,0 +1,19 @@ +// SPDX-License-Identifier: Apache-2.0 + +package postgres + +import ( + "errors" + + "github.com/jackc/pgx/v5/pgconn" +) + +var ErrConnTimeout = errors.New("connection timeout") + +func mapError(err error) error { + if pgconn.Timeout(err) { + return ErrConnTimeout + } + + return err +} diff --git a/internal/postgres/mocks/mock_pg_conn.go b/internal/postgres/mocks/mock_pg_conn.go new file mode 100644 index 0000000..c8eb128 --- /dev/null +++ b/internal/postgres/mocks/mock_pg_conn.go @@ -0,0 +1,24 @@ +// SPDX-License-Identifier: Apache-2.0 + +package mocks + +import ( + "context" + + "github.com/jackc/pgx/v5/pgconn" + "github.com/xataio/pgstream/internal/postgres" +) + +type Conn struct { + QueryRowFn func(ctx context.Context, query string, args ...any) postgres.Row + ExecFn func(context.Context, string, ...any) (pgconn.CommandTag, error) + CloseFn func(context.Context) error +} + +func (m *Conn) QueryRow(ctx context.Context, query string, args ...any) postgres.Row { + return m.QueryRowFn(ctx, query, args...) +} + +func (m *Conn) Close(ctx context.Context) error { + return m.CloseFn(ctx) +} diff --git a/internal/postgres/mocks/mock_pg_replication_conn.go b/internal/postgres/mocks/mock_pg_replication_conn.go new file mode 100644 index 0000000..f2a6fba --- /dev/null +++ b/internal/postgres/mocks/mock_pg_replication_conn.go @@ -0,0 +1,37 @@ +// SPDX-License-Identifier: Apache-2.0 + +package mocks + +import ( + "context" + + "github.com/xataio/pgstream/internal/postgres" +) + +type ReplicationConn struct { + IdentifySystemFn func(ctx context.Context) (postgres.IdentifySystemResult, error) + StartReplicationFn func(ctx context.Context, cfg postgres.ReplicationConfig) error + SendStandbyStatusUpdateFn func(ctx context.Context, lsn uint64) error + ReceiveMessageFn func(ctx context.Context) (*postgres.ReplicationMessage, error) + CloseFn func(ctx context.Context) error +} + +func (m *ReplicationConn) IdentifySystem(ctx context.Context) (postgres.IdentifySystemResult, error) { + return m.IdentifySystemFn(ctx) +} + +func (m *ReplicationConn) StartReplication(ctx context.Context, cfg postgres.ReplicationConfig) error { + return m.StartReplicationFn(ctx, cfg) +} + +func (m *ReplicationConn) SendStandbyStatusUpdate(ctx context.Context, lsn uint64) error { + return m.SendStandbyStatusUpdateFn(ctx, lsn) +} + +func (m *ReplicationConn) ReceiveMessage(ctx context.Context) (*postgres.ReplicationMessage, error) { + return m.ReceiveMessageFn(ctx) +} + +func (m *ReplicationConn) Close(ctx context.Context) error { + return m.CloseFn(ctx) +} diff --git a/internal/postgres/pg_conn.go b/internal/postgres/pg_conn.go new file mode 100644 index 0000000..1cc334b --- /dev/null +++ b/internal/postgres/pg_conn.go @@ -0,0 +1,40 @@ +// SPDX-License-Identifier: Apache-2.0 + +package postgres + +import ( + "context" + "fmt" + + "github.com/jackc/pgx/v5" +) + +type Conn struct { + conn *pgx.Conn +} + +type Row interface { + pgx.Row +} + +func NewConn(ctx context.Context, url string) (*Conn, error) { + pgCfg, err := pgx.ParseConfig(url) + if err != nil { + return nil, fmt.Errorf("failed parsing postgres connection string: %w", mapError(err)) + } + + conn, err := pgx.ConnectConfig(ctx, pgCfg) + if err != nil { + return nil, fmt.Errorf("failed to connect to postgres: %w", mapError(err)) + } + + return &Conn{conn: conn}, nil +} + +func (c *Conn) QueryRow(ctx context.Context, query string, args ...any) Row { + return c.conn.QueryRow(ctx, query, args...) +} + +func (c *Conn) Close(ctx context.Context) error { + return mapError(c.conn.Close(ctx)) +} diff --git a/internal/postgres/pg_replication_conn.go b/internal/postgres/pg_replication_conn.go new file mode 100644 index 0000000..8cb922e --- /dev/null +++ b/internal/postgres/pg_replication_conn.go @@ -0,0 +1,138 @@ +// SPDX-License-Identifier: Apache-2.0 + +package postgres + +import ( + "context" + "errors" + "fmt" + "time" + + "github.com/jackc/pglogrepl" + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgconn" + "github.com/jackc/pgx/v5/pgproto3" +) + +type ReplicationConn struct { + conn *pgconn.PgConn +} + +type ReplicationConfig struct { + SlotName string + StartPos uint64 + PluginArguments []string +} + +type ReplicationMessage struct { + LSN uint64 + ServerTime time.Time + WALData []byte + ReplyRequested bool +} + +type IdentifySystemResult pglogrepl.IdentifySystemResult + +var ErrUnsupportedCopyDataMessage = errors.New("unsupported copy data message") + +func NewReplicationConn(ctx context.Context, url string) (*ReplicationConn, error) { + pgCfg, err := pgx.ParseConfig(url) + if err != nil { + return nil, fmt.Errorf("failed parsing postgres connection string: %w", err) + } + + pgCfg.RuntimeParams["replication"] = "database" + + conn, err := pgconn.ConnectConfig(context.Background(), &pgCfg.Config) + if err != nil { + return nil, fmt.Errorf("create postgres replication client: %w", mapError(err)) + } + + return &ReplicationConn{ + conn: conn, + }, nil +} + +func (c *ReplicationConn) IdentifySystem(ctx context.Context) (IdentifySystemResult, error) { + res, err := pglogrepl.IdentifySystem(ctx, c.conn) + return IdentifySystemResult(res), mapError(err) +} + +func (c *ReplicationConn) StartReplication(ctx context.Context, cfg ReplicationConfig) error { + return mapError(pglogrepl.StartReplication( + ctx, + c.conn, + cfg.SlotName, + pglogrepl.LSN(cfg.StartPos), + pglogrepl.StartReplicationOptions{PluginArgs: cfg.PluginArguments})) +} + +func (c *ReplicationConn) SendStandbyStatusUpdate(ctx context.Context, lsn uint64) error { + return mapError(pglogrepl.SendStandbyStatusUpdate( + ctx, + c.conn, + pglogrepl.StandbyStatusUpdate{WALWritePosition: pglogrepl.LSN(lsn)}, + )) +} + +func (c *ReplicationConn) ReceiveMessage(ctx context.Context) (*ReplicationMessage, error) { + msg, err := c.conn.ReceiveMessage(ctx) + if err != nil { + return nil, mapError(err) + } + + switch msg := msg.(type) { + case *pgproto3.CopyData: + switch msg.Data[0] { + case pglogrepl.PrimaryKeepaliveMessageByteID: + pka, err := pglogrepl.ParsePrimaryKeepaliveMessage(msg.Data[1:]) + if err != nil { + return nil, fmt.Errorf("parse keep alive: %w", err) + } + return &ReplicationMessage{ + LSN: uint64(pka.ServerWALEnd), + ServerTime: pka.ServerTime, + ReplyRequested: pka.ReplyRequested, + }, nil + case pglogrepl.XLogDataByteID: + xld, err := pglogrepl.ParseXLogData(msg.Data[1:]) + if err != nil { + return nil, fmt.Errorf("parse xlog data: %w", err) + } + + return &ReplicationMessage{ + LSN: uint64(xld.WALStart) + uint64(len(xld.WALData)), + ServerTime: xld.ServerTime, + WALData: xld.WALData, + }, nil + default: + return nil, fmt.Errorf("%v: %w", msg.Data[0], ErrUnsupportedCopyDataMessage) + } + case *pgproto3.NoticeResponse: + return nil, parseErrNoticeResponse(msg) + default: + // unexpected message (WAL error?) + return nil, fmt.Errorf("unexpected message: %#v", msg) + } +} + +func (c *ReplicationConn) Close(ctx context.Context) error { + return mapError(c.conn.Close(ctx)) +} + +type Error struct { + Severity string + Msg string +} + +func (e *Error) Error() string { + return fmt.Sprintf("replication error: %s", e.Msg) +} + +func parseErrNoticeResponse(errMsg *pgproto3.NoticeResponse) error { + return &Error{ + Severity: errMsg.Severity, + Msg: fmt.Sprintf("replication notice response: severity: %s, code: %s, message: %s, detail: %s, schemaName: %s, tableName: %s, columnName: %s", + errMsg.Severity, errMsg.Code, errMsg.Message, errMsg.Detail, errMsg.SchemaName, errMsg.TableName, errMsg.ColumnName), + } +} From 3116a1e5770d3cb7b206d403f37dcc4480c95e96 Mon Sep 17 00:00:00 2001 From: eminano Date: Mon, 1 Jul 2024 17:23:50 +0200 Subject: [PATCH 2/3] Use postgres lib in replication handler --- pkg/wal/listener/postgres/helper_test.go | 30 ++-- pkg/wal/listener/postgres/wal_pg_listener.go | 26 ++- .../listener/postgres/wal_pg_listener_test.go | 20 +-- .../instrumented_replication_handler.go | 2 +- .../mocks/mock_replication_handler.go | 4 +- .../mocks/mock_replication_message.go | 15 -- .../postgres/pg_replication_errors.go | 30 ---- .../postgres/pg_replication_handler.go | 154 +++++++++--------- .../postgres/pg_replication_messages.go | 35 ---- pkg/wal/replication/replication_handler.go | 20 +-- 10 files changed, 114 insertions(+), 222 deletions(-) delete mode 100644 pkg/wal/replication/mocks/mock_replication_message.go delete mode 100644 pkg/wal/replication/postgres/pg_replication_errors.go delete mode 100644 pkg/wal/replication/postgres/pg_replication_messages.go diff --git a/pkg/wal/listener/postgres/helper_test.go b/pkg/wal/listener/postgres/helper_test.go index 14d891d..2b1a51e 100644 --- a/pkg/wal/listener/postgres/helper_test.go +++ b/pkg/wal/listener/postgres/helper_test.go @@ -20,33 +20,25 @@ func newMockReplicationHandler() *replicationmocks.Handler { StartReplicationFn: func(context.Context) error { return nil }, GetLSNParserFn: func() replication.LSNParser { return newMockLSNParser() }, SyncLSNFn: func(ctx context.Context, lsn replication.LSN) error { return nil }, - ReceiveMessageFn: func(ctx context.Context, i uint64) (replication.Message, error) { + ReceiveMessageFn: func(ctx context.Context, i uint64) (*replication.Message, error) { return newMockMessage(), nil }, } } -func newMockMessage() *replicationmocks.Message { - return &replicationmocks.Message{ - GetDataFn: func() *replication.MessageData { - return &replication.MessageData{ - LSN: testLSN, - Data: []byte("test-data"), - ReplyRequested: false, - ServerTime: time.Now(), - } - }, +func newMockMessage() *replication.Message { + return &replication.Message{ + LSN: testLSN, + Data: []byte("test-data"), + ReplyRequested: false, + ServerTime: time.Now(), } } -func newMockKeepAliveMessage(replyRequested bool) *replicationmocks.Message { - return &replicationmocks.Message{ - GetDataFn: func() *replication.MessageData { - return &replication.MessageData{ - LSN: testLSN, - ReplyRequested: replyRequested, - } - }, +func newMockKeepAliveMessage(replyRequested bool) *replication.Message { + return &replication.Message{ + LSN: testLSN, + ReplyRequested: replyRequested, } } diff --git a/pkg/wal/listener/postgres/wal_pg_listener.go b/pkg/wal/listener/postgres/wal_pg_listener.go index d828f5b..23ae364 100644 --- a/pkg/wal/listener/postgres/wal_pg_listener.go +++ b/pkg/wal/listener/postgres/wal_pg_listener.go @@ -29,7 +29,7 @@ type Listener struct { type replicationHandler interface { StartReplication(ctx context.Context) error - ReceiveMessage(ctx context.Context) (replication.Message, error) + ReceiveMessage(ctx context.Context) (*replication.Message, error) GetLSNParser() replication.LSNParser Close() error } @@ -89,46 +89,44 @@ func (l *Listener) listen(ctx context.Context) error { default: msg, err := l.replicationHandler.ReceiveMessage(ctx) if err != nil { - replErr := &replication.Error{} - if errors.Is(err, replication.ErrConnTimeout) || (errors.As(err, &replErr) && replErr.Severity == "WARNING") { + if errors.Is(err, replication.ErrConnTimeout) { continue } return fmt.Errorf("receiving message: %w", err) } - msgData := msg.GetData() - if msgData == nil { + if msg == nil { continue } l.logger.Trace("", loglib.Fields{ - "wal_end": l.lsnParser.ToString(msgData.LSN), - "server_time": msgData.ServerTime, - "wal_data": msgData.Data, + "wal_end": l.lsnParser.ToString(msg.LSN), + "server_time": msg.ServerTime, + "wal_data": msg.Data, }) - if err := l.processWALEvent(ctx, msgData); err != nil { + if err := l.processWALEvent(ctx, msg); err != nil { return err } } } } -func (l *Listener) processWALEvent(ctx context.Context, msgData *replication.MessageData) error { +func (l *Listener) processWALEvent(ctx context.Context, msg *replication.Message) error { // if there's no data, it's a keep alive. If a reply is not requested, // no need to process this message. - if msgData.Data == nil && !msgData.ReplyRequested { + if msg.Data == nil && !msg.ReplyRequested { return nil } event := &wal.Event{} - if msgData.Data != nil { + if msg.Data != nil { event.Data = &wal.Data{} - if err := l.walDataDeserialiser(msgData.Data, event.Data); err != nil { + if err := l.walDataDeserialiser(msg.Data, event.Data); err != nil { return fmt.Errorf("error unmarshaling wal data: %w", err) } } - event.CommitPosition = wal.CommitPosition(l.lsnParser.ToString(msgData.LSN)) + event.CommitPosition = wal.CommitPosition(l.lsnParser.ToString(msg.LSN)) return l.processEvent(ctx, event) } diff --git a/pkg/wal/listener/postgres/wal_pg_listener_test.go b/pkg/wal/listener/postgres/wal_pg_listener_test.go index 991327b..9b292d6 100644 --- a/pkg/wal/listener/postgres/wal_pg_listener_test.go +++ b/pkg/wal/listener/postgres/wal_pg_listener_test.go @@ -20,10 +20,8 @@ import ( func TestListener_Listen(t *testing.T) { t.Parallel() - emptyMessage := &replicationmocks.Message{ - GetDataFn: func() *replication.MessageData { - return nil - }, + emptyMessage := &replication.Message{ + Data: nil, } testDeserialiser := func(_ []byte, out any) error { @@ -60,7 +58,7 @@ func TestListener_Listen(t *testing.T) { name: "ok - message received", replicationHandler: func(doneChan chan struct{}) *replicationmocks.Handler { h := newMockReplicationHandler() - h.ReceiveMessageFn = func(ctx context.Context, i uint64) (replication.Message, error) { + h.ReceiveMessageFn = func(ctx context.Context, i uint64) (*replication.Message, error) { defer func() { if i == 1 { doneChan <- struct{}{} @@ -83,7 +81,7 @@ func TestListener_Listen(t *testing.T) { name: "ok - timeout on receive message, retried", replicationHandler: func(doneChan chan struct{}) *replicationmocks.Handler { h := newMockReplicationHandler() - h.ReceiveMessageFn = func(ctx context.Context, i uint64) (replication.Message, error) { + h.ReceiveMessageFn = func(ctx context.Context, i uint64) (*replication.Message, error) { defer func() { if i == 2 { doneChan <- struct{}{} @@ -108,7 +106,7 @@ func TestListener_Listen(t *testing.T) { name: "ok - nil msg data", replicationHandler: func(doneChan chan struct{}) *replicationmocks.Handler { h := newMockReplicationHandler() - h.ReceiveMessageFn = func(ctx context.Context, i uint64) (replication.Message, error) { + h.ReceiveMessageFn = func(ctx context.Context, i uint64) (*replication.Message, error) { defer func() { if i == 1 { doneChan <- struct{}{} @@ -126,7 +124,7 @@ func TestListener_Listen(t *testing.T) { name: "ok - keep alive", replicationHandler: func(doneChan chan struct{}) *replicationmocks.Handler { h := newMockReplicationHandler() - h.ReceiveMessageFn = func(ctx context.Context, i uint64) (replication.Message, error) { + h.ReceiveMessageFn = func(ctx context.Context, i uint64) (*replication.Message, error) { defer func() { if i == 1 { doneChan <- struct{}{} @@ -154,7 +152,7 @@ func TestListener_Listen(t *testing.T) { name: "error - receiving message", replicationHandler: func(doneChan chan struct{}) *replicationmocks.Handler { h := newMockReplicationHandler() - h.ReceiveMessageFn = func(ctx context.Context, i uint64) (replication.Message, error) { + h.ReceiveMessageFn = func(ctx context.Context, i uint64) (*replication.Message, error) { defer func() { if i == 1 { doneChan <- struct{}{} @@ -172,7 +170,7 @@ func TestListener_Listen(t *testing.T) { name: "error - processing wal event", replicationHandler: func(doneChan chan struct{}) *replicationmocks.Handler { h := newMockReplicationHandler() - h.ReceiveMessageFn = func(ctx context.Context, i uint64) (replication.Message, error) { + h.ReceiveMessageFn = func(ctx context.Context, i uint64) (*replication.Message, error) { defer func() { if i == 1 { doneChan <- struct{}{} @@ -190,7 +188,7 @@ func TestListener_Listen(t *testing.T) { name: "error - deserialising wal event", replicationHandler: func(doneChan chan struct{}) *replicationmocks.Handler { h := newMockReplicationHandler() - h.ReceiveMessageFn = func(ctx context.Context, i uint64) (replication.Message, error) { + h.ReceiveMessageFn = func(ctx context.Context, i uint64) (*replication.Message, error) { defer func() { if i == 1 { doneChan <- struct{}{} diff --git a/pkg/wal/replication/instrumentation/instrumented_replication_handler.go b/pkg/wal/replication/instrumentation/instrumented_replication_handler.go index f5e7c48..920ce88 100644 --- a/pkg/wal/replication/instrumentation/instrumented_replication_handler.go +++ b/pkg/wal/replication/instrumentation/instrumented_replication_handler.go @@ -39,7 +39,7 @@ func (h *Handler) StartReplication(ctx context.Context) error { return h.inner.StartReplication(ctx) } -func (h *Handler) ReceiveMessage(ctx context.Context) (msg replication.Message, err error) { +func (h *Handler) ReceiveMessage(ctx context.Context) (*replication.Message, error) { return h.inner.ReceiveMessage(ctx) } diff --git a/pkg/wal/replication/mocks/mock_replication_handler.go b/pkg/wal/replication/mocks/mock_replication_handler.go index 3a78336..8db039d 100644 --- a/pkg/wal/replication/mocks/mock_replication_handler.go +++ b/pkg/wal/replication/mocks/mock_replication_handler.go @@ -11,7 +11,7 @@ import ( type Handler struct { StartReplicationFn func(context.Context) error - ReceiveMessageFn func(context.Context, uint64) (replication.Message, error) + ReceiveMessageFn func(context.Context, uint64) (*replication.Message, error) SyncLSNFn func(context.Context, replication.LSN) error DropReplicationSlotFn func(ctx context.Context) error GetLSNParserFn func() replication.LSNParser @@ -24,7 +24,7 @@ func (m *Handler) StartReplication(ctx context.Context) error { return m.StartReplicationFn(ctx) } -func (m *Handler) ReceiveMessage(ctx context.Context) (replication.Message, error) { +func (m *Handler) ReceiveMessage(ctx context.Context) (*replication.Message, error) { atomic.AddUint64(&m.ReceiveMessageCalls, 1) return m.ReceiveMessageFn(ctx, m.GetReceiveMessageCalls()) } diff --git a/pkg/wal/replication/mocks/mock_replication_message.go b/pkg/wal/replication/mocks/mock_replication_message.go deleted file mode 100644 index 12e82ad..0000000 --- a/pkg/wal/replication/mocks/mock_replication_message.go +++ /dev/null @@ -1,15 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 - -package mocks - -import ( - "github.com/xataio/pgstream/pkg/wal/replication" -) - -type Message struct { - GetDataFn func() *replication.MessageData -} - -func (m *Message) GetData() *replication.MessageData { - return m.GetDataFn() -} diff --git a/pkg/wal/replication/postgres/pg_replication_errors.go b/pkg/wal/replication/postgres/pg_replication_errors.go deleted file mode 100644 index 65cab58..0000000 --- a/pkg/wal/replication/postgres/pg_replication_errors.go +++ /dev/null @@ -1,30 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 - -package postgres - -import ( - "errors" - "fmt" - - "github.com/xataio/pgstream/pkg/wal/replication" - - "github.com/jackc/pgx/v5/pgconn" - "github.com/jackc/pgx/v5/pgproto3" -) - -var ErrUnsupportedCopyDataMessage = errors.New("unsupported copy data message") - -func parseErrNoticeResponse(errMsg *pgproto3.NoticeResponse) error { - return &replication.Error{ - Severity: errMsg.Severity, - Msg: fmt.Sprintf("replication notice response: severity: %s, code: %s, message: %s, detail: %s, schemaName: %s, tableName: %s, columnName: %s", - errMsg.Severity, errMsg.Code, errMsg.Message, errMsg.Detail, errMsg.SchemaName, errMsg.TableName, errMsg.ColumnName), - } -} - -func mapPostgresError(err error) error { - if pgconn.Timeout(err) { - return replication.ErrConnTimeout - } - return err -} diff --git a/pkg/wal/replication/postgres/pg_replication_handler.go b/pkg/wal/replication/postgres/pg_replication_handler.go index e636b24..4c8c348 100644 --- a/pkg/wal/replication/postgres/pg_replication_handler.go +++ b/pkg/wal/replication/postgres/pg_replication_handler.go @@ -4,13 +4,10 @@ package postgres import ( "context" + "errors" "fmt" - "github.com/jackc/pglogrepl" - "github.com/jackc/pgx/v5" - "github.com/jackc/pgx/v5/pgconn" - "github.com/jackc/pgx/v5/pgproto3" - + pglib "github.com/xataio/pgstream/internal/postgres" loglib "github.com/xataio/pgstream/pkg/log" "github.com/xataio/pgstream/pkg/wal/replication" ) @@ -19,15 +16,31 @@ import ( type Handler struct { logger loglib.Logger - pgReplicationConn *pgconn.PgConn + pgReplicationConn pgReplicationConn pgReplicationSlotName string - pgConnBuilder func() (*pgx.Conn, error) + pgConnBuilder func() (pgConn, error) lsnParser replication.LSNParser } +type pgConn interface { + QueryRow(ctx context.Context, query string, args ...any) pglib.Row + Close(context.Context) error +} + +type pgReplicationConn interface { + IdentifySystem(ctx context.Context) (pglib.IdentifySystemResult, error) + StartReplication(ctx context.Context, cfg pglib.ReplicationConfig) error + SendStandbyStatusUpdate(ctx context.Context, lsn uint64) error + ReceiveMessage(ctx context.Context) (*pglib.ReplicationMessage, error) + Close(ctx context.Context) error +} + type Config struct { PostgresURL string + // Name of the replication slot to listen on. If not provided, it defaults + // to "pgstream__slot". + ReplicationSlotName string } type Option func(h *Handler) @@ -40,31 +53,31 @@ const ( logSystemID = "system_id" ) +var pluginArguments = []string{ + `"include-timestamp" '1'`, + `"format-version" '2'`, + `"write-in-chunks" '1'`, + `"include-lsn" '1'`, + `"include-transaction" '0'`, +} + // NewHandler returns a new postgres replication handler for the database on input. func NewHandler(ctx context.Context, cfg Config, opts ...Option) (*Handler, error) { - pgCfg, err := pgx.ParseConfig(cfg.PostgresURL) - if err != nil { - return nil, fmt.Errorf("failed parsing postgres connection string: %w", err) - } - - connBuilder := func() (*pgx.Conn, error) { - return pgx.ConnectConfig(ctx, pgCfg) + connBuilder := func() (pgConn, error) { + return pglib.NewConn(ctx, cfg.PostgresURL) } - // open a Postgres connection dedicated for replication - copyConfig := pgCfg.Copy() - copyConfig.RuntimeParams["replication"] = "database" - - pgReplicationConn, err := pgconn.ConnectConfig(context.Background(), ©Config.Config) + pgReplicationConn, err := pglib.NewReplicationConn(ctx, cfg.PostgresURL) if err != nil { - return nil, fmt.Errorf("create postgres replication client: %w", err) + return nil, err } h := &Handler{ - logger: loglib.NewNoopLogger(), - pgReplicationConn: pgReplicationConn, - pgConnBuilder: connBuilder, - lsnParser: &LSNParser{}, + logger: loglib.NewNoopLogger(), + pgReplicationConn: pgReplicationConn, + pgReplicationSlotName: cfg.ReplicationSlotName, + pgConnBuilder: connBuilder, + lsnParser: &LSNParser{}, } for _, opt := range opts { @@ -87,12 +100,14 @@ func WithLogger(l loglib.Logger) Option { // (confirmed_flush_lsn), and if there isn't one, it will start replication from // the restart_lsn position. func (h *Handler) StartReplication(ctx context.Context) error { - sysID, err := pglogrepl.IdentifySystem(ctx, h.pgReplicationConn) + sysID, err := h.pgReplicationConn.IdentifySystem(ctx) if err != nil { return fmt.Errorf("identifySystem failed: %w", err) } - h.pgReplicationSlotName = fmt.Sprintf("pgstream_%s_slot", sysID.DBName) + if h.pgReplicationSlotName == "" { + h.pgReplicationSlotName = fmt.Sprintf("pgstream_%s_slot", sysID.DBName) + } logFields := loglib.Fields{ logSystemID: sysID.SystemID, @@ -116,7 +131,7 @@ func (h *Handler) StartReplication(ctx context.Context) error { } h.logger.Trace("replication handler: read last LSN position", logFields, loglib.Fields{ - logLSNPosition: pglogrepl.LSN(startPos), + logLSNPosition: h.lsnParser.ToString(startPos), }) if startPos == 0 { @@ -127,22 +142,15 @@ func (h *Handler) StartReplication(ctx context.Context) error { } h.logger.Trace("replication handler: set start LSN", logFields, loglib.Fields{ - logLSNPosition: pglogrepl.LSN(startPos), + logLSNPosition: h.lsnParser.ToString(startPos), }) - pluginArguments := []string{ - `"include-timestamp" '1'`, - `"format-version" '2'`, - `"write-in-chunks" '1'`, - `"include-lsn" '1'`, - `"include-transaction" '0'`, - } - err = pglogrepl.StartReplication( - ctx, - h.pgReplicationConn, - h.pgReplicationSlotName, - pglogrepl.LSN(startPos), - pglogrepl.StartReplicationOptions{PluginArgs: pluginArguments}) + err = h.pgReplicationConn.StartReplication( + ctx, pglib.ReplicationConfig{ + SlotName: h.pgReplicationSlotName, + StartPos: uint64(startPos), + PluginArguments: pluginArguments, + }) if err != nil { return fmt.Errorf("startReplication: %w", err) } @@ -154,53 +162,29 @@ func (h *Handler) StartReplication(ctx context.Context) error { // ReceiveMessage will listen for messages from the WAL. It returns an error if // an unexpected message is received. -func (h *Handler) ReceiveMessage(ctx context.Context) (replication.Message, error) { - msg, err := h.pgReplicationConn.ReceiveMessage(ctx) +func (h *Handler) ReceiveMessage(ctx context.Context) (*replication.Message, error) { + pgMsg, err := h.pgReplicationConn.ReceiveMessage(ctx) if err != nil { + h.logger.Error(err, "receiving message") return nil, mapPostgresError(err) } - switch msg := msg.(type) { - case *pgproto3.CopyData: - switch msg.Data[0] { - case pglogrepl.PrimaryKeepaliveMessageByteID: - pka, err := pglogrepl.ParsePrimaryKeepaliveMessage(msg.Data[1:]) - if err != nil { - return nil, fmt.Errorf("parse keep alive: %w", err) - } - pkaMessage := PrimaryKeepAliveMessage(pka) - return &pkaMessage, nil - case pglogrepl.XLogDataByteID: - xld, err := pglogrepl.ParseXLogData(msg.Data[1:]) - if err != nil { - return nil, fmt.Errorf("parse xlog data: %w", err) - } - - xldMessage := XLogDataMessage(xld) - return &xldMessage, nil - default: - return nil, fmt.Errorf("%v: %w", msg.Data[0], ErrUnsupportedCopyDataMessage) - } - case *pgproto3.NoticeResponse: - return nil, parseErrNoticeResponse(msg) - default: - // unexpected message (WAL error?) - return nil, fmt.Errorf("unexpected message: %#v", msg) - } + return &replication.Message{ + LSN: replication.LSN(pgMsg.LSN), + Data: pgMsg.WALData, + ServerTime: pgMsg.ServerTime, + ReplyRequested: pgMsg.ReplyRequested, + }, nil } // SyncLSN notifies Postgres how far we have processed in the WAL. func (h *Handler) SyncLSN(ctx context.Context, lsn replication.LSN) error { - err := pglogrepl.SendStandbyStatusUpdate( - ctx, - h.pgReplicationConn, - pglogrepl.StandbyStatusUpdate{WALWritePosition: pglogrepl.LSN(lsn)}, - ) + err := h.pgReplicationConn.SendStandbyStatusUpdate(ctx, uint64(lsn)) if err != nil { return fmt.Errorf("syncLSN: send status update: %w", err) } h.logger.Trace("stored new LSN position", loglib.Fields{ - logLSNPosition: pglogrepl.LSN(lsn).String(), + logLSNPosition: h.lsnParser.ToString(lsn), }) return nil } @@ -237,7 +221,7 @@ func (h *Handler) Close() error { // getRestartLSN returns the absolute earliest possible LSN we can support. If // the consumer's LSN is earlier than this, we cannot (easily) catch the // consumer back up. -func (h *Handler) getRestartLSN(ctx context.Context, conn *pgx.Conn, slotName string) (replication.LSN, error) { +func (h *Handler) getRestartLSN(ctx context.Context, conn pgConn, slotName string) (replication.LSN, error) { var restartLSN string err := conn.QueryRow( ctx, @@ -253,7 +237,7 @@ func (h *Handler) getRestartLSN(ctx context.Context, conn *pgx.Conn, slotName st // getLastSyncedLSN gets the `confirmed_flush_lsn` from PG. This is the last LSN // that the consumer confirmed it had completed. -func (h *Handler) getLastSyncedLSN(ctx context.Context, conn *pgx.Conn) (replication.LSN, error) { +func (h *Handler) getLastSyncedLSN(ctx context.Context, conn pgConn) (replication.LSN, error) { var confirmedFlushLSN string err := conn.QueryRow(ctx, `select confirmed_flush_lsn from pg_replication_slots where slot_name=$1`, h.pgReplicationSlotName).Scan(&confirmedFlushLSN) if err != nil { @@ -262,3 +246,17 @@ func (h *Handler) getLastSyncedLSN(ctx context.Context, conn *pgx.Conn) (replica return h.lsnParser.FromString(confirmedFlushLSN) } + +func mapPostgresError(err error) error { + if errors.Is(err, pglib.ErrConnTimeout) { + return replication.ErrConnTimeout + } + + // ignore warnings + replErr := &pglib.Error{} + if errors.As(err, &replErr) && replErr.Severity == "WARNING" { + return nil + } + + return err +} diff --git a/pkg/wal/replication/postgres/pg_replication_messages.go b/pkg/wal/replication/postgres/pg_replication_messages.go deleted file mode 100644 index 17a9f94..0000000 --- a/pkg/wal/replication/postgres/pg_replication_messages.go +++ /dev/null @@ -1,35 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 - -package postgres - -import ( - "github.com/jackc/pglogrepl" - - "github.com/xataio/pgstream/pkg/wal/replication" -) - -// PrimaryKeepAliveMessage contains no wal data and a flag to indicate if a -// response is requested along with the message metadata (lsn and server time). -type PrimaryKeepAliveMessage pglogrepl.PrimaryKeepaliveMessage - -func (pka *PrimaryKeepAliveMessage) GetData() *replication.MessageData { - return &replication.MessageData{ - LSN: replication.LSN(pka.ServerWALEnd), - ServerTime: pka.ServerTime, - ReplyRequested: pka.ReplyRequested, - } -} - -// XLogDataMessage contains the wal data along with the message metadata (lsn -// and server time) -type XLogDataMessage pglogrepl.XLogData - -func (xld *XLogDataMessage) GetData() *replication.MessageData { - newLSN := xld.WALStart + pglogrepl.LSN(len(xld.WALData)) - return &replication.MessageData{ - LSN: replication.LSN(newLSN), - ServerTime: xld.ServerTime, - ReplyRequested: false, - Data: xld.WALData, - } -} diff --git a/pkg/wal/replication/replication_handler.go b/pkg/wal/replication/replication_handler.go index 557979d..0889e45 100644 --- a/pkg/wal/replication/replication_handler.go +++ b/pkg/wal/replication/replication_handler.go @@ -5,26 +5,21 @@ package replication import ( "context" "errors" - "fmt" "time" ) // Handler manages the replication operations type Handler interface { StartReplication(ctx context.Context) error - ReceiveMessage(ctx context.Context) (Message, error) + ReceiveMessage(ctx context.Context) (*Message, error) SyncLSN(ctx context.Context, lsn LSN) error GetReplicationLag(ctx context.Context) (int64, error) GetLSNParser() LSNParser Close() error } -type Message interface { - GetData() *MessageData -} - -// MessageData is the common data for all replication messages -type MessageData struct { +// Message contains the replication data +type Message struct { LSN LSN Data []byte ServerTime time.Time @@ -40,12 +35,3 @@ type LSNParser interface { type LSN uint64 var ErrConnTimeout = errors.New("connection timeout") - -type Error struct { - Severity string - Msg string -} - -func (e *Error) Error() string { - return fmt.Sprintf("replication error: %s", e.Msg) -} From 532d3fc711f66c7cc209f8e87734f69d7e835460 Mon Sep 17 00:00:00 2001 From: eminano Date: Mon, 1 Jul 2024 17:24:11 +0200 Subject: [PATCH 3/3] Add replication handler tests --- pkg/wal/replication/postgres/helper_test.go | 45 ++ .../postgres/pg_replication_handler_test.go | 481 ++++++++++++++++++ 2 files changed, 526 insertions(+) create mode 100644 pkg/wal/replication/postgres/helper_test.go create mode 100644 pkg/wal/replication/postgres/pg_replication_handler_test.go diff --git a/pkg/wal/replication/postgres/helper_test.go b/pkg/wal/replication/postgres/helper_test.go new file mode 100644 index 0000000..8667a37 --- /dev/null +++ b/pkg/wal/replication/postgres/helper_test.go @@ -0,0 +1,45 @@ +// SPDX-License-Identifier: Apache-2.0 + +package postgres + +import ( + "errors" + "fmt" + "time" +) + +type mockRow struct { + lsn string + lag int64 + scanFn func(args ...any) error +} + +func (m *mockRow) Scan(args ...any) error { + if m.scanFn != nil { + return m.scanFn(args...) + } + + switch arg := args[0].(type) { + case *string: + *arg = m.lsn + case *int64: + *arg = m.lag + default: + return fmt.Errorf("unexpected argument type in scan: %T", args[0]) + } + + return nil +} + +const ( + testDBName = "test-db" + testSlot = "test_slot" + testLSN = uint64(7773397064) + testLSNStr = "1/CF54A048" +) + +var ( + errTest = errors.New("oh noes") + + now = time.Now() +) diff --git a/pkg/wal/replication/postgres/pg_replication_handler_test.go b/pkg/wal/replication/postgres/pg_replication_handler_test.go new file mode 100644 index 0000000..39fc45f --- /dev/null +++ b/pkg/wal/replication/postgres/pg_replication_handler_test.go @@ -0,0 +1,481 @@ +// SPDX-License-Identifier: Apache-2.0 + +package postgres + +import ( + "context" + "errors" + "fmt" + "testing" + + "github.com/stretchr/testify/require" + pglib "github.com/xataio/pgstream/internal/postgres" + pgmocks "github.com/xataio/pgstream/internal/postgres/mocks" + "github.com/xataio/pgstream/pkg/log" + "github.com/xataio/pgstream/pkg/wal/replication" +) + +func TestHandler_StartReplication(t *testing.T) { + t.Parallel() + + defaultSlot := fmt.Sprintf("pgstream_%s_slot", testDBName) + + tests := []struct { + name string + replicationConn pgReplicationConn + connBuilder func() (pgConn, error) + slotName string + + wantErr error + }{ + { + name: "ok - with last synced LSN", + replicationConn: &pgmocks.ReplicationConn{ + IdentifySystemFn: func(ctx context.Context) (pglib.IdentifySystemResult, error) { + return pglib.IdentifySystemResult{ + DBName: testDBName, + SystemID: "tes-sys-id", + }, nil + }, + StartReplicationFn: func(ctx context.Context, cfg pglib.ReplicationConfig) error { + require.Equal(t, testLSN, cfg.StartPos) + require.Equal(t, testSlot, cfg.SlotName) + require.Equal(t, pluginArguments, cfg.PluginArguments) + return nil + }, + SendStandbyStatusUpdateFn: func(ctx context.Context, lsn uint64) error { + require.Equal(t, testLSN, lsn) + return nil + }, + }, + connBuilder: func() (pgConn, error) { + return &pgmocks.Conn{ + QueryRowFn: func(ctx context.Context, query string, args ...any) pglib.Row { + require.Len(t, args, 1) + require.Equal(t, args[0], testSlot) + switch query { + case "select restart_lsn from pg_replication_slots where slot_name=$1": + return &mockRow{scanFn: func(args ...any) error { return errors.New("restart lsn should not be called") }} + case "select confirmed_flush_lsn from pg_replication_slots where slot_name=$1": + return &mockRow{lsn: testLSNStr} + default: + return &mockRow{scanFn: func(args ...any) error { return fmt.Errorf("unexpected query: %s", query) }} + } + }, + CloseFn: func(ctx context.Context) error { return nil }, + }, nil + }, + slotName: testSlot, + + wantErr: nil, + }, + { + name: "ok - with restart LSN", + replicationConn: &pgmocks.ReplicationConn{ + IdentifySystemFn: func(ctx context.Context) (pglib.IdentifySystemResult, error) { + return pglib.IdentifySystemResult{ + DBName: testDBName, + SystemID: "tes-sys-id", + }, nil + }, + StartReplicationFn: func(ctx context.Context, cfg pglib.ReplicationConfig) error { + require.Equal(t, testLSN, cfg.StartPos) + require.Equal(t, testSlot, cfg.SlotName) + require.Equal(t, pluginArguments, cfg.PluginArguments) + return nil + }, + SendStandbyStatusUpdateFn: func(ctx context.Context, lsn uint64) error { + require.Equal(t, testLSN, lsn) + return nil + }, + }, + connBuilder: func() (pgConn, error) { + return &pgmocks.Conn{ + QueryRowFn: func(ctx context.Context, query string, args ...any) pglib.Row { + require.Len(t, args, 1) + require.Equal(t, args[0], testSlot) + switch query { + case "select restart_lsn from pg_replication_slots where slot_name=$1": + return &mockRow{lsn: testLSNStr} + case "select confirmed_flush_lsn from pg_replication_slots where slot_name=$1": + return &mockRow{lsn: "0/0"} + default: + return &mockRow{scanFn: func(args ...any) error { return fmt.Errorf("unexpected query: %s", query) }} + } + }, + CloseFn: func(ctx context.Context) error { return nil }, + }, nil + }, + slotName: testSlot, + + wantErr: nil, + }, + { + name: "ok - with default slot name", + replicationConn: &pgmocks.ReplicationConn{ + IdentifySystemFn: func(ctx context.Context) (pglib.IdentifySystemResult, error) { + return pglib.IdentifySystemResult{ + DBName: testDBName, + SystemID: "tes-sys-id", + }, nil + }, + StartReplicationFn: func(ctx context.Context, cfg pglib.ReplicationConfig) error { + require.Equal(t, testLSN, cfg.StartPos) + require.Equal(t, defaultSlot, cfg.SlotName) + require.Equal(t, pluginArguments, cfg.PluginArguments) + return nil + }, + SendStandbyStatusUpdateFn: func(ctx context.Context, lsn uint64) error { + require.Equal(t, testLSN, lsn) + return nil + }, + }, + connBuilder: func() (pgConn, error) { + return &pgmocks.Conn{ + QueryRowFn: func(ctx context.Context, query string, args ...any) pglib.Row { + require.Len(t, args, 1) + require.Equal(t, args[0], defaultSlot) + switch query { + case "select restart_lsn from pg_replication_slots where slot_name=$1": + return &mockRow{scanFn: func(args ...any) error { return errors.New("restart lsn should not be called") }} + case "select confirmed_flush_lsn from pg_replication_slots where slot_name=$1": + return &mockRow{lsn: testLSNStr} + default: + return &mockRow{scanFn: func(args ...any) error { return fmt.Errorf("unexpected query: %s", query) }} + } + }, + CloseFn: func(ctx context.Context) error { return nil }, + }, nil + }, + + wantErr: nil, + }, + { + name: "error - identifying system", + replicationConn: &pgmocks.ReplicationConn{ + IdentifySystemFn: func(ctx context.Context) (pglib.IdentifySystemResult, error) { + return pglib.IdentifySystemResult{}, errTest + }, + }, + slotName: testSlot, + + wantErr: errTest, + }, + { + name: "error - creating connection", + replicationConn: &pgmocks.ReplicationConn{ + IdentifySystemFn: func(ctx context.Context) (pglib.IdentifySystemResult, error) { + return pglib.IdentifySystemResult{ + DBName: testDBName, + SystemID: "tes-sys-id", + }, nil + }, + }, + connBuilder: func() (pgConn, error) { + return nil, errTest + }, + slotName: testSlot, + + wantErr: errTest, + }, + { + name: "error - getting last synced LSN", + replicationConn: &pgmocks.ReplicationConn{ + IdentifySystemFn: func(ctx context.Context) (pglib.IdentifySystemResult, error) { + return pglib.IdentifySystemResult{ + DBName: testDBName, + SystemID: "tes-sys-id", + }, nil + }, + }, + connBuilder: func() (pgConn, error) { + return &pgmocks.Conn{ + QueryRowFn: func(ctx context.Context, query string, args ...any) pglib.Row { + require.Len(t, args, 1) + require.Equal(t, args[0], testSlot) + switch query { + case "select confirmed_flush_lsn from pg_replication_slots where slot_name=$1": + return &mockRow{scanFn: func(args ...any) error { return errTest }} + default: + return &mockRow{scanFn: func(args ...any) error { return fmt.Errorf("unexpected query: %s", query) }} + } + }, + CloseFn: func(ctx context.Context) error { return nil }, + }, nil + }, + slotName: testSlot, + + wantErr: errTest, + }, + { + name: "error - getting restart LSN", + replicationConn: &pgmocks.ReplicationConn{ + IdentifySystemFn: func(ctx context.Context) (pglib.IdentifySystemResult, error) { + return pglib.IdentifySystemResult{ + DBName: testDBName, + SystemID: "tes-sys-id", + }, nil + }, + }, + connBuilder: func() (pgConn, error) { + return &pgmocks.Conn{ + QueryRowFn: func(ctx context.Context, query string, args ...any) pglib.Row { + require.Len(t, args, 1) + require.Equal(t, args[0], testSlot) + switch query { + case "select restart_lsn from pg_replication_slots where slot_name=$1": + return &mockRow{scanFn: func(args ...any) error { return errTest }} + case "select confirmed_flush_lsn from pg_replication_slots where slot_name=$1": + return &mockRow{lsn: "0/0"} + default: + return &mockRow{scanFn: func(args ...any) error { return fmt.Errorf("unexpected query: %s", query) }} + } + }, + CloseFn: func(ctx context.Context) error { return nil }, + }, nil + }, + slotName: testSlot, + + wantErr: errTest, + }, + { + name: "error - starting replication", + replicationConn: &pgmocks.ReplicationConn{ + IdentifySystemFn: func(ctx context.Context) (pglib.IdentifySystemResult, error) { + return pglib.IdentifySystemResult{ + DBName: testDBName, + SystemID: "tes-sys-id", + }, nil + }, + StartReplicationFn: func(ctx context.Context, cfg pglib.ReplicationConfig) error { + return errTest + }, + }, + connBuilder: func() (pgConn, error) { + return &pgmocks.Conn{ + QueryRowFn: func(ctx context.Context, query string, args ...any) pglib.Row { + require.Len(t, args, 1) + require.Equal(t, args[0], testSlot) + switch query { + case "select confirmed_flush_lsn from pg_replication_slots where slot_name=$1": + return &mockRow{lsn: testLSNStr} + default: + return &mockRow{scanFn: func(args ...any) error { return fmt.Errorf("unexpected query: %s", query) }} + } + }, + CloseFn: func(ctx context.Context) error { return nil }, + }, nil + }, + slotName: testSlot, + + wantErr: errTest, + }, + { + name: "error - syncing LSN", + replicationConn: &pgmocks.ReplicationConn{ + IdentifySystemFn: func(ctx context.Context) (pglib.IdentifySystemResult, error) { + return pglib.IdentifySystemResult{ + DBName: testDBName, + SystemID: "tes-sys-id", + }, nil + }, + StartReplicationFn: func(ctx context.Context, cfg pglib.ReplicationConfig) error { + require.Equal(t, testLSN, cfg.StartPos) + require.Equal(t, testSlot, cfg.SlotName) + require.Equal(t, pluginArguments, cfg.PluginArguments) + return nil + }, + SendStandbyStatusUpdateFn: func(ctx context.Context, lsn uint64) error { + return errTest + }, + }, + connBuilder: func() (pgConn, error) { + return &pgmocks.Conn{ + QueryRowFn: func(ctx context.Context, query string, args ...any) pglib.Row { + require.Len(t, args, 1) + require.Equal(t, args[0], testSlot) + switch query { + case "select confirmed_flush_lsn from pg_replication_slots where slot_name=$1": + return &mockRow{lsn: testLSNStr} + default: + return &mockRow{scanFn: func(args ...any) error { return fmt.Errorf("unexpected query: %s", query) }} + } + }, + CloseFn: func(ctx context.Context) error { return nil }, + }, nil + }, + slotName: testSlot, + + wantErr: errTest, + }, + } + + for _, tc := range tests { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + h := Handler{ + logger: log.NewNoopLogger(), + pgReplicationConn: tc.replicationConn, + pgConnBuilder: tc.connBuilder, + pgReplicationSlotName: tc.slotName, + lsnParser: NewLSNParser(), + } + + err := h.StartReplication(context.Background()) + require.ErrorIs(t, err, tc.wantErr) + }) + } +} + +func TestHandler_ReceiveMessage(t *testing.T) { + t.Parallel() + + testData := []byte("test-data") + + tests := []struct { + name string + replicationConn pgReplicationConn + + wantMessage *replication.Message + wantErr error + }{ + { + name: "ok", + replicationConn: &pgmocks.ReplicationConn{ + ReceiveMessageFn: func(ctx context.Context) (*pglib.ReplicationMessage, error) { + return &pglib.ReplicationMessage{ + LSN: testLSN, + ServerTime: now, + WALData: testData, + ReplyRequested: false, + }, nil + }, + }, + + wantMessage: &replication.Message{ + LSN: replication.LSN(testLSN), + Data: testData, + ServerTime: now, + ReplyRequested: false, + }, + wantErr: nil, + }, + { + name: "ok - receiving message, warning notice", + replicationConn: &pgmocks.ReplicationConn{ + ReceiveMessageFn: func(ctx context.Context) (*pglib.ReplicationMessage, error) { + return nil, &pglib.Error{Severity: "WARNING"} + }, + }, + + wantMessage: nil, + wantErr: nil, + }, + { + name: "error - receiving message - timeout", + replicationConn: &pgmocks.ReplicationConn{ + ReceiveMessageFn: func(ctx context.Context) (*pglib.ReplicationMessage, error) { + return nil, pglib.ErrConnTimeout + }, + }, + + wantMessage: nil, + wantErr: replication.ErrConnTimeout, + }, + } + + for _, tc := range tests { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + h := Handler{ + logger: log.NewNoopLogger(), + pgReplicationConn: tc.replicationConn, + } + + msg, err := h.ReceiveMessage(context.Background()) + require.ErrorIs(t, err, tc.wantErr) + require.Equal(t, tc.wantMessage, msg) + }) + } +} + +func TestHandler_GetReplicationLag(t *testing.T) { + t.Parallel() + + testLag := int64(5) + + tests := []struct { + name string + connBuilder func() (pgConn, error) + + wantLag int64 + wantErr error + }{ + { + name: "ok", + connBuilder: func() (pgConn, error) { + return &pgmocks.Conn{ + QueryRowFn: func(ctx context.Context, query string, args ...any) pglib.Row { + require.Len(t, args, 1) + require.Equal(t, args[0], testSlot) + switch query { + case "SELECT (pg_current_wal_lsn() - confirmed_flush_lsn) FROM pg_replication_slots WHERE slot_name=$1": + return &mockRow{lag: testLag} + default: + return &mockRow{scanFn: func(args ...any) error { return fmt.Errorf("unexpected query: %s", query) }} + } + }, + CloseFn: func(ctx context.Context) error { return nil }, + }, nil + }, + + wantLag: testLag, + wantErr: nil, + }, + { + name: "error - building connection", + connBuilder: func() (pgConn, error) { + return nil, errTest + }, + + wantLag: -1, + wantErr: errTest, + }, + { + name: "error - getting lag", + connBuilder: func() (pgConn, error) { + return &pgmocks.Conn{ + QueryRowFn: func(ctx context.Context, query string, args ...any) pglib.Row { + return &mockRow{scanFn: func(args ...any) error { return errTest }} + }, + CloseFn: func(ctx context.Context) error { return nil }, + }, nil + }, + + wantLag: -1, + wantErr: errTest, + }, + } + + for _, tc := range tests { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + h := Handler{ + logger: log.NewNoopLogger(), + pgConnBuilder: tc.connBuilder, + pgReplicationSlotName: testSlot, + lsnParser: NewLSNParser(), + } + + lag, err := h.GetReplicationLag(context.Background()) + require.ErrorIs(t, err, tc.wantErr) + require.Equal(t, tc.wantLag, lag) + }) + } +}