diff --git a/docs/reference/config.md b/docs/reference/config.md index 43d09363de..4907f5ebdd 100644 --- a/docs/reference/config.md +++ b/docs/reference/config.md @@ -19,6 +19,7 @@ packages: emit_json_tags: true emit_result_struct_pointers: false emit_params_struct_pointers: false + emit_methods_with_db_argument: false json_tags_case_style: "camel" output_db_file_name: "db.go" output_models_file_name: "models.go" @@ -57,6 +58,8 @@ Each package document has the following keys: - If true, query results are returned as pointers to structs. Queries returning multiple results are returned as slices of pointers. Defaults to `false`. - `emit_params_struct_pointers`: - If true, parameters are passed as pointers to structs. Defaults to `false`. +- `emit_methods_with_db_argument`: + - If true, generated methods will accept a DBTX argument instead of storing a DBTX on the `*Queries` struct. Defaults to `false`. - `json_tags_case_style`: - `camel` for camelCase, `pascal` for PascalCase, `snake` for snake_case or `none` to use the column name in the DB. Defaults to `none`. - `output_db_file_name`: diff --git a/internal/cmd/generate.go b/internal/cmd/generate.go index 0ed7a4745c..41e24d1dd9 100644 --- a/internal/cmd/generate.go +++ b/internal/cmd/generate.go @@ -99,6 +99,11 @@ func Generate(ctx context.Context, e Env, dir, filename string, stderr io.Writer return nil, err } + if err := config.Validate(conf); err != nil { + fmt.Fprintf(stderr, "error validating %s: %s\n", base, err) + return nil, err + } + output := map[string]string{} errored := false diff --git a/internal/codegen/golang/gen.go b/internal/codegen/golang/gen.go index 8c9c4ac882..cf5ef2f3e3 100644 --- a/internal/codegen/golang/gen.go +++ b/internal/codegen/golang/gen.go @@ -31,11 +31,12 @@ type tmplCtx struct { // TODO: Race conditions SourceName string - EmitJSONTags bool - EmitDBTags bool - EmitPreparedQueries bool - EmitInterface bool - EmitEmptySlices bool + EmitJSONTags bool + EmitDBTags bool + EmitPreparedQueries bool + EmitInterface bool + EmitEmptySlices bool + EmitMethodsWithDBArgument bool } func (t *tmplCtx) OutputQuery(sourceName string) bool { @@ -76,18 +77,19 @@ func generate(settings config.CombinedSettings, enums []Enum, structs []Struct, golang := settings.Go tctx := tmplCtx{ - Settings: settings.Global, - EmitInterface: golang.EmitInterface, - EmitJSONTags: golang.EmitJSONTags, - EmitDBTags: golang.EmitDBTags, - EmitPreparedQueries: golang.EmitPreparedQueries, - EmitEmptySlices: golang.EmitEmptySlices, - SQLPackage: SQLPackageFromString(golang.SQLPackage), - Q: "`", - Package: golang.Package, - GoQueries: queries, - Enums: enums, - Structs: structs, + Settings: settings.Global, + EmitInterface: golang.EmitInterface, + EmitJSONTags: golang.EmitJSONTags, + EmitDBTags: golang.EmitDBTags, + EmitPreparedQueries: golang.EmitPreparedQueries, + EmitEmptySlices: golang.EmitEmptySlices, + EmitMethodsWithDBArgument: golang.EmitMethodsWithDBArgument, + SQLPackage: SQLPackageFromString(golang.SQLPackage), + Q: "`", + Package: golang.Package, + GoQueries: queries, + Enums: enums, + Structs: structs, } output := map[string]string{} diff --git a/internal/codegen/golang/templates/pgx/dbCode.tmpl b/internal/codegen/golang/templates/pgx/dbCode.tmpl index 4334f54d6a..dbfde50d50 100644 --- a/internal/codegen/golang/templates/pgx/dbCode.tmpl +++ b/internal/codegen/golang/templates/pgx/dbCode.tmpl @@ -6,17 +6,26 @@ type DBTX interface { QueryRow(context.Context, string, ...interface{}) pgx.Row } +{{ if .EmitMethodsWithDBArgument}} +func New() *Queries { + return &Queries{} +{{- else -}} func New(db DBTX) *Queries { return &Queries{db: db} +{{- end}} } type Queries struct { + {{if not .EmitMethodsWithDBArgument}} db DBTX + {{end}} } +{{if not .EmitMethodsWithDBArgument}} func (q *Queries) WithTx(tx pgx.Tx) *Queries { return &Queries{ db: tx, } } -{{end}} \ No newline at end of file +{{end}} +{{end}} diff --git a/internal/codegen/golang/templates/pgx/interfaceCode.tmpl b/internal/codegen/golang/templates/pgx/interfaceCode.tmpl index a7e48974d3..d8ab4d7bb9 100644 --- a/internal/codegen/golang/templates/pgx/interfaceCode.tmpl +++ b/internal/codegen/golang/templates/pgx/interfaceCode.tmpl @@ -1,23 +1,34 @@ {{define "interfaceCodePgx"}} type Querier interface { + {{- $dbtxParam := .EmitMethodsWithDBArgument -}} {{- range .GoQueries}} - {{- if eq .Cmd ":one"}} + {{- if and (eq .Cmd ":one") ($dbtxParam) }} + {{.MethodName}}(ctx context.Context, db DBTX, {{.Arg.Pair}}) ({{.Ret.DefineType}}, error) + {{- else if eq .Cmd ":one" }} {{.MethodName}}(ctx context.Context, {{.Arg.Pair}}) ({{.Ret.DefineType}}, error) {{- end}} - {{- if eq .Cmd ":many"}} + {{- if and (eq .Cmd ":many") ($dbtxParam) }} + {{.MethodName}}(ctx context.Context, db DBTX, {{.Arg.Pair}}) ([]{{.Ret.DefineType}}, error) + {{- else if eq .Cmd ":many" }} {{.MethodName}}(ctx context.Context, {{.Arg.Pair}}) ([]{{.Ret.DefineType}}, error) {{- end}} - {{- if eq .Cmd ":exec"}} + {{- if and (eq .Cmd ":exec") ($dbtxParam) }} + {{.MethodName}}(ctx context.Context, db DBTX, {{.Arg.Pair}}) error + {{- else if eq .Cmd ":exec" }} {{.MethodName}}(ctx context.Context, {{.Arg.Pair}}) error {{- end}} - {{- if eq .Cmd ":execrows"}} + {{- if and (eq .Cmd ":execrows") ($dbtxParam) }} + {{.MethodName}}(ctx context.Context, db DBTX, {{.Arg.Pair}}) (int64, error) + {{- else if eq .Cmd ":execrows" }} {{.MethodName}}(ctx context.Context, {{.Arg.Pair}}) (int64, error) {{- end}} - {{- if eq .Cmd ":execresult"}} + {{- if and (eq .Cmd ":execresult") ($dbtxParam) }} + {{.MethodName}}(ctx context.Context, db DBTX, {{.Arg.Pair}}) (pgconn.CommandTag, error) + {{- else if eq .Cmd ":execresult" }} {{.MethodName}}(ctx context.Context, {{.Arg.Pair}}) (pgconn.CommandTag, error) {{- end}} {{- end}} } var _ Querier = (*Queries)(nil) -{{end}} \ No newline at end of file +{{end}} diff --git a/internal/codegen/golang/templates/pgx/queryCode.tmpl b/internal/codegen/golang/templates/pgx/queryCode.tmpl index f5c9e3ed58..21575b3e70 100644 --- a/internal/codegen/golang/templates/pgx/queryCode.tmpl +++ b/internal/codegen/golang/templates/pgx/queryCode.tmpl @@ -22,8 +22,13 @@ type {{.Ret.Type}} struct { {{- range .Ret.Struct.Fields}} {{if eq .Cmd ":one"}} {{range .Comments}}//{{.}} {{end -}} +{{- if $.EmitMethodsWithDBArgument}} +func (q *Queries) {{.MethodName}}(ctx context.Context, db DBTX, {{.Arg.Pair}}) ({{.Ret.DefineType}}, error) { + row := db.QueryRow(ctx, {{.ConstantName}}, {{.Arg.Params}}) +{{- else -}} func (q *Queries) {{.MethodName}}(ctx context.Context, {{.Arg.Pair}}) ({{.Ret.DefineType}}, error) { row := q.db.QueryRow(ctx, {{.ConstantName}}, {{.Arg.Params}}) +{{- end}} var {{.Ret.Name}} {{.Ret.Type}} err := row.Scan({{.Ret.Scan}}) return {{.Ret.ReturnName}}, err @@ -33,8 +38,13 @@ func (q *Queries) {{.MethodName}}(ctx context.Context, {{.Arg.Pair}}) ({{.Ret.De {{if eq .Cmd ":many"}} {{range .Comments}}//{{.}} {{end -}} +{{- if $.EmitMethodsWithDBArgument}} +func (q *Queries) {{.MethodName}}(ctx context.Context, db DBTX, {{.Arg.Pair}}) ([]{{.Ret.DefineType}}, error) { + rows, err := db.Query(ctx, {{.ConstantName}}, {{.Arg.Params}}) +{{- else}} func (q *Queries) {{.MethodName}}(ctx context.Context, {{.Arg.Pair}}) ([]{{.Ret.DefineType}}, error) { rows, err := q.db.Query(ctx, {{.ConstantName}}, {{.Arg.Params}}) +{{- end}} if err != nil { return nil, err } @@ -61,8 +71,13 @@ func (q *Queries) {{.MethodName}}(ctx context.Context, {{.Arg.Pair}}) ([]{{.Ret. {{if eq .Cmd ":exec"}} {{range .Comments}}//{{.}} {{end -}} +{{- if $.EmitMethodsWithDBArgument}} +func (q *Queries) {{.MethodName}}(ctx context.Context, db DBTX, {{.Arg.Pair}}) error { + _, err := db.Exec(ctx, {{.ConstantName}}, {{.Arg.Params}}) +{{- else}} func (q *Queries) {{.MethodName}}(ctx context.Context, {{.Arg.Pair}}) error { _, err := q.db.Exec(ctx, {{.ConstantName}}, {{.Arg.Params}}) +{{- end}} return err } {{end}} @@ -70,8 +85,13 @@ func (q *Queries) {{.MethodName}}(ctx context.Context, {{.Arg.Pair}}) error { {{if eq .Cmd ":execrows"}} {{range .Comments}}//{{.}} {{end -}} +{{if $.EmitMethodsWithDBArgument}} +func (q *Queries) {{.MethodName}}(ctx context.Context, db DBTX, {{.Arg.Pair}}) (int64, error) { + result, err := db.Exec(ctx, {{.ConstantName}}, {{.Arg.Params}}) +{{- else -}} func (q *Queries) {{.MethodName}}(ctx context.Context, {{.Arg.Pair}}) (int64, error) { result, err := q.db.Exec(ctx, {{.ConstantName}}, {{.Arg.Params}}) +{{- end}} if err != nil { return 0, err } @@ -82,11 +102,16 @@ func (q *Queries) {{.MethodName}}(ctx context.Context, {{.Arg.Pair}}) (int64, er {{if eq .Cmd ":execresult"}} {{range .Comments}}//{{.}} {{end -}} +{{if $.EmitMethodsWithDBArgument}} +func (q *Queries) {{.MethodName}}(ctx context.Context, db DBTX, {{.Arg.Pair}}) (pgconn.CommandTag, error) { + return db.Exec(ctx, {{.ConstantName}}, {{.Arg.Params}}) +{{else}} func (q *Queries) {{.MethodName}}(ctx context.Context, {{.Arg.Pair}}) (pgconn.CommandTag, error) { return q.db.Exec(ctx, {{.ConstantName}}, {{.Arg.Params}}) +{{- end}} } {{end}} {{end}} {{end}} -{{end}} \ No newline at end of file +{{end}} diff --git a/internal/codegen/golang/templates/stdlib/dbCode.tmpl b/internal/codegen/golang/templates/stdlib/dbCode.tmpl index 08710a62d9..7433d522f6 100644 --- a/internal/codegen/golang/templates/stdlib/dbCode.tmpl +++ b/internal/codegen/golang/templates/stdlib/dbCode.tmpl @@ -6,8 +6,13 @@ type DBTX interface { QueryRowContext(context.Context, string, ...interface{}) *sql.Row } +{{ if .EmitMethodsWithDBArgument}} +func New() *Queries { + return &Queries{} +{{- else -}} func New(db DBTX) *Queries { return &Queries{db: db} +{{- end}} } {{if .EmitPreparedQueries}} @@ -72,7 +77,9 @@ func (q *Queries) queryRow(ctx context.Context, stmt *sql.Stmt, query string, ar {{end}} type Queries struct { + {{- if not .EmitMethodsWithDBArgument}} db DBTX + {{- end}} {{- if .EmitPreparedQueries}} tx *sql.Tx @@ -82,6 +89,7 @@ type Queries struct { {{- end}} } +{{if not .EmitMethodsWithDBArgument}} func (q *Queries) WithTx(tx *sql.Tx) *Queries { return &Queries{ db: tx, @@ -93,4 +101,5 @@ func (q *Queries) WithTx(tx *sql.Tx) *Queries { {{- end}} } } -{{end}} \ No newline at end of file +{{end}} +{{end}} diff --git a/internal/codegen/golang/templates/stdlib/interfaceCode.tmpl b/internal/codegen/golang/templates/stdlib/interfaceCode.tmpl index 5705c69fd1..67c31b817a 100644 --- a/internal/codegen/golang/templates/stdlib/interfaceCode.tmpl +++ b/internal/codegen/golang/templates/stdlib/interfaceCode.tmpl @@ -1,23 +1,34 @@ {{define "interfaceCodeStd"}} type Querier interface { + {{- $dbtxParam := .EmitMethodsWithDBArgument -}} {{- range .GoQueries}} - {{- if eq .Cmd ":one"}} + {{- if and (eq .Cmd ":one") ($dbtxParam) }} + {{.MethodName}}(ctx context.Context, db DBTX, {{.Arg.Pair}}) ({{.Ret.DefineType}}, error) + {{- else if eq .Cmd ":one"}} {{.MethodName}}(ctx context.Context, {{.Arg.Pair}}) ({{.Ret.DefineType}}, error) {{- end}} - {{- if eq .Cmd ":many"}} + {{- if and (eq .Cmd ":many") ($dbtxParam) }} + {{.MethodName}}(ctx context.Context, db DBTX, {{.Arg.Pair}}) ([]{{.Ret.DefineType}}, error) + {{- else if eq .Cmd ":many"}} {{.MethodName}}(ctx context.Context, {{.Arg.Pair}}) ([]{{.Ret.DefineType}}, error) {{- end}} - {{- if eq .Cmd ":exec"}} + {{- if and (eq .Cmd ":exec") ($dbtxParam) }} + {{.MethodName}}(ctx context.Context, db DBTX, {{.Arg.Pair}}) error + {{- else if eq .Cmd ":exec"}} {{.MethodName}}(ctx context.Context, {{.Arg.Pair}}) error {{- end}} - {{- if eq .Cmd ":execrows"}} + {{- if and (eq .Cmd ":execrows") ($dbtxParam) }} + {{.MethodName}}(ctx context.Context, db DBTX, {{.Arg.Pair}}) (int64, error) + {{- else if eq .Cmd ":execrows"}} {{.MethodName}}(ctx context.Context, {{.Arg.Pair}}) (int64, error) {{- end}} - {{- if eq .Cmd ":execresult"}} + {{- if and (eq .Cmd ":execresult") ($dbtxParam) }} + {{.MethodName}}(ctx context.Context, db DBTX, {{.Arg.Pair}}) (sql.Result, error) + {{- else if eq .Cmd ":execresult"}} {{.MethodName}}(ctx context.Context, {{.Arg.Pair}}) (sql.Result, error) {{- end}} {{- end}} } var _ Querier = (*Queries)(nil) -{{end}} \ No newline at end of file +{{end}} diff --git a/internal/codegen/golang/templates/stdlib/queryCode.tmpl b/internal/codegen/golang/templates/stdlib/queryCode.tmpl index 1b06f4e1a8..9a23ed3fba 100644 --- a/internal/codegen/golang/templates/stdlib/queryCode.tmpl +++ b/internal/codegen/golang/templates/stdlib/queryCode.tmpl @@ -22,10 +22,16 @@ type {{.Ret.Type}} struct { {{- range .Ret.Struct.Fields}} {{if eq .Cmd ":one"}} {{range .Comments}}//{{.}} {{end -}} +{{- if $.EmitMethodsWithDBArgument -}} +func (q *Queries) {{.MethodName}}(ctx context.Context, db DBTX, {{.Arg.Pair}}) ({{.Ret.DefineType}}, error) { +{{- else -}} func (q *Queries) {{.MethodName}}(ctx context.Context, {{.Arg.Pair}}) ({{.Ret.DefineType}}, error) { +{{- end -}} {{- if $.EmitPreparedQueries}} row := q.queryRow(ctx, q.{{.FieldName}}, {{.ConstantName}}, {{.Arg.Params}}) - {{- else}} + {{- else if $.EmitMethodsWithDBArgument -}} + row := db.QueryRowContext(ctx, {{.ConstantName}}, {{.Arg.Params}}) + {{- else -}} row := q.db.QueryRowContext(ctx, {{.ConstantName}}, {{.Arg.Params}}) {{- end}} var {{.Ret.Name}} {{.Ret.Type}} @@ -37,10 +43,16 @@ func (q *Queries) {{.MethodName}}(ctx context.Context, {{.Arg.Pair}}) ({{.Ret.De {{if eq .Cmd ":many"}} {{range .Comments}}//{{.}} {{end -}} +{{- if $.EmitMethodsWithDBArgument -}} +func (q *Queries) {{.MethodName}}(ctx context.Context, db DBTX, {{.Arg.Pair}}) ([]{{.Ret.DefineType}}, error) { +{{- else -}} func (q *Queries) {{.MethodName}}(ctx context.Context, {{.Arg.Pair}}) ([]{{.Ret.DefineType}}, error) { +{{- end -}} {{- if $.EmitPreparedQueries}} rows, err := q.query(ctx, q.{{.FieldName}}, {{.ConstantName}}, {{.Arg.Params}}) - {{- else}} + {{- else if $.EmitMethodsWithDBArgument -}} + rows, err := db.QueryContext(ctx, {{.ConstantName}}, {{.Arg.Params}}) + {{- else -}} rows, err := q.db.QueryContext(ctx, {{.ConstantName}}, {{.Arg.Params}}) {{- end}} if err != nil { @@ -72,10 +84,16 @@ func (q *Queries) {{.MethodName}}(ctx context.Context, {{.Arg.Pair}}) ([]{{.Ret. {{if eq .Cmd ":exec"}} {{range .Comments}}//{{.}} {{end -}} +{{- if $.EmitMethodsWithDBArgument -}} +func (q *Queries) {{.MethodName}}(ctx context.Context, db DBTX, {{.Arg.Pair}}) error { +{{- else -}} func (q *Queries) {{.MethodName}}(ctx context.Context, {{.Arg.Pair}}) error { +{{- end -}} {{- if $.EmitPreparedQueries}} _, err := q.exec(ctx, q.{{.FieldName}}, {{.ConstantName}}, {{.Arg.Params}}) - {{- else}} + {{- else if $.EmitMethodsWithDBArgument -}} + _, err := db.ExecContext(ctx, {{.ConstantName}}, {{.Arg.Params}}) + {{- else -}} _, err := q.db.ExecContext(ctx, {{.ConstantName}}, {{.Arg.Params}}) {{- end}} return err @@ -85,10 +103,16 @@ func (q *Queries) {{.MethodName}}(ctx context.Context, {{.Arg.Pair}}) error { {{if eq .Cmd ":execrows"}} {{range .Comments}}//{{.}} {{end -}} +{{- if $.EmitMethodsWithDBArgument -}} +func (q *Queries) {{.MethodName}}(ctx context.Context, db DBTX, {{.Arg.Pair}}) (int64, error) { +{{- else -}} func (q *Queries) {{.MethodName}}(ctx context.Context, {{.Arg.Pair}}) (int64, error) { +{{end -}} {{- if $.EmitPreparedQueries}} result, err := q.exec(ctx, q.{{.FieldName}}, {{.ConstantName}}, {{.Arg.Params}}) - {{- else}} + {{- else if $.EmitMethodsWithDBArgument -}} + result, err := db.ExecContext(ctx, {{.ConstantName}}, {{.Arg.Params}}) + {{- else -}} result, err := q.db.ExecContext(ctx, {{.ConstantName}}, {{.Arg.Params}}) {{- end}} if err != nil { @@ -101,10 +125,16 @@ func (q *Queries) {{.MethodName}}(ctx context.Context, {{.Arg.Pair}}) (int64, er {{if eq .Cmd ":execresult"}} {{range .Comments}}//{{.}} {{end -}} +{{- if $.EmitMethodsWithDBArgument -}} +func (q *Queries) {{.MethodName}}(ctx context.Context, db DBTX, {{.Arg.Pair}}) (sql.Result, error) { +{{- else -}} func (q *Queries) {{.MethodName}}(ctx context.Context, {{.Arg.Pair}}) (sql.Result, error) { +{{- end -}} {{- if $.EmitPreparedQueries}} return q.exec(ctx, q.{{.FieldName}}, {{.ConstantName}}, {{.Arg.Params}}) - {{- else}} + {{- else if $.EmitMethodsWithDBArgument -}} + return db.ExecContext(ctx, {{.ConstantName}}, {{.Arg.Params}}) + {{- else -}} return q.db.ExecContext(ctx, {{.ConstantName}}, {{.Arg.Params}}) {{- end}} } @@ -112,4 +142,4 @@ func (q *Queries) {{.MethodName}}(ctx context.Context, {{.Arg.Pair}}) (sql.Resul {{end}} {{end}} -{{end}} \ No newline at end of file +{{end}} diff --git a/internal/config/config.go b/internal/config/config.go index c7b1fbbe60..3ef4a26300 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -110,25 +110,26 @@ type SQLGen struct { } type SQLGo struct { - EmitInterface bool `json:"emit_interface" yaml:"emit_interface"` - EmitJSONTags bool `json:"emit_json_tags" yaml:"emit_json_tags"` - EmitDBTags bool `json:"emit_db_tags" yaml:"emit_db_tags"` - EmitPreparedQueries bool `json:"emit_prepared_queries" yaml:"emit_prepared_queries"` - EmitExactTableNames bool `json:"emit_exact_table_names,omitempty" yaml:"emit_exact_table_names"` - EmitEmptySlices bool `json:"emit_empty_slices,omitempty" yaml:"emit_empty_slices"` - EmitExportedQueries bool `json:"emit_exported_queries" yaml:"emit_exported_queries"` - EmitResultStructPointers bool `json:"emit_result_struct_pointers" yaml:"emit_result_struct_pointers"` - EmitParamsStructPointers bool `json:"emit_params_struct_pointers" yaml:"emit_params_struct_pointers"` - JSONTagsCaseStyle string `json:"json_tags_case_style,omitempty" yaml:"json_tags_case_style"` - Package string `json:"package" yaml:"package"` - Out string `json:"out" yaml:"out"` - Overrides []Override `json:"overrides,omitempty" yaml:"overrides"` - Rename map[string]string `json:"rename,omitempty" yaml:"rename"` - SQLPackage string `json:"sql_package" yaml:"sql_package"` - OutputDBFileName string `json:"output_db_file_name,omitempty" yaml:"output_db_file_name"` - OutputModelsFileName string `json:"output_models_file_name,omitempty" yaml:"output_models_file_name"` - OutputQuerierFileName string `json:"output_querier_file_name,omitempty" yaml:"output_querier_file_name"` - OutputFilesSuffix string `json:"output_files_suffix,omitempty" yaml:"output_files_suffix"` + EmitInterface bool `json:"emit_interface" yaml:"emit_interface"` + EmitJSONTags bool `json:"emit_json_tags" yaml:"emit_json_tags"` + EmitDBTags bool `json:"emit_db_tags" yaml:"emit_db_tags"` + EmitPreparedQueries bool `json:"emit_prepared_queries" yaml:"emit_prepared_queries"` + EmitExactTableNames bool `json:"emit_exact_table_names,omitempty" yaml:"emit_exact_table_names"` + EmitEmptySlices bool `json:"emit_empty_slices,omitempty" yaml:"emit_empty_slices"` + EmitExportedQueries bool `json:"emit_exported_queries" yaml:"emit_exported_queries"` + EmitResultStructPointers bool `json:"emit_result_struct_pointers" yaml:"emit_result_struct_pointers"` + EmitParamsStructPointers bool `json:"emit_params_struct_pointers" yaml:"emit_params_struct_pointers"` + EmitMethodsWithDBArgument bool `json:"emit_methods_with_db_argument,omitempty" yaml:"emit_methods_with_db_argument"` + JSONTagsCaseStyle string `json:"json_tags_case_style,omitempty" yaml:"json_tags_case_style"` + Package string `json:"package" yaml:"package"` + Out string `json:"out" yaml:"out"` + Overrides []Override `json:"overrides,omitempty" yaml:"overrides"` + Rename map[string]string `json:"rename,omitempty" yaml:"rename"` + SQLPackage string `json:"sql_package" yaml:"sql_package"` + OutputDBFileName string `json:"output_db_file_name,omitempty" yaml:"output_db_file_name"` + OutputModelsFileName string `json:"output_models_file_name,omitempty" yaml:"output_models_file_name"` + OutputQuerierFileName string `json:"output_querier_file_name,omitempty" yaml:"output_querier_file_name"` + OutputFilesSuffix string `json:"output_files_suffix,omitempty" yaml:"output_files_suffix"` } type SQLKotlin struct { @@ -324,6 +325,19 @@ func ParseConfig(rd io.Reader) (Config, error) { } } +func Validate(c Config) error { + for _, sql := range c.SQL { + sqlGo := sql.Gen.Go + if sqlGo == nil { + continue + } + if sqlGo.EmitMethodsWithDBArgument && sqlGo.EmitPreparedQueries { + return fmt.Errorf("invalid config: emit_methods_with_db_argument and emit_prepared_queries settings are mutually exclusive") + } + } + return nil +} + type CombinedSettings struct { Global Config Package SQL diff --git a/internal/config/config_test.go b/internal/config/config_test.go index 3249d93398..9250629c2e 100644 --- a/internal/config/config_test.go +++ b/internal/config/config_test.go @@ -64,6 +64,21 @@ func TestBadConfigs(t *testing.T) { } } +func TestInvalidConfig(t *testing.T) { + err := Validate(Config{ + SQL: []SQL{{ + Gen: SQLGen{ + Go: &SQLGo{ + EmitMethodsWithDBArgument: true, + EmitPreparedQueries: true, + }, + }, + }}}) + if err == nil { + t.Errorf("expected err; got nil") + } +} + func TestTypeOverrides(t *testing.T) { for _, test := range []struct { override Override diff --git a/internal/config/v_one.go b/internal/config/v_one.go index 8cb86c68c8..67701d4213 100644 --- a/internal/config/v_one.go +++ b/internal/config/v_one.go @@ -16,27 +16,28 @@ type V1GenerateSettings struct { } type v1PackageSettings struct { - Name string `json:"name" yaml:"name"` - Engine Engine `json:"engine,omitempty" yaml:"engine"` - Path string `json:"path" yaml:"path"` - Schema Paths `json:"schema" yaml:"schema"` - Queries Paths `json:"queries" yaml:"queries"` - EmitInterface bool `json:"emit_interface" yaml:"emit_interface"` - EmitJSONTags bool `json:"emit_json_tags" yaml:"emit_json_tags"` - EmitDBTags bool `json:"emit_db_tags" yaml:"emit_db_tags"` - EmitPreparedQueries bool `json:"emit_prepared_queries" yaml:"emit_prepared_queries"` - EmitExactTableNames bool `json:"emit_exact_table_names,omitempty" yaml:"emit_exact_table_names"` - EmitEmptySlices bool `json:"emit_empty_slices,omitempty" yaml:"emit_empty_slices"` - EmitExportedQueries bool `json:"emit_exported_queries,omitempty" yaml:"emit_exported_queries"` - EmitResultStructPointers bool `json:"emit_result_struct_pointers" yaml:"emit_result_struct_pointers"` - EmitParamsStructPointers bool `json:"emit_params_struct_pointers" yaml:"emit_params_struct_pointers"` - JSONTagsCaseStyle string `json:"json_tags_case_style,omitempty" yaml:"json_tags_case_style"` - SQLPackage string `json:"sql_package" yaml:"sql_package"` - Overrides []Override `json:"overrides" yaml:"overrides"` - OutputDBFileName string `json:"output_db_file_name,omitempty" yaml:"output_db_file_name"` - OutputModelsFileName string `json:"output_models_file_name,omitempty" yaml:"output_models_file_name"` - OutputQuerierFileName string `json:"output_querier_file_name,omitempty" yaml:"output_querier_file_name"` - OutputFilesSuffix string `json:"output_files_suffix,omitempty" yaml:"output_files_suffix"` + Name string `json:"name" yaml:"name"` + Engine Engine `json:"engine,omitempty" yaml:"engine"` + Path string `json:"path" yaml:"path"` + Schema Paths `json:"schema" yaml:"schema"` + Queries Paths `json:"queries" yaml:"queries"` + EmitInterface bool `json:"emit_interface" yaml:"emit_interface"` + EmitJSONTags bool `json:"emit_json_tags" yaml:"emit_json_tags"` + EmitDBTags bool `json:"emit_db_tags" yaml:"emit_db_tags"` + EmitPreparedQueries bool `json:"emit_prepared_queries" yaml:"emit_prepared_queries"` + EmitExactTableNames bool `json:"emit_exact_table_names,omitempty" yaml:"emit_exact_table_names"` + EmitEmptySlices bool `json:"emit_empty_slices,omitempty" yaml:"emit_empty_slices"` + EmitExportedQueries bool `json:"emit_exported_queries,omitempty" yaml:"emit_exported_queries"` + EmitResultStructPointers bool `json:"emit_result_struct_pointers" yaml:"emit_result_struct_pointers"` + EmitParamsStructPointers bool `json:"emit_params_struct_pointers" yaml:"emit_params_struct_pointers"` + EmitMethodsWithDBArgument bool `json:"emit_methods_with_db_argument" yaml:"emit_methods_with_db_argument"` + JSONTagsCaseStyle string `json:"json_tags_case_style,omitempty" yaml:"json_tags_case_style"` + SQLPackage string `json:"sql_package" yaml:"sql_package"` + Overrides []Override `json:"overrides" yaml:"overrides"` + OutputDBFileName string `json:"output_db_file_name,omitempty" yaml:"output_db_file_name"` + OutputModelsFileName string `json:"output_models_file_name,omitempty" yaml:"output_models_file_name"` + OutputQuerierFileName string `json:"output_querier_file_name,omitempty" yaml:"output_querier_file_name"` + OutputFilesSuffix string `json:"output_files_suffix,omitempty" yaml:"output_files_suffix"` } func v1ParseConfig(rd io.Reader) (Config, error) { @@ -112,24 +113,25 @@ func (c *V1GenerateSettings) Translate() Config { Queries: pkg.Queries, Gen: SQLGen{ Go: &SQLGo{ - EmitInterface: pkg.EmitInterface, - EmitJSONTags: pkg.EmitJSONTags, - EmitDBTags: pkg.EmitDBTags, - EmitPreparedQueries: pkg.EmitPreparedQueries, - EmitExactTableNames: pkg.EmitExactTableNames, - EmitEmptySlices: pkg.EmitEmptySlices, - EmitExportedQueries: pkg.EmitExportedQueries, - EmitResultStructPointers: pkg.EmitResultStructPointers, - EmitParamsStructPointers: pkg.EmitParamsStructPointers, - Package: pkg.Name, - Out: pkg.Path, - SQLPackage: pkg.SQLPackage, - Overrides: pkg.Overrides, - JSONTagsCaseStyle: pkg.JSONTagsCaseStyle, - OutputDBFileName: pkg.OutputDBFileName, - OutputModelsFileName: pkg.OutputModelsFileName, - OutputQuerierFileName: pkg.OutputQuerierFileName, - OutputFilesSuffix: pkg.OutputFilesSuffix, + EmitInterface: pkg.EmitInterface, + EmitJSONTags: pkg.EmitJSONTags, + EmitDBTags: pkg.EmitDBTags, + EmitPreparedQueries: pkg.EmitPreparedQueries, + EmitExactTableNames: pkg.EmitExactTableNames, + EmitEmptySlices: pkg.EmitEmptySlices, + EmitExportedQueries: pkg.EmitExportedQueries, + EmitResultStructPointers: pkg.EmitResultStructPointers, + EmitParamsStructPointers: pkg.EmitParamsStructPointers, + EmitMethodsWithDBArgument: pkg.EmitMethodsWithDBArgument, + Package: pkg.Name, + Out: pkg.Path, + SQLPackage: pkg.SQLPackage, + Overrides: pkg.Overrides, + JSONTagsCaseStyle: pkg.JSONTagsCaseStyle, + OutputDBFileName: pkg.OutputDBFileName, + OutputModelsFileName: pkg.OutputModelsFileName, + OutputQuerierFileName: pkg.OutputQuerierFileName, + OutputFilesSuffix: pkg.OutputFilesSuffix, }, }, }) diff --git a/internal/endtoend/testdata/emit_methods_with_db_argument/mysql/go/db.go b/internal/endtoend/testdata/emit_methods_with_db_argument/mysql/go/db.go new file mode 100644 index 0000000000..151985bc5b --- /dev/null +++ b/internal/endtoend/testdata/emit_methods_with_db_argument/mysql/go/db.go @@ -0,0 +1,22 @@ +// Code generated by sqlc. DO NOT EDIT. + +package querytest + +import ( + "context" + "database/sql" +) + +type DBTX interface { + ExecContext(context.Context, string, ...interface{}) (sql.Result, error) + PrepareContext(context.Context, string) (*sql.Stmt, error) + QueryContext(context.Context, string, ...interface{}) (*sql.Rows, error) + QueryRowContext(context.Context, string, ...interface{}) *sql.Row +} + +func New() *Queries { + return &Queries{} +} + +type Queries struct { +} diff --git a/internal/endtoend/testdata/emit_methods_with_db_argument/mysql/go/models.go b/internal/endtoend/testdata/emit_methods_with_db_argument/mysql/go/models.go new file mode 100644 index 0000000000..4c02540e98 --- /dev/null +++ b/internal/endtoend/testdata/emit_methods_with_db_argument/mysql/go/models.go @@ -0,0 +1,14 @@ +// Code generated by sqlc. DO NOT EDIT. + +package querytest + +import ( + "database/sql" +) + +type User struct { + ID int32 + FirstName string + LastName sql.NullString + Age int32 +} diff --git a/internal/endtoend/testdata/emit_methods_with_db_argument/mysql/go/query.sql.go b/internal/endtoend/testdata/emit_methods_with_db_argument/mysql/go/query.sql.go new file mode 100644 index 0000000000..dedcf2bb08 --- /dev/null +++ b/internal/endtoend/testdata/emit_methods_with_db_argument/mysql/go/query.sql.go @@ -0,0 +1,40 @@ +// Code generated by sqlc. DO NOT EDIT. +// source: query.sql + +package querytest + +import ( + "context" +) + +const getAll = `-- name: GetAll :many +SELECT id, first_name, last_name, age FROM users +` + +func (q *Queries) GetAll(ctx context.Context, db DBTX) ([]User, error) { + rows, err := db.QueryContext(ctx, getAll) + if err != nil { + return nil, err + } + defer rows.Close() + var items []User + for rows.Next() { + var i User + if err := rows.Scan( + &i.ID, + &i.FirstName, + &i.LastName, + &i.Age, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} diff --git a/internal/endtoend/testdata/emit_methods_with_db_argument/mysql/query.sql b/internal/endtoend/testdata/emit_methods_with_db_argument/mysql/query.sql new file mode 100644 index 0000000000..e2f85e2a9a --- /dev/null +++ b/internal/endtoend/testdata/emit_methods_with_db_argument/mysql/query.sql @@ -0,0 +1,2 @@ +/* name: GetAll :many */ +SELECT * FROM users; diff --git a/internal/endtoend/testdata/emit_methods_with_db_argument/mysql/schema.sql b/internal/endtoend/testdata/emit_methods_with_db_argument/mysql/schema.sql new file mode 100644 index 0000000000..3e36d6cdf7 --- /dev/null +++ b/internal/endtoend/testdata/emit_methods_with_db_argument/mysql/schema.sql @@ -0,0 +1,6 @@ +CREATE TABLE users ( + id integer NOT NULL AUTO_INCREMENT PRIMARY KEY, + first_name varchar(255) NOT NULL, + last_name varchar(255), + age integer NOT NULL +); diff --git a/internal/endtoend/testdata/emit_methods_with_db_argument/mysql/sqlc.json b/internal/endtoend/testdata/emit_methods_with_db_argument/mysql/sqlc.json new file mode 100644 index 0000000000..f27ded3934 --- /dev/null +++ b/internal/endtoend/testdata/emit_methods_with_db_argument/mysql/sqlc.json @@ -0,0 +1,13 @@ +{ + "version": "1", + "packages": [ + { + "name": "querytest", + "path": "go", + "schema": "schema.sql", + "queries": "query.sql", + "engine": "mysql", + "emit_methods_with_db_argument": true + } + ] +} diff --git a/internal/endtoend/testdata/emit_methods_with_db_argument/postgresql/pgx/go/db.go b/internal/endtoend/testdata/emit_methods_with_db_argument/postgresql/pgx/go/db.go new file mode 100644 index 0000000000..ae4ac1974d --- /dev/null +++ b/internal/endtoend/testdata/emit_methods_with_db_argument/postgresql/pgx/go/db.go @@ -0,0 +1,23 @@ +// Code generated by sqlc. DO NOT EDIT. + +package querytest + +import ( + "context" + + "github.com/jackc/pgconn" + "github.com/jackc/pgx/v4" +) + +type DBTX interface { + Exec(context.Context, string, ...interface{}) (pgconn.CommandTag, error) + Query(context.Context, string, ...interface{}) (pgx.Rows, error) + QueryRow(context.Context, string, ...interface{}) pgx.Row +} + +func New() *Queries { + return &Queries{} +} + +type Queries struct { +} diff --git a/internal/endtoend/testdata/emit_methods_with_db_argument/postgresql/pgx/go/models.go b/internal/endtoend/testdata/emit_methods_with_db_argument/postgresql/pgx/go/models.go new file mode 100644 index 0000000000..4c02540e98 --- /dev/null +++ b/internal/endtoend/testdata/emit_methods_with_db_argument/postgresql/pgx/go/models.go @@ -0,0 +1,14 @@ +// Code generated by sqlc. DO NOT EDIT. + +package querytest + +import ( + "database/sql" +) + +type User struct { + ID int32 + FirstName string + LastName sql.NullString + Age int32 +} diff --git a/internal/endtoend/testdata/emit_methods_with_db_argument/postgresql/pgx/go/query.sql.go b/internal/endtoend/testdata/emit_methods_with_db_argument/postgresql/pgx/go/query.sql.go new file mode 100644 index 0000000000..978ca7f7c6 --- /dev/null +++ b/internal/endtoend/testdata/emit_methods_with_db_argument/postgresql/pgx/go/query.sql.go @@ -0,0 +1,37 @@ +// Code generated by sqlc. DO NOT EDIT. +// source: query.sql + +package querytest + +import ( + "context" +) + +const getAll = `-- name: GetAll :many +SELECT id, first_name, last_name, age FROM users +` + +func (q *Queries) GetAll(ctx context.Context, db DBTX) ([]User, error) { + rows, err := db.Query(ctx, getAll) + if err != nil { + return nil, err + } + defer rows.Close() + var items []User + for rows.Next() { + var i User + if err := rows.Scan( + &i.ID, + &i.FirstName, + &i.LastName, + &i.Age, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} diff --git a/internal/endtoend/testdata/emit_methods_with_db_argument/postgresql/pgx/query.sql b/internal/endtoend/testdata/emit_methods_with_db_argument/postgresql/pgx/query.sql new file mode 100644 index 0000000000..237b20193b --- /dev/null +++ b/internal/endtoend/testdata/emit_methods_with_db_argument/postgresql/pgx/query.sql @@ -0,0 +1,2 @@ +-- name: GetAll :many +SELECT * FROM users; diff --git a/internal/endtoend/testdata/emit_methods_with_db_argument/postgresql/pgx/schema.sql b/internal/endtoend/testdata/emit_methods_with_db_argument/postgresql/pgx/schema.sql new file mode 100644 index 0000000000..ae8e46e25e --- /dev/null +++ b/internal/endtoend/testdata/emit_methods_with_db_argument/postgresql/pgx/schema.sql @@ -0,0 +1,6 @@ +CREATE TABLE users ( + id integer NOT NULL GENERATED ALWAYS AS IDENTITY PRIMARY KEY, + first_name varchar(255) NOT NULL, + last_name varchar(255), + age integer NOT NULL +); diff --git a/internal/endtoend/testdata/emit_methods_with_db_argument/postgresql/pgx/sqlc.json b/internal/endtoend/testdata/emit_methods_with_db_argument/postgresql/pgx/sqlc.json new file mode 100644 index 0000000000..069e8bc5bc --- /dev/null +++ b/internal/endtoend/testdata/emit_methods_with_db_argument/postgresql/pgx/sqlc.json @@ -0,0 +1,14 @@ +{ + "version": "1", + "packages": [ + { + "path": "go", + "engine": "postgresql", + "sql_package": "pgx/v4", + "name": "querytest", + "schema": "schema.sql", + "queries": "query.sql", + "emit_methods_with_db_argument": true + } + ] +} diff --git a/internal/endtoend/testdata/emit_methods_with_db_argument/postgresql/stdlib/go/db.go b/internal/endtoend/testdata/emit_methods_with_db_argument/postgresql/stdlib/go/db.go new file mode 100644 index 0000000000..151985bc5b --- /dev/null +++ b/internal/endtoend/testdata/emit_methods_with_db_argument/postgresql/stdlib/go/db.go @@ -0,0 +1,22 @@ +// Code generated by sqlc. DO NOT EDIT. + +package querytest + +import ( + "context" + "database/sql" +) + +type DBTX interface { + ExecContext(context.Context, string, ...interface{}) (sql.Result, error) + PrepareContext(context.Context, string) (*sql.Stmt, error) + QueryContext(context.Context, string, ...interface{}) (*sql.Rows, error) + QueryRowContext(context.Context, string, ...interface{}) *sql.Row +} + +func New() *Queries { + return &Queries{} +} + +type Queries struct { +} diff --git a/internal/endtoend/testdata/emit_methods_with_db_argument/postgresql/stdlib/go/models.go b/internal/endtoend/testdata/emit_methods_with_db_argument/postgresql/stdlib/go/models.go new file mode 100644 index 0000000000..4c02540e98 --- /dev/null +++ b/internal/endtoend/testdata/emit_methods_with_db_argument/postgresql/stdlib/go/models.go @@ -0,0 +1,14 @@ +// Code generated by sqlc. DO NOT EDIT. + +package querytest + +import ( + "database/sql" +) + +type User struct { + ID int32 + FirstName string + LastName sql.NullString + Age int32 +} diff --git a/internal/endtoend/testdata/emit_methods_with_db_argument/postgresql/stdlib/go/query.sql.go b/internal/endtoend/testdata/emit_methods_with_db_argument/postgresql/stdlib/go/query.sql.go new file mode 100644 index 0000000000..dedcf2bb08 --- /dev/null +++ b/internal/endtoend/testdata/emit_methods_with_db_argument/postgresql/stdlib/go/query.sql.go @@ -0,0 +1,40 @@ +// Code generated by sqlc. DO NOT EDIT. +// source: query.sql + +package querytest + +import ( + "context" +) + +const getAll = `-- name: GetAll :many +SELECT id, first_name, last_name, age FROM users +` + +func (q *Queries) GetAll(ctx context.Context, db DBTX) ([]User, error) { + rows, err := db.QueryContext(ctx, getAll) + if err != nil { + return nil, err + } + defer rows.Close() + var items []User + for rows.Next() { + var i User + if err := rows.Scan( + &i.ID, + &i.FirstName, + &i.LastName, + &i.Age, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} diff --git a/internal/endtoend/testdata/emit_methods_with_db_argument/postgresql/stdlib/query.sql b/internal/endtoend/testdata/emit_methods_with_db_argument/postgresql/stdlib/query.sql new file mode 100644 index 0000000000..237b20193b --- /dev/null +++ b/internal/endtoend/testdata/emit_methods_with_db_argument/postgresql/stdlib/query.sql @@ -0,0 +1,2 @@ +-- name: GetAll :many +SELECT * FROM users; diff --git a/internal/endtoend/testdata/emit_methods_with_db_argument/postgresql/stdlib/schema.sql b/internal/endtoend/testdata/emit_methods_with_db_argument/postgresql/stdlib/schema.sql new file mode 100644 index 0000000000..ae8e46e25e --- /dev/null +++ b/internal/endtoend/testdata/emit_methods_with_db_argument/postgresql/stdlib/schema.sql @@ -0,0 +1,6 @@ +CREATE TABLE users ( + id integer NOT NULL GENERATED ALWAYS AS IDENTITY PRIMARY KEY, + first_name varchar(255) NOT NULL, + last_name varchar(255), + age integer NOT NULL +); diff --git a/internal/endtoend/testdata/emit_methods_with_db_argument/postgresql/stdlib/sqlc.json b/internal/endtoend/testdata/emit_methods_with_db_argument/postgresql/stdlib/sqlc.json new file mode 100644 index 0000000000..6acde89538 --- /dev/null +++ b/internal/endtoend/testdata/emit_methods_with_db_argument/postgresql/stdlib/sqlc.json @@ -0,0 +1,13 @@ +{ + "version": "1", + "packages": [ + { + "path": "go", + "engine": "postgresql", + "name": "querytest", + "schema": "schema.sql", + "queries": "query.sql", + "emit_methods_with_db_argument": true + } + ] +}