From 0c3e59b07a882e25a896f1b0e75a84461cc889cf Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 24 Aug 2019 20:29:54 -0500 Subject: [PATCH] Add automatic statement cache --- bench_test.go | 72 ++++++++++++++++++++++++++ conn.go | 136 +++++++++++++++++++++++++++++++++++++++++++------- conn_test.go | 92 ++++++++++++++++++++++++++++++++-- doc.go | 6 +++ go.mod | 4 +- go.sum | 8 +++ query_test.go | 53 ++++++++++++++++++++ 7 files changed, 349 insertions(+), 22 deletions(-) diff --git a/bench_test.go b/bench_test.go index a85de09da..4665301b4 100644 --- a/bench_test.go +++ b/bench_test.go @@ -9,10 +9,82 @@ import ( "testing" "time" + "github.com/jackc/pgconn" + "github.com/jackc/pgconn/stmtcache" "github.com/jackc/pgtype" "github.com/jackc/pgx/v4" ) +func BenchmarkMinimalUnpreparedSelectWithoutStatementCache(b *testing.B) { + config := mustParseConfig(b, os.Getenv("PGX_TEST_DATABASE")) + config.BuildPreparedStatementCache = nil + + conn := mustConnect(b, config) + defer closeConn(b, conn) + + var n int64 + + b.ResetTimer() + for i := 0; i < b.N; i++ { + err := conn.QueryRow(context.Background(), "select $1::int8", i).Scan(&n) + if err != nil { + b.Fatal(err) + } + + if n != int64(i) { + b.Fatalf("expected %d, got %d", i, n) + } + } +} + +func BenchmarkMinimalUnpreparedSelectWithStatementCacheModeDescribe(b *testing.B) { + config := mustParseConfig(b, os.Getenv("PGX_TEST_DATABASE")) + config.BuildPreparedStatementCache = func(conn *pgconn.PgConn) stmtcache.Cache { + return stmtcache.New(conn, stmtcache.ModeDescribe, 32) + } + + conn := mustConnect(b, config) + defer closeConn(b, conn) + + var n int64 + + b.ResetTimer() + for i := 0; i < b.N; i++ { + err := conn.QueryRow(context.Background(), "select $1::int8", i).Scan(&n) + if err != nil { + b.Fatal(err) + } + + if n != int64(i) { + b.Fatalf("expected %d, got %d", i, n) + } + } +} + +func BenchmarkMinimalUnpreparedSelectWithStatementCacheModePrepare(b *testing.B) { + config := mustParseConfig(b, os.Getenv("PGX_TEST_DATABASE")) + config.BuildPreparedStatementCache = func(conn *pgconn.PgConn) stmtcache.Cache { + return stmtcache.New(conn, stmtcache.ModePrepare, 32) + } + + conn := mustConnect(b, config) + defer closeConn(b, conn) + + var n int64 + + b.ResetTimer() + for i := 0; i < b.N; i++ { + err := conn.QueryRow(context.Background(), "select $1::int8", i).Scan(&n) + if err != nil { + b.Fatal(err) + } + + if n != int64(i) { + b.Fatalf("expected %d, got %d", i, n) + } + } +} + func BenchmarkMinimalPreparedSelect(b *testing.B) { conn := mustConnect(b, mustParseConfig(b, os.Getenv("PGX_TEST_DATABASE"))) defer closeConn(b, conn) diff --git a/conn.go b/conn.go index c6bd55296..bf19c3d13 100644 --- a/conn.go +++ b/conn.go @@ -2,12 +2,14 @@ package pgx import ( "context" + "strconv" "strings" "time" errors "golang.org/x/xerrors" "github.com/jackc/pgconn" + "github.com/jackc/pgconn/stmtcache" "github.com/jackc/pgproto3/v2" "github.com/jackc/pgtype" "github.com/jackc/pgx/v4/internal/sanitize" @@ -27,6 +29,10 @@ type ConnConfig struct { Logger Logger LogLevel LogLevel + // BuildPreparedStatementCache creates the stmtcache.Cache implementation for connections created with this config. Set + // to nil to disable automatic prepared statements. + BuildPreparedStatementCache BuildPreparedStatementCacheFunc + // PreferSimpleProtocol disables implicit prepared statement usage. By default pgx automatically uses the extended // protocol. This can improve performance due to being able to use the binary format. It also does not rely on client // side parameter sanitization. However, it does incur two round-trips per query (unless using a prepared statement) @@ -38,12 +44,16 @@ type ConnConfig struct { createdByParseConfig bool // Used to enforce created by ParseConfig rule. } +// BuildPreparedStatementCacheFunc is a function that can be used to create a stmtcache.Cache implementation for connection. +type BuildPreparedStatementCacheFunc func(conn *pgconn.PgConn) stmtcache.Cache + // Conn is a PostgreSQL connection handle. It is not safe for concurrent usage. Use a connection pool to manage access // to multiple database connections from multiple goroutines. type Conn struct { pgConn *pgconn.PgConn config *ConnConfig // config used when establishing this connection preparedStatements map[string]*pgconn.PreparedStatementDescription + stmtcache stmtcache.Cache logger Logger logLevel LogLevel @@ -111,16 +121,58 @@ func ConnectConfig(ctx context.Context, connConfig *ConnConfig) (*Conn, error) { return connect(ctx, connConfig) } -// ParseConfig creates a ConnConfig from a connection string. See pgconn.ParseConfig for details. +// ParseConfig creates a ConnConfig from a connection string. ParseConfig handles all options that pgconn.ParseConfig +// does. In addition, it accepts the following options: +// +// statement_cache_capacity +// The maximum size of the automatic statement cache. Set to 0 to disable automatic statement caching. Default: 512. +// +// statement_cache_mode +// Possible values: "prepare" and "describe". "prepare" will create prepared statements on the PostgreSQL server. +// "describe" will use the anonymous prepared statement to describe a statement without creating a statement on the +// server. "describe" is primarily useful when the environment does not allow prepared statements such as when +// running a connection pooler like PgBouncer. Default: "prepare" func ParseConfig(connString string) (*ConnConfig, error) { config, err := pgconn.ParseConfig(connString) if err != nil { return nil, err } + + var buildPreparedStatementCache BuildPreparedStatementCacheFunc + statementCacheCapacity := 512 + statementCacheMode := stmtcache.ModePrepare + if s, ok := config.RuntimeParams["statement_cache_capacity"]; ok { + delete(config.RuntimeParams, "statement_cache_capacity") + n, err := strconv.ParseInt(s, 10, 32) + if err != nil { + return nil, errors.Errorf("cannot parse statement_cache_capacity: %w", err) + } + statementCacheCapacity = int(n) + } + + if s, ok := config.RuntimeParams["statement_cache_mode"]; ok { + delete(config.RuntimeParams, "statement_cache_mode") + switch s { + case "prepare": + statementCacheMode = stmtcache.ModePrepare + case "describe": + statementCacheMode = stmtcache.ModeDescribe + default: + return nil, errors.Errorf("invalid statement_cache_mod: %s", s) + } + } + + if statementCacheCapacity > 0 { + buildPreparedStatementCache = func(conn *pgconn.PgConn) stmtcache.Cache { + return stmtcache.New(conn, statementCacheMode, statementCacheCapacity) + } + } + connConfig := &ConnConfig{ - Config: *config, - createdByParseConfig: true, - LogLevel: LogLevelInfo, + Config: *config, + createdByParseConfig: true, + LogLevel: LogLevelInfo, + BuildPreparedStatementCache: buildPreparedStatementCache, } return connConfig, nil @@ -165,6 +217,10 @@ func connect(ctx context.Context, config *ConnConfig) (c *Conn, err error) { c.closedChan = make(chan error) c.wbuf = make([]byte, 0, 1024) + if c.config.BuildPreparedStatementCache != nil { + c.stmtcache = c.config.BuildPreparedStatementCache(c.pgConn) + } + // Replication connections can't execute the queries to // populate the c.PgTypes and c.pgsqlAfInet if _, ok := c.pgConn.Config.RuntimeParams["replication"]; ok { @@ -372,6 +428,9 @@ func connInfoFromRows(rows Rows, err error) (map[string]uint32, error) { // is used and the connection must be returned to the same state before any *pgx.Conn methods are again used. func (c *Conn) PgConn() *pgconn.PgConn { return c.pgConn } +// StmtCache returns the statement cache used for this connection. +func (c *Conn) StmtCache() stmtcache.Cache { return c.stmtcache } + // Exec executes sql. sql can be either a prepared statement name or an SQL string. arguments should be referenced // positionally from the sql string as $1, $2, etc. func (c *Conn) Exec(ctx context.Context, sql string, arguments ...interface{}) (pgconn.CommandTag, error) { @@ -419,6 +478,18 @@ optionLoop: return c.execSimpleProtocol(ctx, sql, arguments) } + if c.stmtcache != nil { + ps, err := c.stmtcache.Get(ctx, sql) + if err != nil { + return nil, err + } + + if c.stmtcache.Mode() == stmtcache.ModeDescribe { + return c.execParams(ctx, ps, arguments) + } + return c.execPrepared(ctx, ps, arguments) + } + ps, err := c.Prepare(ctx, "", sql) if err != nil { return nil, err @@ -442,18 +513,18 @@ func (c *Conn) execSimpleProtocol(ctx context.Context, sql string, arguments []i return commandTag, err } -func (c *Conn) execPrepared(ctx context.Context, ps *pgconn.PreparedStatementDescription, arguments []interface{}) (commandTag pgconn.CommandTag, err error) { +func (c *Conn) execParamsAndPreparedPrefix(ps *pgconn.PreparedStatementDescription, arguments []interface{}) error { c.eqb.Reset() args, err := convertDriverValuers(arguments) if err != nil { - return nil, err + return err } for i := range args { err = c.eqb.AppendParam(c.ConnInfo, ps.ParamOIDs[i], args[i]) if err != nil { - return nil, err + return err } } @@ -467,6 +538,25 @@ func (c *Conn) execPrepared(ctx context.Context, ps *pgconn.PreparedStatementDes } } + return nil +} + +func (c *Conn) execParams(ctx context.Context, ps *pgconn.PreparedStatementDescription, arguments []interface{}) (pgconn.CommandTag, error) { + err := c.execParamsAndPreparedPrefix(ps, arguments) + if err != nil { + return nil, err + } + + result := c.pgConn.ExecParams(ctx, ps.SQL, c.eqb.paramValues, ps.ParamOIDs, c.eqb.paramFormats, c.eqb.resultFormats).Read() + return result.CommandTag, result.Err +} + +func (c *Conn) execPrepared(ctx context.Context, ps *pgconn.PreparedStatementDescription, arguments []interface{}) (pgconn.CommandTag, error) { + err := c.execParamsAndPreparedPrefix(ps, arguments) + if err != nil { + return nil, err + } + result := c.pgConn.ExecPrepared(ctx, ps.Name, c.eqb.paramValues, c.eqb.paramFormats, c.eqb.resultFormats).Read() return result.CommandTag, result.Err } @@ -549,17 +639,25 @@ optionLoop: ps, ok := c.preparedStatements[sql] if !ok { - ps, err = c.pgConn.Prepare(ctx, "", sql, nil) - if err != nil { - rows.fatal(err) - return rows, rows.err - } - - if len(ps.ParamOIDs) != len(args) { - rows.fatal(errors.Errorf("expected %d arguments, got %d", len(ps.ParamOIDs), len(args))) - return rows, rows.err + if c.stmtcache != nil { + ps, err = c.stmtcache.Get(ctx, sql) + if err != nil { + rows.fatal(err) + return rows, rows.err + } + } else { + ps, err = c.pgConn.Prepare(ctx, "", sql, nil) + if err != nil { + rows.fatal(err) + return rows, rows.err + } } } + if len(ps.ParamOIDs) != len(args) { + rows.fatal(errors.Errorf("expected %d arguments, got %d", len(ps.ParamOIDs), len(args))) + return rows, rows.err + } + rows.sql = ps.SQL args, err = convertDriverValuers(args) @@ -597,7 +695,11 @@ optionLoop: resultFormats = c.eqb.resultFormats } - rows.resultReader = c.pgConn.ExecPrepared(ctx, ps.Name, c.eqb.paramValues, c.eqb.paramFormats, resultFormats) + if c.stmtcache != nil && c.stmtcache.Mode() == stmtcache.ModeDescribe { + rows.resultReader = c.pgConn.ExecParams(ctx, sql, c.eqb.paramValues, ps.ParamOIDs, c.eqb.paramFormats, resultFormats) + } else { + rows.resultReader = c.pgConn.ExecPrepared(ctx, ps.Name, c.eqb.paramValues, c.eqb.paramFormats, resultFormats) + } return rows, rows.err } diff --git a/conn_test.go b/conn_test.go index f1128416b..f4c453509 100644 --- a/conn_test.go +++ b/conn_test.go @@ -9,6 +9,7 @@ import ( "time" "github.com/jackc/pgconn" + "github.com/jackc/pgconn/stmtcache" "github.com/jackc/pgtype" "github.com/jackc/pgx/v4" "github.com/stretchr/testify/assert" @@ -105,6 +106,38 @@ func TestConnectConfigRequiresConnConfigFromParseConfig(t *testing.T) { require.PanicsWithValue(t, "config must be created by ParseConfig", func() { pgx.ConnectConfig(context.Background(), config) }) } +func TestParseConfigExtractsStatementCacheOptions(t *testing.T) { + t.Parallel() + + config, err := pgx.ParseConfig("statement_cache_capacity=0") + require.NoError(t, err) + require.Nil(t, config.BuildPreparedStatementCache) + + config, err = pgx.ParseConfig("statement_cache_capacity=42") + require.NoError(t, err) + require.NotNil(t, config.BuildPreparedStatementCache) + c := config.BuildPreparedStatementCache(nil) + require.NotNil(t, c) + require.Equal(t, 42, c.Cap()) + require.Equal(t, stmtcache.ModePrepare, c.Mode()) + + config, err = pgx.ParseConfig("statement_cache_capacity=42 statement_cache_mode=prepare") + require.NoError(t, err) + require.NotNil(t, config.BuildPreparedStatementCache) + c = config.BuildPreparedStatementCache(nil) + require.NotNil(t, c) + require.Equal(t, 42, c.Cap()) + require.Equal(t, stmtcache.ModePrepare, c.Mode()) + + config, err = pgx.ParseConfig("statement_cache_capacity=42 statement_cache_mode=describe") + require.NoError(t, err) + require.NotNil(t, config.BuildPreparedStatementCache) + c = config.BuildPreparedStatementCache(nil) + require.NotNil(t, c) + require.Equal(t, 42, c.Cap()) + require.Equal(t, stmtcache.ModeDescribe, c.Mode()) +} + func TestExec(t *testing.T) { t.Parallel() @@ -285,6 +318,56 @@ func TestExecExtendedProtocol(t *testing.T) { ensureConnValid(t, conn) } +func TestExecPreparedStatementCacheModes(t *testing.T) { + t.Parallel() + + config := mustParseConfig(t, os.Getenv("PGX_TEST_DATABASE")) + + tests := []struct { + name string + buildPreparedStatementCache pgx.BuildPreparedStatementCacheFunc + }{ + { + name: "disabled", + buildPreparedStatementCache: nil, + }, + { + name: "prepare", + buildPreparedStatementCache: func(conn *pgconn.PgConn) stmtcache.Cache { + return stmtcache.New(conn, stmtcache.ModePrepare, 32) + }, + }, + { + name: "describe", + buildPreparedStatementCache: func(conn *pgconn.PgConn) stmtcache.Cache { + return stmtcache.New(conn, stmtcache.ModeDescribe, 32) + }, + }, + } + + for _, tt := range tests { + func() { + config.BuildPreparedStatementCache = tt.buildPreparedStatementCache + conn := mustConnect(t, config) + defer closeConn(t, conn) + + commandTag, err := conn.Exec(context.Background(), "select 1") + assert.NoError(t, err, tt.name) + assert.Equal(t, "SELECT 1", string(commandTag), tt.name) + + commandTag, err = conn.Exec(context.Background(), "select 1 union all select 1") + assert.NoError(t, err, tt.name) + assert.Equal(t, "SELECT 2", string(commandTag), tt.name) + + commandTag, err = conn.Exec(context.Background(), "select 1") + assert.NoError(t, err, tt.name) + assert.Equal(t, "SELECT 1", string(commandTag), tt.name) + + ensureConnValid(t, conn) + }() + } +} + func TestExecSimpleProtocol(t *testing.T) { t.Parallel() @@ -652,9 +735,12 @@ func TestCatchSimultaneousConnectionQueries(t *testing.T) { } defer rows1.Close() - _, err = conn.Query(context.Background(), "select generate_series(1,$1)", 10) - if !errors.Is(err, pgconn.ErrConnBusy) { - t.Fatalf("conn.Query should have failed with pgconn.ErrConnBusy, but it was %v", err) + rows2, err := conn.Query(context.Background(), "select generate_series(1,$1)", 10) + require.NoError(t, err) + require.NotNil(t, rows2) + require.False(t, rows2.Next()) + if !errors.Is(rows2.Err(), pgconn.ErrConnBusy) { + t.Fatalf("conn.Query should have failed with pgconn.ErrConnBusy, but it was %v", rows2.Err()) } } diff --git a/doc.go b/doc.go index b22c870ae..ba2ed2bc7 100644 --- a/doc.go +++ b/doc.go @@ -179,6 +179,12 @@ can create a transaction with a specified isolation level. return err } +Prepared Statements + +Prepared statements can be manually created with the Prepare method. However, this is rarely necessary because pgx +includes an automatic statement cache by default. Queries run through the normal Query, QueryRow, and Exec functions are +automatically prepared on first execution and the prepared statement is reused on subsequent executions. + Copy Protocol Use CopyFrom to efficiently insert multiple rows at a time using the PostgreSQL diff --git a/go.mod b/go.mod index 9ee7e1724..39c738e02 100644 --- a/go.mod +++ b/go.mod @@ -6,7 +6,7 @@ require ( github.com/cockroachdb/apd v1.1.0 github.com/coreos/go-systemd v0.0.0-20190719114852-fd7a80b32e1f // indirect github.com/go-stack/stack v1.8.0 // indirect - github.com/jackc/pgconn v0.0.0-20190824142844-760dd75542eb + github.com/jackc/pgconn v0.0.0-20190825004843-78abbdf1d7ee github.com/jackc/pgio v1.0.0 github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190609003834-432c2951c711 github.com/jackc/pgtype v0.0.0-20190824184912-ab885b375b90 @@ -28,7 +28,7 @@ require ( golang.org/x/net v0.0.0-20190813141303-74dc4d7220e7 // indirect golang.org/x/sys v0.0.0-20190813064441-fde4db37ae7a // indirect golang.org/x/text v0.3.2 // indirect - golang.org/x/tools v0.0.0-20190823170909-c4a336ef6a2f // indirect + golang.org/x/tools v0.0.0-20190824210100-c2567a220953 // indirect golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7 gopkg.in/inconshreveable/log15.v2 v2.0.0-20180818164646-67afb5ed74ec ) diff --git a/go.sum b/go.sum index 4b937b9e4..ff02f7237 100644 --- a/go.sum +++ b/go.sum @@ -17,6 +17,12 @@ github.com/jackc/pgconn v0.0.0-20190420214824-7e0022ef6ba3 h1:ZFYpB74Kq8xE9gmfxC github.com/jackc/pgconn v0.0.0-20190420214824-7e0022ef6ba3/go.mod h1:jkELnwuX+w9qN5YIfX0fl88Ehu4XC3keFuOJJk9pcnA= github.com/jackc/pgconn v0.0.0-20190824142844-760dd75542eb h1:d6GP9szHvXVopAOAnZ7WhRnF3Xdxrylmm/9jnfmW4Ag= github.com/jackc/pgconn v0.0.0-20190824142844-760dd75542eb/go.mod h1:lLjNuW/+OfW9/pnVKPazfWOgNfH2aPem8YQ7ilXGvJE= +github.com/jackc/pgconn v0.0.0-20190824212754-2209d2e36aea h1:FwCceMjr3vnfVyl2EG3F0TKILOVs0ly8Z8EbXe72WAE= +github.com/jackc/pgconn v0.0.0-20190824212754-2209d2e36aea/go.mod h1:lLjNuW/+OfW9/pnVKPazfWOgNfH2aPem8YQ7ilXGvJE= +github.com/jackc/pgconn v0.0.0-20190824221829-beba629bb5d5 h1:fGy7MTsuLbREyDs7o1m03cGEgwMrKyUP488Z9zlmR/k= +github.com/jackc/pgconn v0.0.0-20190824221829-beba629bb5d5/go.mod h1:lLjNuW/+OfW9/pnVKPazfWOgNfH2aPem8YQ7ilXGvJE= +github.com/jackc/pgconn v0.0.0-20190825004843-78abbdf1d7ee h1:uHUd7Cnu7QjzOqOWj6MYqz8zvNGoDZG1tK6jQASP2j0= +github.com/jackc/pgconn v0.0.0-20190825004843-78abbdf1d7ee/go.mod h1:lLjNuW/+OfW9/pnVKPazfWOgNfH2aPem8YQ7ilXGvJE= github.com/jackc/pgio v1.0.0 h1:g12B9UwVnzGhueNavwioyEEpAmqMe1E/BN9ES+8ovkE= github.com/jackc/pgio v1.0.0/go.mod h1:oP+2QK2wFfUWgr+gxjoBH9KGBb31Eio69xUb0w5bYf8= github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= @@ -50,6 +56,7 @@ github.com/lib/pq v1.0.0 h1:X5PMW56eZitiTeO7tKzZxFCSpbFZJtkMMooicw2us9A= github.com/lib/pq v1.0.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= github.com/lib/pq v1.1.0 h1:/5u4a+KGJptBRqGzPvYQL9p0d/tPR4S31+Tnzj9lEO4= github.com/lib/pq v1.1.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= +github.com/lib/pq v1.2.0 h1:LXpIM/LZ5xGFhOpXAQUIMM1HdyqzVYM13zNdjCEEcA0= github.com/lib/pq v1.2.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= github.com/mattn/go-colorable v0.1.1 h1:G1f5SKeVxmagw/IyvzvtZE4Gybcc4Tr1tf7I8z0XgOg= github.com/mattn/go-colorable v0.1.1/go.mod h1:FuOcm+DKB9mbwrcAfNl7/TZVBZ6rcnceauSikq3lYCQ= @@ -119,6 +126,7 @@ golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20190425163242-31fd60d6bfdc/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= golang.org/x/tools v0.0.0-20190823170909-c4a336ef6a2f/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/tools v0.0.0-20190824210100-c2567a220953/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/xerrors v0.0.0-20190410155217-1f06c39b4373 h1:PPwnA7z1Pjf7XYaBP9GL1VAMZmcIWyFz7QCMSIIa3Bg= golang.org/x/xerrors v0.0.0-20190410155217-1f06c39b4373/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20190513163551-3ee3066db522 h1:bhOzK9QyoD0ogCnFro1m2mz41+Ib0oOhfJnBp5MR4K4= diff --git a/query_test.go b/query_test.go index 91cc744ab..0ced94632 100644 --- a/query_test.go +++ b/query_test.go @@ -13,11 +13,13 @@ import ( "github.com/cockroachdb/apd" "github.com/jackc/pgconn" + "github.com/jackc/pgconn/stmtcache" "github.com/jackc/pgtype" satori "github.com/jackc/pgtype/ext/satori-uuid" "github.com/jackc/pgx/v4" uuid "github.com/satori/go.uuid" "github.com/shopspring/decimal" + "github.com/stretchr/testify/assert" errors "golang.org/x/xerrors" ) @@ -1563,3 +1565,54 @@ func TestConnSimpleProtocolRefusesNonStandardConformingStrings(t *testing.T) { ensureConnValid(t, conn) } + +func TestQueryPreparedStatementCacheModes(t *testing.T) { + t.Parallel() + + config := mustParseConfig(t, os.Getenv("PGX_TEST_DATABASE")) + + tests := []struct { + name string + buildPreparedStatementCache pgx.BuildPreparedStatementCacheFunc + }{ + { + name: "disabled", + buildPreparedStatementCache: nil, + }, + { + name: "prepare", + buildPreparedStatementCache: func(conn *pgconn.PgConn) stmtcache.Cache { + return stmtcache.New(conn, stmtcache.ModePrepare, 32) + }, + }, + { + name: "describe", + buildPreparedStatementCache: func(conn *pgconn.PgConn) stmtcache.Cache { + return stmtcache.New(conn, stmtcache.ModeDescribe, 32) + }, + }, + } + + for _, tt := range tests { + func() { + config.BuildPreparedStatementCache = tt.buildPreparedStatementCache + conn := mustConnect(t, config) + defer closeConn(t, conn) + + var n int + err := conn.QueryRow(context.Background(), "select 1").Scan(&n) + assert.NoError(t, err, tt.name) + assert.Equal(t, 1, n, tt.name) + + err = conn.QueryRow(context.Background(), "select 2").Scan(&n) + assert.NoError(t, err, tt.name) + assert.Equal(t, 2, n, tt.name) + + err = conn.QueryRow(context.Background(), "select 1").Scan(&n) + assert.NoError(t, err, tt.name) + assert.Equal(t, 1, n, tt.name) + + ensureConnValid(t, conn) + }() + } +}