Skip to content

Commit

Permalink
Merge pull request #52 from xataio/add-replication-handler-tests
Browse files Browse the repository at this point in the history
Add replication handler tests
  • Loading branch information
eminano authored Jul 3, 2024
2 parents f21b09f + 532d3fc commit f5817f5
Show file tree
Hide file tree
Showing 17 changed files with 898 additions and 222 deletions.
19 changes: 19 additions & 0 deletions internal/postgres/errors.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
// SPDX-License-Identifier: Apache-2.0

package postgres

import (
"errors"

"github.com/jackc/pgx/v5/pgconn"
)

var ErrConnTimeout = errors.New("connection timeout")

func mapError(err error) error {
if pgconn.Timeout(err) {
return ErrConnTimeout
}

return err
}
24 changes: 24 additions & 0 deletions internal/postgres/mocks/mock_pg_conn.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
// SPDX-License-Identifier: Apache-2.0

package mocks

import (
"context"

"github.com/jackc/pgx/v5/pgconn"
"github.com/xataio/pgstream/internal/postgres"
)

type Conn struct {
QueryRowFn func(ctx context.Context, query string, args ...any) postgres.Row
ExecFn func(context.Context, string, ...any) (pgconn.CommandTag, error)
CloseFn func(context.Context) error
}

func (m *Conn) QueryRow(ctx context.Context, query string, args ...any) postgres.Row {
return m.QueryRowFn(ctx, query, args...)
}

func (m *Conn) Close(ctx context.Context) error {
return m.CloseFn(ctx)
}
37 changes: 37 additions & 0 deletions internal/postgres/mocks/mock_pg_replication_conn.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
// SPDX-License-Identifier: Apache-2.0

package mocks

import (
"context"

"github.com/xataio/pgstream/internal/postgres"
)

type ReplicationConn struct {
IdentifySystemFn func(ctx context.Context) (postgres.IdentifySystemResult, error)
StartReplicationFn func(ctx context.Context, cfg postgres.ReplicationConfig) error
SendStandbyStatusUpdateFn func(ctx context.Context, lsn uint64) error
ReceiveMessageFn func(ctx context.Context) (*postgres.ReplicationMessage, error)
CloseFn func(ctx context.Context) error
}

func (m *ReplicationConn) IdentifySystem(ctx context.Context) (postgres.IdentifySystemResult, error) {
return m.IdentifySystemFn(ctx)
}

func (m *ReplicationConn) StartReplication(ctx context.Context, cfg postgres.ReplicationConfig) error {
return m.StartReplicationFn(ctx, cfg)
}

func (m *ReplicationConn) SendStandbyStatusUpdate(ctx context.Context, lsn uint64) error {
return m.SendStandbyStatusUpdateFn(ctx, lsn)
}

func (m *ReplicationConn) ReceiveMessage(ctx context.Context) (*postgres.ReplicationMessage, error) {
return m.ReceiveMessageFn(ctx)
}

func (m *ReplicationConn) Close(ctx context.Context) error {
return m.CloseFn(ctx)
}
40 changes: 40 additions & 0 deletions internal/postgres/pg_conn.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
// SPDX-License-Identifier: Apache-2.0

package postgres

import (
"context"
"fmt"

"github.com/jackc/pgx/v5"
)

type Conn struct {
conn *pgx.Conn
}

type Row interface {
pgx.Row
}

func NewConn(ctx context.Context, url string) (*Conn, error) {
pgCfg, err := pgx.ParseConfig(url)
if err != nil {
return nil, fmt.Errorf("failed parsing postgres connection string: %w", mapError(err))
}

conn, err := pgx.ConnectConfig(ctx, pgCfg)
if err != nil {
return nil, fmt.Errorf("failed to connect to postgres: %w", mapError(err))
}

return &Conn{conn: conn}, nil
}

func (c *Conn) QueryRow(ctx context.Context, query string, args ...any) Row {
return c.conn.QueryRow(ctx, query, args...)
}

func (c *Conn) Close(ctx context.Context) error {
return mapError(c.conn.Close(ctx))
}
138 changes: 138 additions & 0 deletions internal/postgres/pg_replication_conn.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
// SPDX-License-Identifier: Apache-2.0

package postgres

import (
"context"
"errors"
"fmt"
"time"

"github.com/jackc/pglogrepl"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgconn"
"github.com/jackc/pgx/v5/pgproto3"
)

type ReplicationConn struct {
conn *pgconn.PgConn
}

type ReplicationConfig struct {
SlotName string
StartPos uint64
PluginArguments []string
}

type ReplicationMessage struct {
LSN uint64
ServerTime time.Time
WALData []byte
ReplyRequested bool
}

type IdentifySystemResult pglogrepl.IdentifySystemResult

var ErrUnsupportedCopyDataMessage = errors.New("unsupported copy data message")

func NewReplicationConn(ctx context.Context, url string) (*ReplicationConn, error) {
pgCfg, err := pgx.ParseConfig(url)
if err != nil {
return nil, fmt.Errorf("failed parsing postgres connection string: %w", err)
}

pgCfg.RuntimeParams["replication"] = "database"

conn, err := pgconn.ConnectConfig(context.Background(), &pgCfg.Config)
if err != nil {
return nil, fmt.Errorf("create postgres replication client: %w", mapError(err))
}

return &ReplicationConn{
conn: conn,
}, nil
}

func (c *ReplicationConn) IdentifySystem(ctx context.Context) (IdentifySystemResult, error) {
res, err := pglogrepl.IdentifySystem(ctx, c.conn)
return IdentifySystemResult(res), mapError(err)
}

func (c *ReplicationConn) StartReplication(ctx context.Context, cfg ReplicationConfig) error {
return mapError(pglogrepl.StartReplication(
ctx,
c.conn,
cfg.SlotName,
pglogrepl.LSN(cfg.StartPos),
pglogrepl.StartReplicationOptions{PluginArgs: cfg.PluginArguments}))
}

func (c *ReplicationConn) SendStandbyStatusUpdate(ctx context.Context, lsn uint64) error {
return mapError(pglogrepl.SendStandbyStatusUpdate(
ctx,
c.conn,
pglogrepl.StandbyStatusUpdate{WALWritePosition: pglogrepl.LSN(lsn)},
))
}

func (c *ReplicationConn) ReceiveMessage(ctx context.Context) (*ReplicationMessage, error) {
msg, err := c.conn.ReceiveMessage(ctx)
if err != nil {
return nil, mapError(err)
}

switch msg := msg.(type) {
case *pgproto3.CopyData:
switch msg.Data[0] {
case pglogrepl.PrimaryKeepaliveMessageByteID:
pka, err := pglogrepl.ParsePrimaryKeepaliveMessage(msg.Data[1:])
if err != nil {
return nil, fmt.Errorf("parse keep alive: %w", err)
}
return &ReplicationMessage{
LSN: uint64(pka.ServerWALEnd),
ServerTime: pka.ServerTime,
ReplyRequested: pka.ReplyRequested,
}, nil
case pglogrepl.XLogDataByteID:
xld, err := pglogrepl.ParseXLogData(msg.Data[1:])
if err != nil {
return nil, fmt.Errorf("parse xlog data: %w", err)
}

return &ReplicationMessage{
LSN: uint64(xld.WALStart) + uint64(len(xld.WALData)),
ServerTime: xld.ServerTime,
WALData: xld.WALData,
}, nil
default:
return nil, fmt.Errorf("%v: %w", msg.Data[0], ErrUnsupportedCopyDataMessage)
}
case *pgproto3.NoticeResponse:
return nil, parseErrNoticeResponse(msg)
default:
// unexpected message (WAL error?)
return nil, fmt.Errorf("unexpected message: %#v", msg)
}
}

func (c *ReplicationConn) Close(ctx context.Context) error {
return mapError(c.conn.Close(ctx))
}

type Error struct {
Severity string
Msg string
}

func (e *Error) Error() string {
return fmt.Sprintf("replication error: %s", e.Msg)
}

func parseErrNoticeResponse(errMsg *pgproto3.NoticeResponse) error {
return &Error{
Severity: errMsg.Severity,
Msg: fmt.Sprintf("replication notice response: severity: %s, code: %s, message: %s, detail: %s, schemaName: %s, tableName: %s, columnName: %s",
errMsg.Severity, errMsg.Code, errMsg.Message, errMsg.Detail, errMsg.SchemaName, errMsg.TableName, errMsg.ColumnName),
}
}
30 changes: 11 additions & 19 deletions pkg/wal/listener/postgres/helper_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,33 +20,25 @@ func newMockReplicationHandler() *replicationmocks.Handler {
StartReplicationFn: func(context.Context) error { return nil },
GetLSNParserFn: func() replication.LSNParser { return newMockLSNParser() },
SyncLSNFn: func(ctx context.Context, lsn replication.LSN) error { return nil },
ReceiveMessageFn: func(ctx context.Context, i uint64) (replication.Message, error) {
ReceiveMessageFn: func(ctx context.Context, i uint64) (*replication.Message, error) {
return newMockMessage(), nil
},
}
}

func newMockMessage() *replicationmocks.Message {
return &replicationmocks.Message{
GetDataFn: func() *replication.MessageData {
return &replication.MessageData{
LSN: testLSN,
Data: []byte("test-data"),
ReplyRequested: false,
ServerTime: time.Now(),
}
},
func newMockMessage() *replication.Message {
return &replication.Message{
LSN: testLSN,
Data: []byte("test-data"),
ReplyRequested: false,
ServerTime: time.Now(),
}
}

func newMockKeepAliveMessage(replyRequested bool) *replicationmocks.Message {
return &replicationmocks.Message{
GetDataFn: func() *replication.MessageData {
return &replication.MessageData{
LSN: testLSN,
ReplyRequested: replyRequested,
}
},
func newMockKeepAliveMessage(replyRequested bool) *replication.Message {
return &replication.Message{
LSN: testLSN,
ReplyRequested: replyRequested,
}
}

Expand Down
26 changes: 12 additions & 14 deletions pkg/wal/listener/postgres/wal_pg_listener.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ type Listener struct {

type replicationHandler interface {
StartReplication(ctx context.Context) error
ReceiveMessage(ctx context.Context) (replication.Message, error)
ReceiveMessage(ctx context.Context) (*replication.Message, error)
GetLSNParser() replication.LSNParser
Close() error
}
Expand Down Expand Up @@ -89,46 +89,44 @@ func (l *Listener) listen(ctx context.Context) error {
default:
msg, err := l.replicationHandler.ReceiveMessage(ctx)
if err != nil {
replErr := &replication.Error{}
if errors.Is(err, replication.ErrConnTimeout) || (errors.As(err, &replErr) && replErr.Severity == "WARNING") {
if errors.Is(err, replication.ErrConnTimeout) {
continue
}
return fmt.Errorf("receiving message: %w", err)
}

msgData := msg.GetData()
if msgData == nil {
if msg == nil {
continue
}

l.logger.Trace("", loglib.Fields{
"wal_end": l.lsnParser.ToString(msgData.LSN),
"server_time": msgData.ServerTime,
"wal_data": msgData.Data,
"wal_end": l.lsnParser.ToString(msg.LSN),
"server_time": msg.ServerTime,
"wal_data": msg.Data,
})

if err := l.processWALEvent(ctx, msgData); err != nil {
if err := l.processWALEvent(ctx, msg); err != nil {
return err
}
}
}
}

func (l *Listener) processWALEvent(ctx context.Context, msgData *replication.MessageData) error {
func (l *Listener) processWALEvent(ctx context.Context, msg *replication.Message) error {
// if there's no data, it's a keep alive. If a reply is not requested,
// no need to process this message.
if msgData.Data == nil && !msgData.ReplyRequested {
if msg.Data == nil && !msg.ReplyRequested {
return nil
}

event := &wal.Event{}
if msgData.Data != nil {
if msg.Data != nil {
event.Data = &wal.Data{}
if err := l.walDataDeserialiser(msgData.Data, event.Data); err != nil {
if err := l.walDataDeserialiser(msg.Data, event.Data); err != nil {
return fmt.Errorf("error unmarshaling wal data: %w", err)
}
}
event.CommitPosition = wal.CommitPosition(l.lsnParser.ToString(msgData.LSN))
event.CommitPosition = wal.CommitPosition(l.lsnParser.ToString(msg.LSN))

return l.processEvent(ctx, event)
}
Loading

0 comments on commit f5817f5

Please sign in to comment.