Skip to content

Commit

Permalink
Add automatic statement cache
Browse files Browse the repository at this point in the history
  • Loading branch information
jackc committed Aug 25, 2019
1 parent 180dfe6 commit 0c3e59b
Show file tree
Hide file tree
Showing 7 changed files with 349 additions and 22 deletions.
72 changes: 72 additions & 0 deletions bench_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,82 @@ import (
"testing"
"time"

"github.com/jackc/pgconn"
"github.com/jackc/pgconn/stmtcache"
"github.com/jackc/pgtype"
"github.com/jackc/pgx/v4"
)

func BenchmarkMinimalUnpreparedSelectWithoutStatementCache(b *testing.B) {
config := mustParseConfig(b, os.Getenv("PGX_TEST_DATABASE"))
config.BuildPreparedStatementCache = nil

conn := mustConnect(b, config)
defer closeConn(b, conn)

var n int64

b.ResetTimer()
for i := 0; i < b.N; i++ {
err := conn.QueryRow(context.Background(), "select $1::int8", i).Scan(&n)
if err != nil {
b.Fatal(err)
}

if n != int64(i) {
b.Fatalf("expected %d, got %d", i, n)
}
}
}

func BenchmarkMinimalUnpreparedSelectWithStatementCacheModeDescribe(b *testing.B) {
config := mustParseConfig(b, os.Getenv("PGX_TEST_DATABASE"))
config.BuildPreparedStatementCache = func(conn *pgconn.PgConn) stmtcache.Cache {
return stmtcache.New(conn, stmtcache.ModeDescribe, 32)
}

conn := mustConnect(b, config)
defer closeConn(b, conn)

var n int64

b.ResetTimer()
for i := 0; i < b.N; i++ {
err := conn.QueryRow(context.Background(), "select $1::int8", i).Scan(&n)
if err != nil {
b.Fatal(err)
}

if n != int64(i) {
b.Fatalf("expected %d, got %d", i, n)
}
}
}

func BenchmarkMinimalUnpreparedSelectWithStatementCacheModePrepare(b *testing.B) {
config := mustParseConfig(b, os.Getenv("PGX_TEST_DATABASE"))
config.BuildPreparedStatementCache = func(conn *pgconn.PgConn) stmtcache.Cache {
return stmtcache.New(conn, stmtcache.ModePrepare, 32)
}

conn := mustConnect(b, config)
defer closeConn(b, conn)

var n int64

b.ResetTimer()
for i := 0; i < b.N; i++ {
err := conn.QueryRow(context.Background(), "select $1::int8", i).Scan(&n)
if err != nil {
b.Fatal(err)
}

if n != int64(i) {
b.Fatalf("expected %d, got %d", i, n)
}
}
}

func BenchmarkMinimalPreparedSelect(b *testing.B) {
conn := mustConnect(b, mustParseConfig(b, os.Getenv("PGX_TEST_DATABASE")))
defer closeConn(b, conn)
Expand Down
136 changes: 119 additions & 17 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,14 @@ package pgx

import (
"context"
"strconv"
"strings"
"time"

errors "golang.org/x/xerrors"

"github.com/jackc/pgconn"
"github.com/jackc/pgconn/stmtcache"
"github.com/jackc/pgproto3/v2"
"github.com/jackc/pgtype"
"github.com/jackc/pgx/v4/internal/sanitize"
Expand All @@ -27,6 +29,10 @@ type ConnConfig struct {
Logger Logger
LogLevel LogLevel

// BuildPreparedStatementCache creates the stmtcache.Cache implementation for connections created with this config. Set
// to nil to disable automatic prepared statements.
BuildPreparedStatementCache BuildPreparedStatementCacheFunc

// PreferSimpleProtocol disables implicit prepared statement usage. By default pgx automatically uses the extended
// protocol. This can improve performance due to being able to use the binary format. It also does not rely on client
// side parameter sanitization. However, it does incur two round-trips per query (unless using a prepared statement)
Expand All @@ -38,12 +44,16 @@ type ConnConfig struct {
createdByParseConfig bool // Used to enforce created by ParseConfig rule.
}

// BuildPreparedStatementCacheFunc is a function that can be used to create a stmtcache.Cache implementation for connection.
type BuildPreparedStatementCacheFunc func(conn *pgconn.PgConn) stmtcache.Cache

// Conn is a PostgreSQL connection handle. It is not safe for concurrent usage. Use a connection pool to manage access
// to multiple database connections from multiple goroutines.
type Conn struct {
pgConn *pgconn.PgConn
config *ConnConfig // config used when establishing this connection
preparedStatements map[string]*pgconn.PreparedStatementDescription
stmtcache stmtcache.Cache
logger Logger
logLevel LogLevel

Expand Down Expand Up @@ -111,16 +121,58 @@ func ConnectConfig(ctx context.Context, connConfig *ConnConfig) (*Conn, error) {
return connect(ctx, connConfig)
}

// ParseConfig creates a ConnConfig from a connection string. See pgconn.ParseConfig for details.
// ParseConfig creates a ConnConfig from a connection string. ParseConfig handles all options that pgconn.ParseConfig
// does. In addition, it accepts the following options:
//
// statement_cache_capacity
// The maximum size of the automatic statement cache. Set to 0 to disable automatic statement caching. Default: 512.
//
// statement_cache_mode
// Possible values: "prepare" and "describe". "prepare" will create prepared statements on the PostgreSQL server.
// "describe" will use the anonymous prepared statement to describe a statement without creating a statement on the
// server. "describe" is primarily useful when the environment does not allow prepared statements such as when
// running a connection pooler like PgBouncer. Default: "prepare"
func ParseConfig(connString string) (*ConnConfig, error) {
config, err := pgconn.ParseConfig(connString)
if err != nil {
return nil, err
}

var buildPreparedStatementCache BuildPreparedStatementCacheFunc
statementCacheCapacity := 512
statementCacheMode := stmtcache.ModePrepare
if s, ok := config.RuntimeParams["statement_cache_capacity"]; ok {
delete(config.RuntimeParams, "statement_cache_capacity")
n, err := strconv.ParseInt(s, 10, 32)
if err != nil {
return nil, errors.Errorf("cannot parse statement_cache_capacity: %w", err)
}
statementCacheCapacity = int(n)
}

if s, ok := config.RuntimeParams["statement_cache_mode"]; ok {
delete(config.RuntimeParams, "statement_cache_mode")
switch s {
case "prepare":
statementCacheMode = stmtcache.ModePrepare
case "describe":
statementCacheMode = stmtcache.ModeDescribe
default:
return nil, errors.Errorf("invalid statement_cache_mod: %s", s)
}
}

if statementCacheCapacity > 0 {
buildPreparedStatementCache = func(conn *pgconn.PgConn) stmtcache.Cache {
return stmtcache.New(conn, statementCacheMode, statementCacheCapacity)
}
}

connConfig := &ConnConfig{
Config: *config,
createdByParseConfig: true,
LogLevel: LogLevelInfo,
Config: *config,
createdByParseConfig: true,
LogLevel: LogLevelInfo,
BuildPreparedStatementCache: buildPreparedStatementCache,
}

return connConfig, nil
Expand Down Expand Up @@ -165,6 +217,10 @@ func connect(ctx context.Context, config *ConnConfig) (c *Conn, err error) {
c.closedChan = make(chan error)
c.wbuf = make([]byte, 0, 1024)

if c.config.BuildPreparedStatementCache != nil {
c.stmtcache = c.config.BuildPreparedStatementCache(c.pgConn)
}

// Replication connections can't execute the queries to
// populate the c.PgTypes and c.pgsqlAfInet
if _, ok := c.pgConn.Config.RuntimeParams["replication"]; ok {
Expand Down Expand Up @@ -372,6 +428,9 @@ func connInfoFromRows(rows Rows, err error) (map[string]uint32, error) {
// is used and the connection must be returned to the same state before any *pgx.Conn methods are again used.
func (c *Conn) PgConn() *pgconn.PgConn { return c.pgConn }

// StmtCache returns the statement cache used for this connection.
func (c *Conn) StmtCache() stmtcache.Cache { return c.stmtcache }

// Exec executes sql. sql can be either a prepared statement name or an SQL string. arguments should be referenced
// positionally from the sql string as $1, $2, etc.
func (c *Conn) Exec(ctx context.Context, sql string, arguments ...interface{}) (pgconn.CommandTag, error) {
Expand Down Expand Up @@ -419,6 +478,18 @@ optionLoop:
return c.execSimpleProtocol(ctx, sql, arguments)
}

if c.stmtcache != nil {
ps, err := c.stmtcache.Get(ctx, sql)
if err != nil {
return nil, err
}

if c.stmtcache.Mode() == stmtcache.ModeDescribe {
return c.execParams(ctx, ps, arguments)
}
return c.execPrepared(ctx, ps, arguments)
}

ps, err := c.Prepare(ctx, "", sql)
if err != nil {
return nil, err
Expand All @@ -442,18 +513,18 @@ func (c *Conn) execSimpleProtocol(ctx context.Context, sql string, arguments []i
return commandTag, err
}

func (c *Conn) execPrepared(ctx context.Context, ps *pgconn.PreparedStatementDescription, arguments []interface{}) (commandTag pgconn.CommandTag, err error) {
func (c *Conn) execParamsAndPreparedPrefix(ps *pgconn.PreparedStatementDescription, arguments []interface{}) error {
c.eqb.Reset()

args, err := convertDriverValuers(arguments)
if err != nil {
return nil, err
return err
}

for i := range args {
err = c.eqb.AppendParam(c.ConnInfo, ps.ParamOIDs[i], args[i])
if err != nil {
return nil, err
return err
}
}

Expand All @@ -467,6 +538,25 @@ func (c *Conn) execPrepared(ctx context.Context, ps *pgconn.PreparedStatementDes
}
}

return nil
}

func (c *Conn) execParams(ctx context.Context, ps *pgconn.PreparedStatementDescription, arguments []interface{}) (pgconn.CommandTag, error) {
err := c.execParamsAndPreparedPrefix(ps, arguments)
if err != nil {
return nil, err
}

result := c.pgConn.ExecParams(ctx, ps.SQL, c.eqb.paramValues, ps.ParamOIDs, c.eqb.paramFormats, c.eqb.resultFormats).Read()
return result.CommandTag, result.Err
}

func (c *Conn) execPrepared(ctx context.Context, ps *pgconn.PreparedStatementDescription, arguments []interface{}) (pgconn.CommandTag, error) {
err := c.execParamsAndPreparedPrefix(ps, arguments)
if err != nil {
return nil, err
}

result := c.pgConn.ExecPrepared(ctx, ps.Name, c.eqb.paramValues, c.eqb.paramFormats, c.eqb.resultFormats).Read()
return result.CommandTag, result.Err
}
Expand Down Expand Up @@ -549,17 +639,25 @@ optionLoop:

ps, ok := c.preparedStatements[sql]
if !ok {
ps, err = c.pgConn.Prepare(ctx, "", sql, nil)
if err != nil {
rows.fatal(err)
return rows, rows.err
}

if len(ps.ParamOIDs) != len(args) {
rows.fatal(errors.Errorf("expected %d arguments, got %d", len(ps.ParamOIDs), len(args)))
return rows, rows.err
if c.stmtcache != nil {
ps, err = c.stmtcache.Get(ctx, sql)
if err != nil {
rows.fatal(err)
return rows, rows.err
}
} else {
ps, err = c.pgConn.Prepare(ctx, "", sql, nil)
if err != nil {
rows.fatal(err)
return rows, rows.err
}
}
}
if len(ps.ParamOIDs) != len(args) {
rows.fatal(errors.Errorf("expected %d arguments, got %d", len(ps.ParamOIDs), len(args)))
return rows, rows.err
}

rows.sql = ps.SQL

args, err = convertDriverValuers(args)
Expand Down Expand Up @@ -597,7 +695,11 @@ optionLoop:
resultFormats = c.eqb.resultFormats
}

rows.resultReader = c.pgConn.ExecPrepared(ctx, ps.Name, c.eqb.paramValues, c.eqb.paramFormats, resultFormats)
if c.stmtcache != nil && c.stmtcache.Mode() == stmtcache.ModeDescribe {
rows.resultReader = c.pgConn.ExecParams(ctx, sql, c.eqb.paramValues, ps.ParamOIDs, c.eqb.paramFormats, resultFormats)
} else {
rows.resultReader = c.pgConn.ExecPrepared(ctx, ps.Name, c.eqb.paramValues, c.eqb.paramFormats, resultFormats)
}

return rows, rows.err
}
Expand Down
Loading

0 comments on commit 0c3e59b

Please sign in to comment.