Skip to content

Commit

Permalink
Add emit_methods_with_db_argument config option (#1279)
Browse files Browse the repository at this point in the history
Support a new configuration parameter `emit_methods_with_db_argument `that modifies the generated method sets to provide methods like shown in WidgetInserter. If set to true, the generated *Queries omits storing a DBTX as a struct field and requires it be passed in to all method calls. In doing so, it allows callers to easily provide the connection for standalone use or for use as part of a broader transaction and makes it easy for the surrounding code to use a narrowly defined interface.
  • Loading branch information
danielmmetz authored Nov 20, 2021
1 parent 2aa00e2 commit 768ccb6
Show file tree
Hide file tree
Showing 30 changed files with 522 additions and 96 deletions.
3 changes: 3 additions & 0 deletions docs/reference/config.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ packages:
emit_json_tags: true
emit_result_struct_pointers: false
emit_params_struct_pointers: false
emit_methods_with_db_argument: false
json_tags_case_style: "camel"
output_db_file_name: "db.go"
output_models_file_name: "models.go"
Expand Down Expand Up @@ -57,6 +58,8 @@ Each package document has the following keys:
- If true, query results are returned as pointers to structs. Queries returning multiple results are returned as slices of pointers. Defaults to `false`.
- `emit_params_struct_pointers`:
- If true, parameters are passed as pointers to structs. Defaults to `false`.
- `emit_methods_with_db_argument`:
- If true, generated methods will accept a DBTX argument instead of storing a DBTX on the `*Queries` struct. Defaults to `false`.
- `json_tags_case_style`:
- `camel` for camelCase, `pascal` for PascalCase, `snake` for snake_case or `none` to use the column name in the DB. Defaults to `none`.
- `output_db_file_name`:
Expand Down
5 changes: 5 additions & 0 deletions internal/cmd/generate.go
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,11 @@ func Generate(ctx context.Context, e Env, dir, filename string, stderr io.Writer
return nil, err
}

if err := config.Validate(conf); err != nil {
fmt.Fprintf(stderr, "error validating %s: %s\n", base, err)
return nil, err
}

output := map[string]string{}
errored := false

Expand Down
36 changes: 19 additions & 17 deletions internal/codegen/golang/gen.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,12 @@ type tmplCtx struct {
// TODO: Race conditions
SourceName string

EmitJSONTags bool
EmitDBTags bool
EmitPreparedQueries bool
EmitInterface bool
EmitEmptySlices bool
EmitJSONTags bool
EmitDBTags bool
EmitPreparedQueries bool
EmitInterface bool
EmitEmptySlices bool
EmitMethodsWithDBArgument bool
}

func (t *tmplCtx) OutputQuery(sourceName string) bool {
Expand Down Expand Up @@ -76,18 +77,19 @@ func generate(settings config.CombinedSettings, enums []Enum, structs []Struct,

golang := settings.Go
tctx := tmplCtx{
Settings: settings.Global,
EmitInterface: golang.EmitInterface,
EmitJSONTags: golang.EmitJSONTags,
EmitDBTags: golang.EmitDBTags,
EmitPreparedQueries: golang.EmitPreparedQueries,
EmitEmptySlices: golang.EmitEmptySlices,
SQLPackage: SQLPackageFromString(golang.SQLPackage),
Q: "`",
Package: golang.Package,
GoQueries: queries,
Enums: enums,
Structs: structs,
Settings: settings.Global,
EmitInterface: golang.EmitInterface,
EmitJSONTags: golang.EmitJSONTags,
EmitDBTags: golang.EmitDBTags,
EmitPreparedQueries: golang.EmitPreparedQueries,
EmitEmptySlices: golang.EmitEmptySlices,
EmitMethodsWithDBArgument: golang.EmitMethodsWithDBArgument,
SQLPackage: SQLPackageFromString(golang.SQLPackage),
Q: "`",
Package: golang.Package,
GoQueries: queries,
Enums: enums,
Structs: structs,
}

output := map[string]string{}
Expand Down
11 changes: 10 additions & 1 deletion internal/codegen/golang/templates/pgx/dbCode.tmpl
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,26 @@ type DBTX interface {
QueryRow(context.Context, string, ...interface{}) pgx.Row
}

{{ if .EmitMethodsWithDBArgument}}
func New() *Queries {
return &Queries{}
{{- else -}}
func New(db DBTX) *Queries {
return &Queries{db: db}
{{- end}}
}

type Queries struct {
{{if not .EmitMethodsWithDBArgument}}
db DBTX
{{end}}
}

{{if not .EmitMethodsWithDBArgument}}
func (q *Queries) WithTx(tx pgx.Tx) *Queries {
return &Queries{
db: tx,
}
}
{{end}}
{{end}}
{{end}}
23 changes: 17 additions & 6 deletions internal/codegen/golang/templates/pgx/interfaceCode.tmpl
Original file line number Diff line number Diff line change
@@ -1,23 +1,34 @@
{{define "interfaceCodePgx"}}
type Querier interface {
{{- $dbtxParam := .EmitMethodsWithDBArgument -}}
{{- range .GoQueries}}
{{- if eq .Cmd ":one"}}
{{- if and (eq .Cmd ":one") ($dbtxParam) }}
{{.MethodName}}(ctx context.Context, db DBTX, {{.Arg.Pair}}) ({{.Ret.DefineType}}, error)
{{- else if eq .Cmd ":one" }}
{{.MethodName}}(ctx context.Context, {{.Arg.Pair}}) ({{.Ret.DefineType}}, error)
{{- end}}
{{- if eq .Cmd ":many"}}
{{- if and (eq .Cmd ":many") ($dbtxParam) }}
{{.MethodName}}(ctx context.Context, db DBTX, {{.Arg.Pair}}) ([]{{.Ret.DefineType}}, error)
{{- else if eq .Cmd ":many" }}
{{.MethodName}}(ctx context.Context, {{.Arg.Pair}}) ([]{{.Ret.DefineType}}, error)
{{- end}}
{{- if eq .Cmd ":exec"}}
{{- if and (eq .Cmd ":exec") ($dbtxParam) }}
{{.MethodName}}(ctx context.Context, db DBTX, {{.Arg.Pair}}) error
{{- else if eq .Cmd ":exec" }}
{{.MethodName}}(ctx context.Context, {{.Arg.Pair}}) error
{{- end}}
{{- if eq .Cmd ":execrows"}}
{{- if and (eq .Cmd ":execrows") ($dbtxParam) }}
{{.MethodName}}(ctx context.Context, db DBTX, {{.Arg.Pair}}) (int64, error)
{{- else if eq .Cmd ":execrows" }}
{{.MethodName}}(ctx context.Context, {{.Arg.Pair}}) (int64, error)
{{- end}}
{{- if eq .Cmd ":execresult"}}
{{- if and (eq .Cmd ":execresult") ($dbtxParam) }}
{{.MethodName}}(ctx context.Context, db DBTX, {{.Arg.Pair}}) (pgconn.CommandTag, error)
{{- else if eq .Cmd ":execresult" }}
{{.MethodName}}(ctx context.Context, {{.Arg.Pair}}) (pgconn.CommandTag, error)
{{- end}}
{{- end}}
}

var _ Querier = (*Queries)(nil)
{{end}}
{{end}}
27 changes: 26 additions & 1 deletion internal/codegen/golang/templates/pgx/queryCode.tmpl
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,13 @@ type {{.Ret.Type}} struct { {{- range .Ret.Struct.Fields}}
{{if eq .Cmd ":one"}}
{{range .Comments}}//{{.}}
{{end -}}
{{- if $.EmitMethodsWithDBArgument}}
func (q *Queries) {{.MethodName}}(ctx context.Context, db DBTX, {{.Arg.Pair}}) ({{.Ret.DefineType}}, error) {
row := db.QueryRow(ctx, {{.ConstantName}}, {{.Arg.Params}})
{{- else -}}
func (q *Queries) {{.MethodName}}(ctx context.Context, {{.Arg.Pair}}) ({{.Ret.DefineType}}, error) {
row := q.db.QueryRow(ctx, {{.ConstantName}}, {{.Arg.Params}})
{{- end}}
var {{.Ret.Name}} {{.Ret.Type}}
err := row.Scan({{.Ret.Scan}})
return {{.Ret.ReturnName}}, err
Expand All @@ -33,8 +38,13 @@ func (q *Queries) {{.MethodName}}(ctx context.Context, {{.Arg.Pair}}) ({{.Ret.De
{{if eq .Cmd ":many"}}
{{range .Comments}}//{{.}}
{{end -}}
{{- if $.EmitMethodsWithDBArgument}}
func (q *Queries) {{.MethodName}}(ctx context.Context, db DBTX, {{.Arg.Pair}}) ([]{{.Ret.DefineType}}, error) {
rows, err := db.Query(ctx, {{.ConstantName}}, {{.Arg.Params}})
{{- else}}
func (q *Queries) {{.MethodName}}(ctx context.Context, {{.Arg.Pair}}) ([]{{.Ret.DefineType}}, error) {
rows, err := q.db.Query(ctx, {{.ConstantName}}, {{.Arg.Params}})
{{- end}}
if err != nil {
return nil, err
}
Expand All @@ -61,17 +71,27 @@ func (q *Queries) {{.MethodName}}(ctx context.Context, {{.Arg.Pair}}) ([]{{.Ret.
{{if eq .Cmd ":exec"}}
{{range .Comments}}//{{.}}
{{end -}}
{{- if $.EmitMethodsWithDBArgument}}
func (q *Queries) {{.MethodName}}(ctx context.Context, db DBTX, {{.Arg.Pair}}) error {
_, err := db.Exec(ctx, {{.ConstantName}}, {{.Arg.Params}})
{{- else}}
func (q *Queries) {{.MethodName}}(ctx context.Context, {{.Arg.Pair}}) error {
_, err := q.db.Exec(ctx, {{.ConstantName}}, {{.Arg.Params}})
{{- end}}
return err
}
{{end}}

{{if eq .Cmd ":execrows"}}
{{range .Comments}}//{{.}}
{{end -}}
{{if $.EmitMethodsWithDBArgument}}
func (q *Queries) {{.MethodName}}(ctx context.Context, db DBTX, {{.Arg.Pair}}) (int64, error) {
result, err := db.Exec(ctx, {{.ConstantName}}, {{.Arg.Params}})
{{- else -}}
func (q *Queries) {{.MethodName}}(ctx context.Context, {{.Arg.Pair}}) (int64, error) {
result, err := q.db.Exec(ctx, {{.ConstantName}}, {{.Arg.Params}})
{{- end}}
if err != nil {
return 0, err
}
Expand All @@ -82,11 +102,16 @@ func (q *Queries) {{.MethodName}}(ctx context.Context, {{.Arg.Pair}}) (int64, er
{{if eq .Cmd ":execresult"}}
{{range .Comments}}//{{.}}
{{end -}}
{{if $.EmitMethodsWithDBArgument}}
func (q *Queries) {{.MethodName}}(ctx context.Context, db DBTX, {{.Arg.Pair}}) (pgconn.CommandTag, error) {
return db.Exec(ctx, {{.ConstantName}}, {{.Arg.Params}})
{{else}}
func (q *Queries) {{.MethodName}}(ctx context.Context, {{.Arg.Pair}}) (pgconn.CommandTag, error) {
return q.db.Exec(ctx, {{.ConstantName}}, {{.Arg.Params}})
{{- end}}
}
{{end}}

{{end}}
{{end}}
{{end}}
{{end}}
11 changes: 10 additions & 1 deletion internal/codegen/golang/templates/stdlib/dbCode.tmpl
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,13 @@ type DBTX interface {
QueryRowContext(context.Context, string, ...interface{}) *sql.Row
}

{{ if .EmitMethodsWithDBArgument}}
func New() *Queries {
return &Queries{}
{{- else -}}
func New(db DBTX) *Queries {
return &Queries{db: db}
{{- end}}
}

{{if .EmitPreparedQueries}}
Expand Down Expand Up @@ -72,7 +77,9 @@ func (q *Queries) queryRow(ctx context.Context, stmt *sql.Stmt, query string, ar
{{end}}

type Queries struct {
{{- if not .EmitMethodsWithDBArgument}}
db DBTX
{{- end}}

{{- if .EmitPreparedQueries}}
tx *sql.Tx
Expand All @@ -82,6 +89,7 @@ type Queries struct {
{{- end}}
}

{{if not .EmitMethodsWithDBArgument}}
func (q *Queries) WithTx(tx *sql.Tx) *Queries {
return &Queries{
db: tx,
Expand All @@ -93,4 +101,5 @@ func (q *Queries) WithTx(tx *sql.Tx) *Queries {
{{- end}}
}
}
{{end}}
{{end}}
{{end}}
23 changes: 17 additions & 6 deletions internal/codegen/golang/templates/stdlib/interfaceCode.tmpl
Original file line number Diff line number Diff line change
@@ -1,23 +1,34 @@
{{define "interfaceCodeStd"}}
type Querier interface {
{{- $dbtxParam := .EmitMethodsWithDBArgument -}}
{{- range .GoQueries}}
{{- if eq .Cmd ":one"}}
{{- if and (eq .Cmd ":one") ($dbtxParam) }}
{{.MethodName}}(ctx context.Context, db DBTX, {{.Arg.Pair}}) ({{.Ret.DefineType}}, error)
{{- else if eq .Cmd ":one"}}
{{.MethodName}}(ctx context.Context, {{.Arg.Pair}}) ({{.Ret.DefineType}}, error)
{{- end}}
{{- if eq .Cmd ":many"}}
{{- if and (eq .Cmd ":many") ($dbtxParam) }}
{{.MethodName}}(ctx context.Context, db DBTX, {{.Arg.Pair}}) ([]{{.Ret.DefineType}}, error)
{{- else if eq .Cmd ":many"}}
{{.MethodName}}(ctx context.Context, {{.Arg.Pair}}) ([]{{.Ret.DefineType}}, error)
{{- end}}
{{- if eq .Cmd ":exec"}}
{{- if and (eq .Cmd ":exec") ($dbtxParam) }}
{{.MethodName}}(ctx context.Context, db DBTX, {{.Arg.Pair}}) error
{{- else if eq .Cmd ":exec"}}
{{.MethodName}}(ctx context.Context, {{.Arg.Pair}}) error
{{- end}}
{{- if eq .Cmd ":execrows"}}
{{- if and (eq .Cmd ":execrows") ($dbtxParam) }}
{{.MethodName}}(ctx context.Context, db DBTX, {{.Arg.Pair}}) (int64, error)
{{- else if eq .Cmd ":execrows"}}
{{.MethodName}}(ctx context.Context, {{.Arg.Pair}}) (int64, error)
{{- end}}
{{- if eq .Cmd ":execresult"}}
{{- if and (eq .Cmd ":execresult") ($dbtxParam) }}
{{.MethodName}}(ctx context.Context, db DBTX, {{.Arg.Pair}}) (sql.Result, error)
{{- else if eq .Cmd ":execresult"}}
{{.MethodName}}(ctx context.Context, {{.Arg.Pair}}) (sql.Result, error)
{{- end}}
{{- end}}
}

var _ Querier = (*Queries)(nil)
{{end}}
{{end}}
Loading

0 comments on commit 768ccb6

Please sign in to comment.